当我尝试使用任一输出/损失组合时,模型运行良好,但当我尝试同时使用这两种组合时失败。因此,如果我只是不包括模型定义中的输出之一并且还消除了额外的损失,它就可以正常工作。
def prepare_dataset(ds, shuffle = False):
# ds = ds.cache()
if shuffle:
ds = ds.shuffle(buffer_size=500)
ds = ds.batch(BATCH_SIZE)
ds = ds.prefetch(buffer_size=AUTOTUNE)
return ds
...model definition dataset fetching
model = tf.keras.Model(inputs, [output_sp, output_norms])
train_ds = tf.data.Dataset.from_tensor_slices((X_train_file, y_train_1, y_train_2))
train_ds = train_ds.map(process_path, num_parallel_calls=AUTOTUNE)
#Here I check the first element and all parts of it have the right shape
model.compile(
optimizer='adam',
# loss=tf.keras.losses.MeanSquare(),
loss= (tf.keras.losses.MeanAbsoluteError(), tf.keras.losses.MeanSquaredError()), loss_weights=[.8, 1])
train_ds = prepare_dataset(train_ds, shuffle=True)
model.fit(
augmented,
# validation_data = val_ds,
epochs=1,
batch_size=BATCH_SIZE,
callbacks=[cp_callback]
)
Traceback (most recent call last): File "/home/scandy/Developer/RouxNN/models/normals_and_sp_transfer.py", line 245, in <module>
model.fit( File "/home/scandy/miniconda3/envs/tf-rouxnn/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 1100, in fit
tmp_logs = self.train_function(iterator) File "/home/scandy/miniconda3/envs/tf-rouxnn/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 828, in __call__
result = self._call(*args, **kwds) File "/home/scandy/miniconda3/envs/tf-rouxnn/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 888, in _call
return self._stateless_fn(*args, **kwds) File "/home/scandy/miniconda3/envs/tf-rouxnn/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2942, in __call__
return graph_function._call_flat( File "/home/scandy/miniconda3/envs/tf-rouxnn/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1918, in _call_flat
return self._build_call_outputs(self._inference_function.call( File "/home/scandy/miniconda3/envs/tf-rouxnn/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 555, in call
outputs = execute.execute( File "/home/scandy/miniconda3/envs/tf-rouxnn/lib/python3.8/site-packages/tensorflow/python/eager/execute.py", line 59, in quick_execute
tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name, tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found. (0) Invalid argument: Can not squeeze dim[2], expected a dimension of 1, got 3
[[{{node mean_absolute_error/weighted_loss/cond/else/_66/mean_absolute_error/weighted_loss/cond/cond/then/_387/mean_absolute_error/weighted_loss/cond/cond/Squeeze}}]]
[[mean_squared_error/cond/then/_74/mean_squared_error/cond/cond/pivot_t/_400/_117]] (1) Invalid argument: Can not squeeze dim[2], expected a dimension of 1, got 3
[[{{node mean_absolute_error/weighted_loss/cond/else/_66/mean_absolute_error/weighted_loss/cond/cond/then/_387/mean_absolute_error/weighted_loss/cond/cond/Squeeze}}]]