0

当我在 tf2.x 中使用带有 keras 的 tfrecord 时,它只读取第一批进行训练,我如何读取剩余的 tfrecord 数据

def get_dataset(self,
                    tfrecord_dataset,
                    num_parallel=4,
                    batch_size=16,
                    n_epoch=1,
                    buffer_size=10000,
                    _parse_function=_default_parser):
        dataset = tfrecord_dataset \
            .map(lambda x: (_parse_function(self, x), num_parallel)) \
            .shuffle(buffer_size=buffer_size) \
            .batch(batch_size) \
            .repeat(n_epoch)
        print(dataset)
        iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
        next_element = iterator.get_next()
        return next_element[0]
4

1 回答 1

0

我不明白您为什么使用 tf.compat.v1.data.make_one_shot_iterator包装数据集。在 tf2 中,您可以model.fit(dataset)直接使用来训练模型。

于 2020-08-04T13:13:40.153 回答