我正在为 mnist 数据集使用 tesorflow_datasets 库在 GCP AI 平台上进行训练。我正在使用 tf.gan 估计器。我编写了一个使用 tfds 库读取 mnist 数据的输入管道。
import tensorflow_datasets as tfds
ds = tfds.load('mnist', split=self.split, shuffle_files=self.shuffle)
我已经在实例上使用相同的“tensorflow_datasets”库训练了我的 gan 模型,并且模型训练良好。我已将代码打包到包中,以便在 AI Platform 上运行。在 AI Platform 上训练期间,训练卡在警告中,它显示,
Dataset mnist is hosted on GCS. It will automatically be downloaded to your local data
directory. If you'd instead prefer to read directly from our public GCS bucket.
尽管训练停滞不前,但消耗的 ML 单元不断增加。