我正在尝试将此COLA 存储库调整到我在本地文件夹中的音频数据集。我主要更改文件 contrastive.py 以使方法 _get_ssl_task_data() 适应我的新数据库。
但是,我从 model.fit (它在下面调用我的 model.train_step(data) 方法)触发了一个错误。我试图通过修改 train_step 中的数据形状来修复此错误,但没有任何成功。
我不确定这是因为形状或数据类型不兼容而导致的错误,还是因为我需要添加更多内容来调整我的图表。有人知道我的代码有什么问题吗?如果这确实是问题,我该如何替换 tf.Tensor 的使用?
这是 contrastive.py 的内容:
"""Self-supervised model for contrastive learning task."""
import os
import tensorflow as tf
import constants
import data
import network
import numpy as np
import librosa
import sys, os, glob
class ContrastiveModel:
"""Provides functionality for self-supervised constrastive learning model."""
def __init__(self,
strategy,
ssl_dataset_name,
ds_dataset_name,
model_path,
experiment_id,
batch_size,
epochs, learning_rate,
embedding_dim,
temperature,
similarity_type,
pooling_type,
noise,
steps_per_epoch = 1000):
"""Initializes a contrastive model object."""
self._strategy = strategy
self._ssl_dataset_name = ssl_dataset_name
self._ds_dataset_name = ds_dataset_name
self._model_path = model_path
self._experiment_id = experiment_id
self._batch_size = batch_size
self._epochs = epochs
self._learning_rate = learning_rate
self._temperature = temperature
self._embedding_dim = embedding_dim
self._similarity_type = similarity_type
self._pooling_type = pooling_type
self._noise = noise
self._steps_per_epoch = steps_per_epoch
self._shuffle_buffer = 1000
self._n_frames = None
self._n_bands = 64
self._n_channels = 1
self._input_shape = (-1, self._n_frames, self._n_bands, self._n_channels)
def _prepare_example(self, example):
#Creates an example (anchor-positive) for instance discrimination.
example = tf.cast(example, tf.float32) / float(tf.int16.max)
x = tf.math.l2_normalize(example, epsilon=1e-9)
waveform_a = data.extract_window(x)
mels_a = data.extract_log_mel_spectrogram(waveform_a)
frames_anchors = mels_a[Ellipsis, tf.newaxis]
waveform_p = data.extract_window(x)
waveform_p = waveform_p + (
self._noise * tf.random.normal(tf.shape(waveform_p)))
mels_p = data.extract_log_mel_spectrogram(waveform_p)
frames_positives = mels_p[Ellipsis, tf.newaxis]
return frames_anchors, frames_positives
#my own added method to create dataset
def file_load(self, wav_name, mono=False):
try:
return librosa.load(wav_name, sr=None, mono=mono)
except:
logger.error("file_broken or not exists!! : {}".format(wav_name))
#my own added method to create dataset
def make_data(self, folder_name):
all_name = glob.glob(folder_name)
files = []
for name in all_name:
files.append(self.file_load(name)[0])
#return tf.map_fn(lambda file: _prepare_example(file), np.array(files))
files = np.array(files, dtype=object)
r = tf.TensorArray(tf.float32, 0, dynamic_size=True)
for file in files:
r = r.write(r.size(), self._prepare_example(file))
#r = r.write(r.size(), file)
return r.stack()
#my adapted method to create dataset
def _get_ssl_task_data(self):
#Prepares a dataset for contrastive self-supervised task.
data_dir='path/to/my/audio/folder'
dataset = self.make_data(data_dir + "/*")
dataset = tf.data.Dataset.from_tensor_slices(dataset)
ds = dataset.repeat()
ds = ds.shuffle(self._shuffle_buffer, reshuffle_each_iteration=True)
#ds = ds.map(self._prepare_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.batch(self._batch_size, drop_remainder=True)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
return ds
def train(self):
"""Trains a self-supervised model for contrastive learning."""
train_dataset = self._get_ssl_task_data()
train_dataset = self._strategy.experimental_distribute_dataset(train_dataset)
with self._strategy.scope():
contrastive_network = network.get_contrastive_network(
embedding_dim=self._embedding_dim,
temperature=self._temperature,
pooling_type=self._pooling_type,
similarity_type=self._similarity_type)
contrastive_network.compile(
optimizer=tf.keras.optimizers.Adam(self._learning_rate),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
ssl_model_dir = f"{self._ssl_dataset_name.value}/{self._experiment_id}/"
ckpt_path = os.path.join(self._model_path, ssl_model_dir, "ckpt_{epoch}")
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=ckpt_path, save_weights_only=True, monitor="loss")
backup_path = os.path.join(self._model_path, ssl_model_dir, "backup")
backandrestore_callback = tf.keras.callbacks.experimental.BackupAndRestore(
backup_dir=backup_path)
log_dir = os.path.join(self._model_path, "log", self._experiment_id)
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)
contrastive_network.fit(
train_dataset,
epochs=self._epochs,
batch_size=self._batch_size,#added
steps_per_epoch=self._steps_per_epoch,
verbose=2,
callbacks=[
model_checkpoint_callback,
backandrestore_callback,
tensorboard_callback,
])
这是 ContrastiveModel 的代码:
class ContrastiveModel(tf.keras.Model):
"""Wrapper class for custom contrastive model."""
def __init__(self, embedding_model, temperature, similarity_layer,
similarity_type):
super().__init__()
self.embedding_model = embedding_model
self._temperature = temperature
self._similarity_layer = similarity_layer
self._similarity_type = similarity_type
def train_step(self, data):
#tried to modify data shape but in vain
# shape of received data from input is [batch_size, 2, 98, 64, 1]
#data = tf.transpose(data, [1, 0, 2, 3, 4]) #failed attempt, with same error
anchors, positives = data # This one alone doesn't work either
with tf.GradientTape() as tape:
inputs = tf.concat([anchors, positives], axis=0)
embeddings = self.embedding_model(inputs, training=True)
anchor_embeddings, positive_embeddings = tf.split(embeddings, 2, axis=0)
# logits
similarities = self._similarity_layer(anchor_embeddings,
positive_embeddings)
if self._similarity_type == constants.SimilarityMeasure.DOT:
similarities /= self._temperature
sparse_labels = tf.range(tf.shape(anchors)[0])
loss = self.compiled_loss(sparse_labels, similarities)
loss += sum(self.losses)
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
self.compiled_metrics.update_state(sparse_labels, similarities)
return {m.name: m.result() for m in self.metrics}
def get_efficient_net_encoder(input_shape, pooling):
"""Wrapper function for efficient net B0."""
efficient_net = tf.keras.applications.EfficientNetB0(
include_top=False, weights=None, input_shape=input_shape, pooling=pooling)
# To set the name `encoder` as it is used by supervised module for
# to trainable value.
return tf.keras.Model(efficient_net.inputs, efficient_net.outputs, name="encoder")
def get_contrastive_network(embedding_dim,
temperature,
pooling_type="max",
similarity_type=constants.SimilarityMeasure.DOT,
input_shape=(None, 64, 1)):
"""Creates a model for contrastive learning task."""
inputs = tf.keras.layers.Input(input_shape)
encoder = get_efficient_net_encoder(input_shape, pooling_type)
x = encoder(inputs)
outputs = tf.keras.layers.Dense(embedding_dim, activation="linear")(x)
if similarity_type == constants.SimilarityMeasure.BILINEAR:
outputs = tf.keras.layers.LayerNormalization()(outputs)
outputs = tf.keras.layers.Activation("tanh")(outputs)
embedding_model = tf.keras.Model(inputs, outputs)
if similarity_type == constants.SimilarityMeasure.BILINEAR:
embedding_dim = embedding_model.output.shape[-1]
similarity_layer = BilinearProduct(embedding_dim)
else:
similarity_layer = DotProduct()
return ContrastiveModel(embedding_model, temperature, similarity_layer,
similarity_type)
当我运行上面代码中使用 strategy = tf.distribute.MirroredStrategy() 分发的代码时,我得到了这个完整的错误:
INFO:tensorflow:Error reported to Coordinator: iterating over `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.
Traceback (most recent call last):
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/training/coordinator.py", line 297, in stop_on_exception
yield
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/distribute/mirrored_run.py", line 323, in run
self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 667, in wrapper
return converted_call(f, args, kwargs, options=options)
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 396, in converted_call
return _call_unconverted(f, args, kwargs, options)
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 478, in _call_unconverted
return f(*args, **kwargs)
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 820, in run_step
outputs = model.train_step(data)
File "/lfs/eq/tim/project_sept2020/cola/network.py", line 66, in train_step
anchors, positives = data
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 505, in __iter__
self._disallow_iteration()
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 498, in _disallow_iteration
self._disallow_when_autograph_enabled("iterating over `tf.Tensor`")
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 474, in _disallow_when_autograph_enabled
raise errors.OperatorNotAllowedInGraphError(
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.
I1129 04:56:14.756884 140640742516480 coordinator.py:217] Error reported to Coordinator: iterating over `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.
Traceback (most recent call last):
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/training/coordinator.py", line 297, in stop_on_exception
yield
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/distribute/mirrored_run.py", line 323, in run
self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 667, in wrapper
return converted_call(f, args, kwargs, options=options)
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 396, in converted_call
return _call_unconverted(f, args, kwargs, options)
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 478, in _call_unconverted
return f(*args, **kwargs)
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 820, in run_step
outputs = model.train_step(data)
File "/lfs/eq/tim/project_sept2020/cola/network.py", line 66, in train_step
anchors, positives = data
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 505, in __iter__
self._disallow_iteration()
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 498, in _disallow_iteration
self._disallow_when_autograph_enabled("iterating over `tf.Tensor`")
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 474, in _disallow_when_autograph_enabled
raise errors.OperatorNotAllowedInGraphError(
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.
Traceback (most recent call last):
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/lfs/eq/tim/project_sept2020/cola/main.py", line 154, in <module>
app.run(main)
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/absl/app.py", line 300, in run
_run_main(main, args)
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/absl/app.py", line 251, in _run_main
sys.exit(main(argv))
File "/lfs/eq/tim/project_sept2020/cola/main.py", line 108, in main
model.train()
File "/lfs/eq/tim/project_sept2020/cola/contrastive.py", line 194, in train
contrastive_network.fit(
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 1132, in fit
tmp_logs = self.train_function(iterator)
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 784, in __call__
result = self._call(*args, **kwds)
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 827, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 681, in _initialize
self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2998, in _get_concrete_function_internal_garbage_collected
graph_function, _ = self._maybe_define_function(args, kwargs)
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3390, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3225, in _create_graph_function
func_graph_module.func_graph_from_py_func(
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 998, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 590, in wrapped_fn
out = weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 985, in wrapper
raise e.ag_error_metadata.to_exception(e)
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: in user code:
/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:837 train_function *
return step_function(self, iterator)
/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:827 step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:1259 run
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:2731 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/distribute/mirrored_strategy.py:628 _call_for_each_replica
return mirrored_run.call_for_each_replica(
/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/distribute/mirrored_run.py:93 call_for_each_replica
return _call_for_each_replica(strategy, fn, args, kwargs)
/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/distribute/mirrored_run.py:234 _call_for_each_replica
coord.join(threads)
/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/training/coordinator.py:389 join
six.reraise(*self._exc_info_to_raise)
/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/six.py:703 reraise
raise value
/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/training/coordinator.py:297 stop_on_exception
yield
/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/distribute/mirrored_run.py:323 run
self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:820 run_step **
outputs = model.train_step(data)
/lfs/eq/tim/project_sept2020/cola/network.py:66 train_step
anchors, positives = data
/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:505 __iter__
self._disallow_iteration()
/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:498 _disallow_iteration
self._disallow_when_autograph_enabled("iterating over `tf.Tensor`")
/misc/home/rc/tim/anaconda3/envs/tfenv2cola/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:474 _disallow_when_autograph_enabled
raise errors.OperatorNotAllowedInGraphError(
OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.
Epoch 1/100