当我在 tf2.x 中使用带有 keras 的 tfrecord 时,它只读取第一批进行训练,我如何读取剩余的 tfrecord 数据
def get_dataset(self,
tfrecord_dataset,
num_parallel=4,
batch_size=16,
n_epoch=1,
buffer_size=10000,
_parse_function=_default_parser):
dataset = tfrecord_dataset \
.map(lambda x: (_parse_function(self, x), num_parallel)) \
.shuffle(buffer_size=buffer_size) \
.batch(batch_size) \
.repeat(n_epoch)
print(dataset)
iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
next_element = iterator.get_next()
return next_element[0]