我想以某种方式维护一个tf.while_loop
可以支持以下功能的常量列表
- 我能够在索引处读取和写入(多次)常量值
- 我可以
tf.cond
通过检查它在索引处的值与某个常数来运行它
TensorArray
在这里不起作用,因为它不支持重写。我还有什么其他选择?
我想以某种方式维护一个tf.while_loop
可以支持以下功能的常量列表
tf.cond
通过检查它在索引处的值与某个常数来运行它TensorArray
在这里不起作用,因为它不支持重写。我还有什么其他选择?
你可以定义一个正常的Tensor
并tf.tensor_scatter_nd_update
像这样更新它:
%tensorflow_version 1.x
import tensorflow as tf
data = tf.constant([1, 1, 1, 0, 1, 0, 1, 1, 0, 0], dtype=tf.float32)
data_tensor = tf.zeros_like(data)
tensor_size = data_tensor.shape[0]
init_state = (0, data_tensor)
condition = lambda i, _: i < tensor_size
def custom_body(i, tensor):
special_index = 3 # index for which a value should be changed
new_value = 8
tensor = tf.where(tf.equal(i, special_index),
tf.tensor_scatter_nd_update(tensor, [[special_index]], [new_value]),
tf.tensor_scatter_nd_update(tensor, [[i]], [data[i]*2]))
return i + 1, tensor
body = lambda i, tensor: (custom_body(i, tensor))
_, final_result = tf.while_loop(condition, body, init_state)
with tf.Session() as sess:
final_result_values = final_result.eval()
print(final_result_values)
[2. 2. 2. 8. 2. 0. 2. 2. 0. 0.]