1

我正在为 mnist 数据集使用 tesorflow_datasets 库在 GCP AI 平台上进行训练。我正在使用 tf.gan 估计器。我编写了一个使用 tfds 库读取 mnist 数据的输入管道。

import tensorflow_datasets as tfds
ds = tfds.load('mnist', split=self.split, shuffle_files=self.shuffle)

我已经在实例上使用相同的“tensorflow_datasets”库训练了我的 gan 模型,并且模型训练良好。我已将代码打包到包中,以便在 AI Platform 上运行。在 AI Platform 上训练期间,训练卡在警告中,它显示,

Dataset mnist is hosted on GCS. It will automatically be downloaded to your local data
directory. If you'd instead prefer to read directly from our public GCS bucket.

尽管训练停滞不前,但消耗的 ML 单元不断增加。

4

1 回答 1

1

“tensorflow_datasets”库采用参数“data_dir”。如果您在 GCP 上使用此库,建议使用“data_dir”,它会在已上传 tensorflow 数据集的地方进行桶浴。

import tensorflow_datasets as tfds
ds = tfds.load('mnist', split=self.split,shuffle_files=self.shuffle, \
     data_dir='gs://tfds-data/datasets')

希望这会。您也可以检查此存储桶。它是一个包含数据集的公共存储桶。

gsutil ls gs://tfds-data/datasets/

您可以查看所有数据集

gs://tfds-data/datasets/
gs://tfds-data/datasets/downloads/
gs://tfds-data/datasets/groove/
gs://tfds-data/datasets/mnist/
gs://tfds-data/datasets/nsynth/
gs://tfds-data/datasets/wikipedia/
于 2020-01-28T10:53:13.953 回答