以下语句,bn1 = tf.contrib.layers.batch_normalization(inputs=conv1, axis=1, training = is_training) 抛出错误:InternalError: FusedBatchNorm 的 CPU 实现目前仅支持 NHWC 张量格式。在我的 CPU 上使用 tensorflow v:1.4 但是,我确保代码使用 NHWC 格式的数据。同一段代码可以在我朋友的 CPU 上运行,唯一的区别是他使用的是 Tensorflow v.1.0,并且代码运行流畅,没有问题。
我试图查找 tensorflow 文档, https: //www.tensorflow.org/performance/performance_guide 它建议输入两个额外的参数:fused=True,data_format='NHWC'。
但是,根据 https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization ,上述两个论点没有这样的规定。事实上,代码会抛出一个错误,说 batch_normalization 收到了一个意外的参数。
任何关于问题背后的潜在原因以及我如何在不回滚我的 Tensorflow 版本(因为那将是荒谬的)的情况下解决它的任何回应都是最受欢迎的。非常感谢您的时间和精力。