Formas incompatibles cuando la salida * actions_one_hot

Formas incompatibles cuando la salida * actions_one_hot

Estoy tratando de implementar una red Deep Q que reproduzca Doom (a saber, doom)

Sin embargo, estoy atascado (desde ayer) con el problema de una codificación en caliente y sus consecuencias: de hecho, tengo 3 acciones posibles que están codificadas así

[[True, False, False], [False, True, False], [False, False, True]] tamaño = [Batch_size, 3]

Cuando codifico one_hot esta matriz de acciones, obtengo una matriz de este tamaño [BatchSize, 3, 3]

Como consecuencia cuando quiero calcular mi estimación del valor Q:

Q = tf.reduce_sum(tf.multiply(self.output, self.actions_one_hot), axis=1)

El tf.multiply(self.output, self.actions_one_hot) produce un error:

InvalidArgumentError: Incompatible shapes: [10,3] vs. [10,3,3] [[Node: DQNetwork/Mul = Mul[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"](DQNetwork/dense/BiasAdd, DQNetwork/one_hot)]]

Entiendo que estos 2 tienen formas incompatibles para ser multiplicados pero no entiendo que debo hacer para que sean compatibles.

Para que quede más claro este es el cuaderno con cada parte explicada:

Estoy seguro de que cometí un error realmente estúpido, pero no lo veo.

¡Gracias por tu ayuda!

Mostrar la mejor respuesta

Tienes que hacer que las formas sean compatibles con tf.multiply porque la función es una multiplicación de elementos.

Sin embargo, creo que probablemente estés haciendo algo mal con el one_hot. Por lo general, una función one_hot se transformará, por ejemplo, de un número a una matriz caliente. Digamos que tiene 3 acciones posibles en su espacio de acción que son (0,1,2), la función activa traducirá eso a [[1,0,0],[0,1,0],[0,0,1]]. El problema es que está enviando los vectores one_hot a otra función one_hot. Si envía directamente las acciones, tendría la misma forma para ambos tensores.

Para resumir, estás usando la función one_hot dos veces. Si ya tiene un vector de tipo [Verdadero, Falso, Falso], ya tiene un one_hot.