Efecto secundario en tf.while_loop

Efecto secundario en tf.while_loop

Actualmente me está costando entender cómo funciona tensorflow, y siento que la interfaz de python es algo oscura.

Recientemente intenté ejecutar una declaración de impresión simple dentro de un tf.while_loop, y hay muchas cosas que no me quedan claras:

import tensorflow as tf

nb_iter = tf.constant(value=10)
#This solution does not work at all
#nb_iter = tf.get_variable('nb_iter', shape=(1), dtype=tf.int32, trainable=False)
i = tf.get_variable('i', shape=(), trainable=False,
                     initializer=tf.zeros_initializer(), dtype=nb_iter.dtype)

loop_condition = lambda i: tf.less(i, nb_iter)
def loop_body(i):
    tf.Print(i, [i], message='Another iteration')
    return [tf.add(i, 1)]

i = tf.while_loop(loop_condition, loop_body, [i])

initializer_op = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(initializer_op)
    res = sess.run(i)
    print('res is now {}'.format(res))

Observe que si inicializo nb_iter con

nb_iter = tf.get_variable('nb_iter', shape=(1), dtype=tf.int32, trainable=False)

Recibí el siguiente error:

ValueError: la forma debe ser de rango 0 pero es de rango 1 para 'while/LoopCond' (op: 'LoopCond') con formas de entrada: [1].

Empeoró aún más cuando trato de usar el índice 'i' para indexar un tensor (el ejemplo no se muestra aquí), luego obtengo el siguiente error

alueError: la operación 'while/strided_slice' se ha marcado como no recuperable.

¿Puede alguien señalarme una documentación que explique cómo funciona tf.while_loop cuando se usa con tf.Variables, y si es posible usar efectos secundarios (como imprimir) dentro del bucle, así como indexar el tensor con la variable de bucle?

Gracias de antemano por su ayuda

Mostrar la mejor respuesta

En realidad, había muchas cosas mal con mi primer ejemplo:

tf.Print no se ejecuta si el operador no tiene efectos secundarios (es decir, i = tf.Print())

Si el booleano es un escalar, entonces es un tensor de rango 0, no un tensor de rango 1. ...

Este es el código que funciona:

import tensorflow as tf

#nb_iter = tf.constant(value=10)
#This solution does not work at all
nb_iter = tf.get_variable('nb_iter', shape=(), dtype=tf.int32, trainable=False,
                          initializer=tf.zeros_initializer())
nb_iter = tf.add(nb_iter,10)
i = tf.get_variable('i', shape=(), trainable=False,
                     initializer=tf.zeros_initializer(), dtype=nb_iter.dtype)
v = tf.get_variable('v', shape=(10), trainable=False,
                     initializer=tf.random_uniform_initializer, dtype=tf.float32)

loop_condition = lambda i: tf.less(i, nb_iter)
def loop_body(i):
    i = tf.Print(i, [v[i]], message='Another vector element: ')
    return [tf.add(i, 1)]

i = tf.while_loop(loop_condition, loop_body, [i])

initializer_op = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(initializer_op)
    res = sess.run(i)
    print('res is now {}'.format(res))

salida:

Another vector element: [0.203766704]
Another vector element: [0.692927241]
Another vector element: [0.732221603]
Another vector element: [0.0556482077]
Another vector element: [0.422092319]
Another vector element: [0.597698212]
Another vector element: [0.92387116]
Another vector element: [0.590101123]
Another vector element: [0.741415381]
Another vector element: [0.514917374]
res is now 10