我想使用 tf.keras 构建一个自定义层。为简单起见,假设它应该在训练期间返回输入*2,在测试期间返回输入*3。这样做的正确方法是什么?
我尝试了这种方法:
class CustomLayer(Layer):
@tf.function
def call(self, inputs, training=None):
if training:
return inputs*2
else:
return inputs*3
然后我可以像这样使用这个类:
>>> layer = CustomLayer()
>>> layer(10)
tf.Tensor(30, shape=(), dtype=int32)
>>> layer(10, training=True)
tf.Tensor(20, shape=(), dtype=int32)
它工作正常!但是,当我在模型中使用这个类并调用它的fit()
方法时,它似乎training
没有设置为True
. 我尝试在方法的开头添加以下代码call()
,但training
始终等于 0。
if training is None:
training = K.learning_phase()
我错过了什么?
编辑
我找到了一个解决方案(请参阅我的答案),但我仍在寻找更好的解决方案@tf.function
(我更喜欢亲笔签名而不是这项smart_cond()
业务)。不幸的是,它看起来K.learning_phase()
不太好@tf.function
(我的猜测是,当call()
函数被跟踪时,学习阶段会被硬编码到图中:因为这发生在调用fit()
方法之前,所以学习阶段总是 0)。这可能是一个错误,或者在使用@tf.function
.