我有两个训练有素的模型 (model_A
和model_B
),它们都有 dropout 层。我已经冻结model_A
并将model_B
它们与一个新的密集层合并以获得model_AB
(但我没有删除model_A
's和model_B
's dropout层)。model_AB
的权重将是不可训练的,除了添加的密集层。
现在我的问题是:当我训练时,辍学层是否处于活动状态(即丢弃神经元model_A
)?model_B
model_AB
我有两个训练有素的模型 (model_A
和model_B
),它们都有 dropout 层。我已经冻结model_A
并将model_B
它们与一个新的密集层合并以获得model_AB
(但我没有删除model_A
's和model_B
's dropout层)。model_AB
的权重将是不可训练的,除了添加的密集层。
现在我的问题是:当我训练时,辍学层是否处于活动状态(即丢弃神经元model_A
)?model_B
model_AB
简短回答:即使您将它们的属性设置为,dropout 层也会在训练期间继续丢弃神经元。trainable
False
长答案: Keras 中有两个不同的概念:
更新层的权重和状态:这是使用trainable
该层的属性控制的,即如果您设置,layer.trainable = False
则不会更新层的权重和内部状态。
层在训练和测试阶段的行为:如您所知,层,如 dropout,在训练和测试阶段可能具有不同的行为。Keras 中的学习阶段使用keras.backend.set_learning_phase()
. 例如,当你调用model.fit(...)
学习阶段时自动设置为1(即训练),而当你使用时model.predict(...)
它会自动设置为0(即测试)。此外,请注意,学习阶段 1(即训练)并不一定意味着更新层的权重/状态。您可以在学习阶段 1(即训练阶段)运行您的模型,但不会更新权重;只是层将切换到它们的训练行为(有关更多信息,请参阅此答案)。此外,还有另一种方法可以通过传递为每个单独的层设置学习阶段training=True
在张量上调用层时的参数(有关更多信息,请参见此答案)。
因此,根据以上几点,当您设置trainable=False
一个 dropout 层并在训练模式下使用它(例如,通过调用model.fit(...)
或手动将学习阶段设置为像下面的示例那样进行训练),神经元仍然会被 dropout 层丢弃。
这是一个可重现的示例,说明了这一点:
from keras import layers
from keras import models
from keras import backend as K
import numpy as np
inp = layers.Input(shape=(10,))
out = layers.Dropout(0.5)(inp)
model = models.Model(inp, out)
model.layers[-1].trainable = False # set dropout layer as non-trainable
model.compile(optimizer='adam', loss='mse') # IMPORTANT: we must always compile model after changing `trainable` attribute
# create a custom backend function so that we can control the learning phase
func = K.function(model.inputs + [K.learning_phase()], model.outputs)
x = np.ones((1,10))
# learning phase = 1, i.e. training mode
print(func([x, 1]))
# the output will be:
[array([[2., 2., 2., 0., 0., 2., 2., 2., 0., 0.]], dtype=float32)]
# as you can see some of the neurons have been dropped
# now set learning phase = 0, i.e test mode
print(func([x, 0]))
# the output will be:
[array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)]
# unsurprisingly, no neurons have been dropped in test phase