1

这是使用 Petastorm 训练 mnist 数据的代码。

def train_and_test(dataset_url, training_iterations, batch_size, evaluation_interval):

    with make_reader(os.path.join(dataset_url, 'train'), num_epochs=None) as train_reader:
        with make_reader(os.path.join(dataset_url, 'test'), num_epochs=None) as test_reader:
            train_readout = tf_tensors(train_reader)
            train_image = tf.cast(tf.reshape(train_readout.image, [784]), tf.float32)
            train_label = train_readout.digit
            batch_image, batch_label = tf.train.batch(
                [train_image, train_label], batch_size=batch_size
            )

不知道怎么换tf.train.batch。你能帮忙吗?

4

1 回答 1

0

您可以使用dataset.batchtf.data.Dataset支持petastorm他们tf.data.Dataset网站中提到的内容。

有关tf.data.Datasetpetastorm您一起实施的代码,请在此处获取。
有关详细信息,dataset.batch您可以在此处找到。

于 2020-11-03T11:38:27.840 回答