1

我想以某种方式维护一个tf.while_loop可以支持以下功能的常量列表

  1. 我能够在索引处读取和写入(多次)常量值
  2. 我可以tf.cond通过检查它在索引处的值与某个常数来运行它

TensorArray在这里不起作用,因为它不支持重写。我还有什么其他选择?

4

1 回答 1

0

你可以定义一个正常的Tensortf.tensor_scatter_nd_update像这样更新它:

%tensorflow_version 1.x

import tensorflow as tf

data = tf.constant([1, 1, 1, 0, 1, 0, 1, 1, 0, 0], dtype=tf.float32)
data_tensor = tf.zeros_like(data)
tensor_size = data_tensor.shape[0]

init_state = (0, data_tensor)
condition = lambda i, _: i < tensor_size

def custom_body(i, tensor):
  special_index = 3 # index for which a value should be changed
  new_value = 8
  tensor = tf.where(tf.equal(i, special_index), 
                    tf.tensor_scatter_nd_update(tensor, [[special_index]], [new_value]),
                    tf.tensor_scatter_nd_update(tensor, [[i]], [data[i]*2]))

  return i + 1, tensor


body = lambda i, tensor: (custom_body(i, tensor))
_, final_result = tf.while_loop(condition, body, init_state)

with tf.Session() as sess:
  final_result_values = final_result.eval()

print(final_result_values)
[2. 2. 2. 8. 2. 0. 2. 2. 0. 0.]
于 2021-12-06T08:41:06.607 回答