我正在尝试微调循环单元中的输入权重,而不让反向传播影响先前的状态(一种截断的反向传播,n = 1)。我在 tensorflow 中使用 tf.keras 和渴望执行。
我找不到冻结 GRU 单元特定部分的方法。特别是循环内核。似乎循环内核是一个张量流变量,因此,我找不到将可训练属性设置为 False 的方法。
我的代码基于这个关于 text_generation 的教程(谷歌 colab 版本,您可以在其中修改build_model
函数并对其进行测试)
def build_model(vocab_size, embedding_dim, rnn_units, batch_size, freeze_embedding_layer=False, freeze_recurrent_kernel=False):
model = tf.keras.Sequential([
tf.keras.layers.Embedding(vocab_size, embedding_dim, batch_input_shape=[batch_size, None]),
rnn(rnn_units, return_sequences=True,recurrent_initializer='glorot_uniform', stateful=True),
tf.keras.layers.Dense(vocab_size)
])
if freeze_embedding_layer:
print("embedding type:", model.layers[0])
model.layers[0].trainable = False
if freeze_recurrent_kernel:
print("rnn type:",type(model.layers[1]))
print("rnn recurrent kernel type:", type(model.layers[1].recurrent_kernel))
model.layers[1].recurrent_kernel.trainable = False
return model
当调用此函数时:
# Length of the vocabulary in chars
vocab_size = len(vocab)
# The embedding dimension
embedding_dim = 256
# Number of RNN units
rnn_units = 1024
if tf.test.is_gpu_available():
rnn = tf.keras.layers.CuDNNGRU
else:
import functools
rnn = functools.partial(
tf.keras.layers.GRU, recurrent_activation='sigmoid')
model = build_model(
vocab_size = len(vocab),
embedding_dim=embedding_dim,
rnn_units=rnn_units,
batch_size=BATCH_SIZE,
freeze_embedding_layer=True,
freeze_recurrent_kernel=True)
我得到:
embedding type: <tensorflow.python.keras.layers.embeddings.Embedding object at 0x7f955a198d68>
rnn type: <class 'tensorflow.python.keras.layers.cudnn_recurrent.CuDNNGRU'>
rnn recurrent kernel type: <class 'tensorflow.python.ops.resource_variable_ops.ResourceVariable'>
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-19-1677e05c2afc> in <module>()
3 embedding_dim=embedding_dim,
4 rnn_units=rnn_units,
----> 5 batch_size=BATCH_SIZE, freeze_embedding_layer=True, freeze_recurrent_kernel=True)
<ipython-input-18-62788170b303> in build_model(vocab_size, embedding_dim, rnn_units, batch_size, freeze_embedding_layer, freeze_recurrent_kernel)
15 print("rnn type:",type(model.layers[1]))
16 print("rnn recurrent kernel type:", type(model.layers[1].recurrent_kernel))
---> 17 model.layers[1].recurrent_kernel.trainable = False
18
19
AttributeError: can't set attribute