1

看看这段代码(假设 Tensorflow 2.1 或 2.2):

import tensorflow as tf

inputs = tf.keras.Input([])
var = tf.Variable(3.0)
out = var*inputs

model = tf.keras.Model(inputs=inputs, outputs=out)
print(model.variables)             # prints []

model.saved_vars = tf.Module()
model.saved_vars.var = var

print(model.saved_vars.variables)  # prints (<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.0>,)
print(model.variables)             # prints []

为什么是model.variables空?我希望包含var. 据我了解tf.Module,它可以用作名称范围,也可以用作tf.Variable分配给其属性的记录。既然 atf.keras.Model是 a tf.Module,它应该递归地做同样的事情,不是吗?

4

1 回答 1

0

根据文档variables是此模块及其子模块拥有的变量序列。但是当我们在 TensorFlow 2.2.0 版本中进行测试时,model.variables并没有显示子模块的变量。

变量在分配给继承自tf.Module. 正如您所做model.saved_vars = tf.Module()的那样,model.saved_vars子模块成为tf.Variables. 由于tf.keras.Model默认情况下是 a tf.Module,因此您可以直接分配 var 。但是正如您正确提到的submodule的model.variables必须记录。tf.Variablesmodel.saved_vars

代码 -尝试在内部添加另一个子模块model.saved_vars以检查是否print(model.saved_vars.variables)提供子模块变量。但还是不行。

import tensorflow as tf

inputs_1 = tf.keras.Input(shape=(3,))
var = tf.Variable(3.0)
out = var*inputs_1

model = tf.keras.Model(inputs=inputs_1, outputs=out)
print(model.variables)             # prints []

# Create Module
model.saved_vars = tf.Module()
model.saved_vars.var = var

model.var = var

print(model.saved_vars.variables)
print(model.variables)  

# Create Sub Module
model.saved_vars.saved_vars2 = tf.Module()
model.saved_vars.saved_vars2.var2 = var

print(model.saved_vars.variables)
print(model.variables)  

输出 -

[]
(<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.0>,)
[<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.0>]
(<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.0>,)
[<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.0>]

当我们得到时,将分享更多信息。

于 2020-06-09T14:37:56.503 回答