这是使用 Petastorm 训练 mnist 数据的代码。
def train_and_test(dataset_url, training_iterations, batch_size, evaluation_interval):
with make_reader(os.path.join(dataset_url, 'train'), num_epochs=None) as train_reader:
with make_reader(os.path.join(dataset_url, 'test'), num_epochs=None) as test_reader:
train_readout = tf_tensors(train_reader)
train_image = tf.cast(tf.reshape(train_readout.image, [784]), tf.float32)
train_label = train_readout.digit
batch_image, batch_label = tf.train.batch(
[train_image, train_label], batch_size=batch_size
)
不知道怎么换tf.train.batch
。你能帮忙吗?