我开始使用tf.TFRecord and tf.Example
. 但是tensorflow.python.framework.errors_impl.InvalidArgumentError
当我尝试从保存的 tfrecords 文件加载数据时出现错误。我一直在为这个问题寻找很多解决方案,但没有奏效。
AUTO = tf.data.experimental.AUTOTUNE
def _parse_batch(record_batch, sample_rate, duration):
n_sample = sample_rate * duration
feature_description = {
'audio': tf.io.FixedLenFeature([n_sample], tf.float32),
'label': tf.io.VarLenFeature(tf.int64)
}
example = tf.io.parse_example(record_batch, feature_description)
return example['audio'], example['label']
def get_dataset_from_tfrecords(tfrecords_dir='tfrecords', split='train', batch_size=16,
sample_rate=44100, duration=4, n_epochs=10):
if split not in ('train', 'validate'):
raise ValueError("Split must be either 'train' or 'validate'")
pattern = os.path.join(tfrecords_dir, '{}*.tfrecord'.format(split))
ignore_order = tf.data.Options()
ignore_order.experimental_deterministic = False
filenames = tf.io.gfile.glob(pattern)
# Read TFRecord files in an interleaved order
dataset = tf.data.TFRecordDataset(filenames, compression_type='ZLIB', num_parallel_reads=AUTO)
dataset = dataset.with_options(ignore_order)
# Prepare batches
dataset = dataset.batch(batch_size)
# Parse a batch into a dataset of [audio, label] pairs
dataset = dataset.map(lambda x: _parse_batch(x, sample_rate, duration))
# Repeat the training data for n_epochs. Don't repeat test/validate splits.
if split == 'train':
dataset = dataset.repeat(n_epochs)
return dataset.prefetch(buffer_size=AUTO)
这是完整的错误
Traceback (most recent call last):
File "train.py", line 25, in <module>
main()
File "train.py", line 16, in main
n_epochs=n_epochs)
File "D:\Natural Language Processing\speech_to_text\utils\load_tfrecord.py", line 33, in get_dataset_from_tfrecords
dataset = tf.data.TFRecordDataset(filenames, compression_type='ZLIB', num_parallel_reads=AUTO)
File "C:\Users\levan\Anaconda3\lib\site-packages\tensorflow_core\python\data\ops\readers.py", line 304, in __init__
num_parallel_reads)
File "C:\Users\levan\Anaconda3\lib\site-packages\tensorflow_core\python\data\ops\readers.py", line 85, in _create_dataset_reader
prefetch_input_elements=None)
File "C:\Users\levan\Anaconda3\lib\site-packages\tensorflow_core\python\data\ops\readers.py", line 250, in __init__
**self._flat_structure)
File "C:\Users\levan\Anaconda3\lib\site-packages\tensorflow_core\python\ops\gen_experimental_dataset_ops.py", line 5977, in parallel_interleave_dataset
_six.raise_from(_core._status_to_exception(e.code, message), None)
File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: `cycle_length` must be > 0 [Op:ParallelInterleaveDataset]
谁能帮我?