我为 is_training 参数设置了一个占位符,slim.batch_norm
如下所示:
is_traing_ph = tf.placeholder(tf.bool)
output = slim.batch_norm(
input,
activation_fn=activation_fn,
is_training=is_training_ph,
updates_collections=None,
scale=scale,
scope=scope)
像这样喂它:
sess.run(train_op, feed_dict={is_training_ph:False}
当我用 True 输入 is_training_ph 时,程序正常,但是当我用 False 输入 is_traing_ph 时,程序会抛出 OOM 错误。
而且,当我不使用这样的占位符时:
output = slim.batch_norm(
input,
activation_fn=activation_fn,
is_training=True,
updates_collections=None,
scale=scale,
scope=scope)
这不是任何问题。
这是我的完整测试代码和日志跟踪: https ://gist.github.com/xxxzhi/8fc8f840a8ec07fdbae7c2fc2c77b3da
有谁知道原因?它是一个错误slim.batch_norm
吗?
GPU的显存为12G。CUDA 8,张量流1.2,张量流1.3
提前致谢。