1

我开始使用tf.TFRecord and tf.Example. 但是tensorflow.python.framework.errors_impl.InvalidArgumentError当我尝试从保存的 tfrecords 文件加载数据时出现错误。我一直在为这个问题寻找很多解决方案,但没有奏效。

AUTO = tf.data.experimental.AUTOTUNE


def _parse_batch(record_batch, sample_rate, duration):
    n_sample = sample_rate * duration

    feature_description = {
        'audio': tf.io.FixedLenFeature([n_sample], tf.float32),
        'label': tf.io.VarLenFeature(tf.int64)
    }

    example = tf.io.parse_example(record_batch, feature_description)

    return example['audio'], example['label']


def get_dataset_from_tfrecords(tfrecords_dir='tfrecords', split='train', batch_size=16,
                               sample_rate=44100, duration=4, n_epochs=10):
    if split not in ('train', 'validate'):
        raise ValueError("Split must be either 'train' or 'validate'")

    pattern = os.path.join(tfrecords_dir, '{}*.tfrecord'.format(split))

    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False
    filenames = tf.io.gfile.glob(pattern)

    # Read TFRecord files in an interleaved order
    dataset = tf.data.TFRecordDataset(filenames, compression_type='ZLIB', num_parallel_reads=AUTO)
    dataset = dataset.with_options(ignore_order)
    # Prepare batches
    dataset = dataset.batch(batch_size)

    # Parse a batch into a dataset of [audio, label] pairs
    dataset = dataset.map(lambda x: _parse_batch(x, sample_rate, duration))

    # Repeat the training data for n_epochs. Don't repeat test/validate splits.
    if split == 'train':
        dataset = dataset.repeat(n_epochs)

    return dataset.prefetch(buffer_size=AUTO)

这是完整的错误

Traceback (most recent call last):
  File "train.py", line 25, in <module>
    main()
  File "train.py", line 16, in main
    n_epochs=n_epochs)
  File "D:\Natural Language Processing\speech_to_text\utils\load_tfrecord.py", line 33, in get_dataset_from_tfrecords
    dataset = tf.data.TFRecordDataset(filenames, compression_type='ZLIB', num_parallel_reads=AUTO)
  File "C:\Users\levan\Anaconda3\lib\site-packages\tensorflow_core\python\data\ops\readers.py", line 304, in __init__
    num_parallel_reads)
  File "C:\Users\levan\Anaconda3\lib\site-packages\tensorflow_core\python\data\ops\readers.py", line 85, in _create_dataset_reader
    prefetch_input_elements=None)
  File "C:\Users\levan\Anaconda3\lib\site-packages\tensorflow_core\python\data\ops\readers.py", line 250, in __init__
    **self._flat_structure)
  File "C:\Users\levan\Anaconda3\lib\site-packages\tensorflow_core\python\ops\gen_experimental_dataset_ops.py", line 5977, in parallel_interleave_dataset
    _six.raise_from(_core._status_to_exception(e.code, message), None)
  File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: `cycle_length` must be > 0 [Op:ParallelInterleaveDataset]

谁能帮我?

4

1 回答 1

0

我在 tensorflow 2.0 上遇到了类似的问题,但是,升级到 2.1 解决了这个问题

于 2020-03-06T10:13:45.747 回答