1

我有一个顺序 keras 模型,并且我有一个类似于以下示例的自定义层,名为“CounterLayer”。我正在使用 tensorflow 2.0(渴望执行)

class CounterLayer(tf.keras.layers.Layer):
  def __init__(self, stateful=False,**kwargs):
    self.stateful = stateful
    super(CounterLayer, self).__init__(**kwargs)


  def build(self, input_shape):
    self.count = tf.keras.backend.variable(0, name="count")
    super(CounterLayer, self).build(input_shape)

  def call(self, input):
    updates = []
    updates.append((self.count, self.count+1))
    self.add_update(updates)
    tf.print('-------------')
    tf.print(self.count)
    return input

当我运行这个例如 epoch=5 或其他东西时,self.count每次运行的值都不会更新。它始终保持不变。我从这里的https://stackoverflow.com/a/41710515/10645817得到了这个例子。我需要一些几乎与此类似的东西,但我想知道这在热切执行 tensorflow 时是否有效,或者我必须做什么才能获得预期的输出。

我一直在尝试实现这一点,但无法弄清楚。有人可以帮我吗。谢谢...

4

1 回答 1

0

是的,我的问题得到了解决。我遇到了一些更新此类变量的内置方法(即像我上面提到的情况一样,在 epoch 之间保持持久状态)。基本上我需要做的是例如:

  def build(self, input_shape):
    self.count = tf.Variable(0, dtype=tf.float32, trainable=False)
    super(CounterLayer, self).build(input_shape)

  def call(self, input):
    ............
    self.count.assign_add(1)
    ............
    return input

可以在call函数中计算更新后的值,也可以调用self.count.assign(some_updated_value). 此类操作的详细信息可在https://www.tensorflow.org/api_docs/python/tf/Variable中找到。谢谢。

于 2020-05-18T03:31:09.257 回答