3

我正在使用keras构建模型,并在tensorflow中编写优化代码和所有其他代码。当我使用像DenseConv2D这样非常简单的层时,一切都很简单。但是在我的 keras 模型中添加BatchNormalization层会使问题变得复杂。

由于BatchNormalization层在训练阶段和测试阶段的行为不同,我发现我的 feed_dict 需要K.learning_phase ( ):True。但是以下代码运行不正常。它运行没有错误,但模型的性能并没有变得更好。

import keras.backend as K
...
x_train, y_train = get_data()
sess.run(train_op, feed_dict={x:x_train, y:y_train, K.learning_phase():True})

当我尝试使用 keras fit函数训练 keras 模型时,效果很好。

我应该怎么做才能在tensorflow中使用BatchNormalization层训练keras模型?

4

1 回答 1

1

实际上我重复了这个我没见过的问题。

我在这里找到了答案,它只是将一个特殊参数传递给 BatchNormalization 层调用

于 2017-04-28T12:45:18.507 回答