2

我正在阅读 Tensorflow 的示例代码来构建我的第一个 CNN。原始代码运行良好。但是,当我尝试使用自己的输入功能而不是 originaltf.estimator.inputs.numpy_input_fn时,我遇到了问题。

我的代码如下,类似于本教程

def train_input_fn(x, y):
    dataset = tf.data.Dataset.from_tensor_slices(({"x": x}, y))
    dataset = dataset.shuffle(100000).repeat().batch(100)
    return dataset


my_classifier.train(
    input_fn=lambda: train_input_fn(train_data, train_labels),
    steps=200,
    hooks=[logging_hook],
)

但是,当我运行该代码时,在将以下内容打印到控制台后,Tensorflow 卡住了:

INFO:tensorflow:Using config: {'_is_chief': True, '_task_type': 'worker', '_save_checkpoints_secs': 600, '_task_id': 0, '_save_summary_steps': 10, '_keep_checkpoint_every_n_hours': 10000, '_master': '', '_evaluation_master': '', '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x00000217B16B64E0>, '_tf_random_seed': None, '_session_config': None, '_log_step_count_steps': 100, '_keep_checkpoint_max': 5, '_global_id_in_cluster': 0, '_model_dir': './model', '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_save_checkpoints_steps': None}
WARNING:tensorflow:Estimator's model_fn (<function cnn_model_fn at 0x000002179C3231E0>) includes params argument, but params are not passed to Estimator.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
2018-04-13 11:23:31.303380: I T:\src\github\tensorflow\tensorflow\core\platform\cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
2018-04-13 11:23:32.071305: I T:\src\github\tensorflow\tensorflow\core\common_runtime\gpu\gpu_device.cc:1344] Found device 0 with properties: 
name: GeForce GTX 960M major: 5 minor: 0 memoryClockRate(GHz): 1.0975
pciBusID: 0000:01:00.0
totalMemory: 2.00GiB freeMemory: 1.65GiB
2018-04-13 11:23:32.071834: I T:\src\github\tensorflow\tensorflow\core\common_runtime\gpu\gpu_device.cc:1423] Adding visible gpu devices: 0
2018-04-13 11:23:33.064563: I T:\src\github\tensorflow\tensorflow\core\common_runtime\gpu\gpu_device.cc:911] Device interconnect StreamExecutor with strength 1 edge matrix:
2018-04-13 11:23:33.064904: I T:\src\github\tensorflow\tensorflow\core\common_runtime\gpu\gpu_device.cc:917]      0 
2018-04-13 11:23:33.065131: I T:\src\github\tensorflow\tensorflow\core\common_runtime\gpu\gpu_device.cc:930] 0:   N 
2018-04-13 11:23:33.065460: I T:\src\github\tensorflow\tensorflow\core\common_runtime\gpu\gpu_device.cc:1041] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 1420 MB memory) -> physical GPU (device: 0, name: GeForce GTX 960M, pci bus id: 0000:01:00.0, compute capability: 5.0)
INFO:tensorflow:Restoring parameters from ./model\model.ckpt-800
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.

之后,Tensorflow 既没有训练也没有退出,太烦人了。

现在我的猜测是,在每个训练步骤中,它都会调用train_input_fn一次。由于该函数将对整个 MNIST 数据集进行洗牌,因此它可能在计算上效率低下。

为了验证我的想法,我尝试将数据集的大小从 60000 缩小到 100,并且代码运行良好。

那么解决这个问题的方法是什么?如何编写自定义输入函数而不会出现此类问题?我希望有人可以为此提供一些指导。

4

0 回答 0