我正在使用 CNN 的入门示例Tensorflow
并将参数更新为我自己的数据,但由于我的模型很大(244 * 244 个特征),我得到了OutOfMemory
错误。
我在 Ubuntu 14.04 上进行培训,配备 4 个 CPU 和 16Go 的 RAM。
有没有办法缩小我的数据,这样我就不会收到这个 OOM 错误?
我的代码如下所示:
# Create the Estimator
mnist_classifier = tf.estimator.Estimator(
model_fn=cnn_model_fn, model_dir="path/to/model")
# Load the data
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": np.array(training_set.data)},
y=np.array(training_set.target),
num_epochs=None,
batch_size=5,
shuffle=True)
# Train the model
mnist_classifier.train(
input_fn=train_input_fn,
steps=100,
hooks=[logging_hook])