根据我之前提出的问题的答案,我正在尝试制作自定义指标word_accuracy
并在 tensorflow 中使用 CRNN char_accuracy
-CTC 模型实现。运行以下几行后,它在链接中运行良好:
import tensorflow as tf
tf.config.run_functions_eagerly(True)
这是 CTC 自定义层以及精度计算函数:
def calculate_accuracy(y_true, y_pred, metric, unknown_placeholder):
y_pred = tf.stack(y_pred)
y_true = tf.cast(y_true, y_pred.dtype)
unknown_indices = tf.where(y_pred == -1)
y_pred = tf.tensor_scatter_nd_update(
y_pred,
unknown_indices,
tf.cast(tf.ones(unknown_indices.shape[0]) * unknown_placeholder, tf.int64),
)
if metric == 'word':
return tf.where(tf.reduce_all(y_true == y_pred, 1)).shape[0] / y_true.shape[0]
if metric == 'char':
return tf.where(y_true == y_pred).shape[0] / tf.reduce_prod(y_true.shape)
return 0
class CTCLayer(Layer):
def __init__(self, max_label_length, unknown_placeholder, **kwargs):
super().__init__(**kwargs)
self.max_label_length = max_label_length
self.unknown_placeholder = unknown_placeholder
def call(self, *args):
y_true, y_pred = args
batch_length = tf.cast(tf.shape(y_true)[0], dtype='int64')
input_length = tf.cast(tf.shape(y_pred)[1], dtype='int64')
label_length = tf.cast(tf.shape(y_true)[1], dtype='int64')
input_length = input_length * tf.ones(shape=(batch_length, 1), dtype='int64')
label_length = label_length * tf.ones(shape=(batch_length, 1), dtype='int64')
loss = tf.keras.backend.ctc_batch_cost(
y_true, y_pred, input_length, label_length
)
if y_true.shape[1] is not None: # this is to prevent an error at model creation
predictions = decode_batch_predictions(y_pred, self.max_label_length)
self.add_metric(
calculate_accuracy(
y_true, predictions, 'word', self.unknown_placeholder
),
'word_accuracy',
)
self.add_metric(
calculate_accuracy(
y_true, predictions, 'char', self.unknown_placeholder
),
'char_accuracy',
)
self.add_loss(loss)
return y_pred
该if y_true.shape[1] is not None
块旨在防止在创建模型时发生错误,因为传递的是占位符而不是实际的张量。如果 if 语句不存在,会发生以下情况(无论是否渴望执行,我仍然会遇到相同的错误)
3 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
697 except Exception as e: # pylint:disable=broad-except
698 if hasattr(e, 'ag_error_metadata'):
--> 699 raise e.ag_error_metadata.to_exception(e)
700 else:
701 raise
ValueError: Exception encountered when calling layer "ctc_loss" (type CTCLayer).
in user code:
File "<ipython-input-6-fabf4ec5a640>", line 67, in call *
predictions = decode_batch_predictions(y_pred, self.max_label_length)
File "<ipython-input-6-fabf4ec5a640>", line 23, in decode_batch_predictions *
results = tf.keras.backend.ctc_decode(
File "/usr/local/lib/python3.7/dist-packages/keras/backend.py", line 6436, in ctc_decode
inputs=y_pred, sequence_length=input_length)
ValueError: Shape must be rank 1 but is rank 0 for '{{node ctc_loss/CTCGreedyDecoder}} = CTCGreedyDecoder[T=DT_FLOAT, blank_index=-1, merge_repeated=true](ctc_loss/Log_1, ctc_loss/Cast_9)' with input shapes: [31,?,20], [].
Call arguments received:
• args=('tf.Tensor(shape=(None, None), dtype=float32)', 'tf.Tensor(shape=(None, 31, 20), dtype=float32)')
注意:在图形执行中,标签的形状始终为(None, None)
,因此添加指标的 if 块下的代码永远不会执行。要使指标起作用,只需运行我包含的笔记本而不进行修改,稍后再对其进行修改以重现错误。
以下是启用 Eager Execution 时应该看到的内容:
/usr/local/lib/python3.7/dist-packages/tensorflow/python/data/ops/dataset_ops.py:4527: UserWarning: Even though the `tf.config.experimental_run_functions_eagerly` option is set, this option does not apply to tf.data functions. To force eager execution of tf.data functions, please use `tf.data.experimental.enable_debug_mode()`.
"Even though the `tf.config.experimental_run_functions_eagerly` "
Epoch 1/100
59/Unknown - 42s 177ms/step - loss: 18.1605 - word_accuracy: 0.0000e+00 - char_accuracy: 2.1186e-04
Epoch 00001: val_loss improved from inf to 17.36043, saving model to 1k_captcha.tf
59/59 [==============================] - 44s 213ms/step - loss: 18.1605 - word_accuracy: 0.0000e+00 - char_accuracy: 2.1186e-04 - val_loss: 17.3604 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0000e+00
Epoch 2/100
59/59 [==============================] - ETA: 0s - loss: 16.1261 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0021
Epoch 00002: val_loss improved from 17.36043 to 16.20875, saving model to 1k_captcha.tf
59/59 [==============================] - 13s 210ms/step - loss: 16.1261 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0021 - val_loss: 16.2087 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0000e+00
Epoch 3/100
59/59 [==============================] - ETA: 0s - loss: 15.8597 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0110
Epoch 00003: val_loss improved from 16.20875 to 16.11712, saving model to 1k_captcha.tf
59/59 [==============================] - 12s 204ms/step - loss: 15.8597 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0110 - val_loss: 16.1171 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0071
Epoch 4/100
59/59 [==============================] - ETA: 0s - loss: 15.3741 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0184
Epoch 00004: val_loss did not improve from 16.11712
59/59 [==============================] - 12s 207ms/step - loss: 15.3741 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0184 - val_loss: 16.6811 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0143
Epoch 5/100
59/59 [==============================] - ETA: 0s - loss: 14.9846 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0225
Epoch 00005: val_loss improved from 16.11712 to 15.23923, saving model to 1k_captcha.tf
59/59 [==============================] - 13s 214ms/step - loss: 14.9846 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0225 - val_loss: 15.2392 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0268
Epoch 6/100
59/59 [==============================] - ETA: 0s - loss: 14.4598 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0258
Epoch 00006: val_loss did not improve from 15.23923
59/59 [==============================] - 12s 207ms/step - loss: 14.4598 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0258 - val_loss: 18.6373 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0089
Epoch 7/100
59/59 [==============================] - ETA: 0s - loss: 13.8650 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0335
Epoch 00007: val_loss improved from 15.23923 to 14.37547, saving model to 1k_captcha.tf
59/59 [==============================] - 13s 215ms/step - loss: 13.8650 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0335 - val_loss: 14.3755 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0393
Epoch 8/100
59/59 [==============================] - ETA: 0s - loss: 13.1221 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0422
Epoch 00008: val_loss did not improve from 14.37547
59/59 [==============================] - 13s 208ms/step - loss: 13.1221 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0422 - val_loss: 14.4376 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0393
Epoch 9/100
59/59 [==============================] - ETA: 0s - loss: 12.2508 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0780
Epoch 00009: val_loss did not improve from 14.37547
59/59 [==============================] - 13s 211ms/step - loss: 12.2508 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0780 - val_loss: 14.8398 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0500
Epoch 10/100
59/59 [==============================] - ETA: 0s - loss: 11.0290 - word_accuracy: 0.0000e+00 - char_accuracy: 0.1460
Epoch 00010: val_loss did not improve from 14.37547
59/59 [==============================] - 13s 215ms/step - loss: 11.0290 - word_accuracy: 0.0000e+00 - char_accuracy: 0.1460 - val_loss: 14.4219 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.1054
Epoch 11/100
59/59 [==============================] - ETA: 0s - loss: 9.8587 - word_accuracy: 0.0011 - char_accuracy: 0.2004
Epoch 00011: val_loss improved from 14.37547 to 10.11944, saving model to 1k_captcha.tf
59/59 [==============================] - 13s 212ms/step - loss: 9.8587 - word_accuracy: 0.0011 - char_accuracy: 0.2004 - val_loss: 10.1194 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.1750
Epoch 12/100
59/59 [==============================] - ETA: 0s - loss: 8.6827 - word_accuracy: 0.0032 - char_accuracy: 0.2388
Epoch 00012: val_loss did not improve from 10.11944
59/59 [==============================] - 13s 216ms/step - loss: 8.6827 - word_accuracy: 0.0032 - char_accuracy: 0.2388 - val_loss: 10.3900 - val_word_accuracy: 0.0089 - val_char_accuracy: 0.1714
Epoch 13/100
59/59 [==============================] - ETA: 0s - loss: 7.4976 - word_accuracy: 0.0127 - char_accuracy: 0.3047
Epoch 00013: val_loss improved from 10.11944 to 8.38430, saving model to 1k_captcha.tf
59/59 [==============================] - 13s 215ms/step - loss: 7.4976 - word_accuracy: 0.0127 - char_accuracy: 0.3047 - val_loss: 8.3843 - val_word_accuracy: 0.0179 - val_char_accuracy: 0.2714
Epoch 14/100
59/59 [==============================] - ETA: 0s - loss: 6.6434 - word_accuracy: 0.0508 - char_accuracy: 0.3519
Epoch 00014: val_loss did not improve from 8.38430
59/59 [==============================] - 13s 217ms/step - loss: 6.6434 - word_accuracy: 0.0508 - char_accuracy: 0.3519 - val_loss: 9.5689 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.2571
Epoch 15/100
59/59 [==============================] - ETA: 0s - loss: 5.3200 - word_accuracy: 0.1398 - char_accuracy: 0.4271
Epoch 00015: val_loss improved from 8.38430 to 6.74445, saving model to 1k_captcha.tf
59/59 [==============================] - 13s 214ms/step - loss: 5.3200 - word_accuracy: 0.1398 - char_accuracy: 0.4271 - val_loss: 6.7445 - val_word_accuracy: 0.0804 - val_char_accuracy: 0.3482
Epoch 16/100
59/59 [==============================] - ETA: 0s - loss: 4.4252 - word_accuracy: 0.2108 - char_accuracy: 0.4799
Epoch 00016: val_loss improved from 6.74445 to 5.40682, saving model to 1k_captcha.tf
59/59 [==============================] - 13s 222ms/step - loss: 4.4252 - word_accuracy: 0.2108 - char_accuracy: 0.4799 - val_loss: 5.4068 - val_word_accuracy: 0.1161 - val_char_accuracy: 0.4446
Epoch 17/100
59/59 [==============================] - ETA: 0s - loss: 3.8119 - word_accuracy: 0.2691 - char_accuracy: 0.5206
Epoch 00017: val_loss improved from 5.40682 to 4.76755, saving model to 1k_captcha.tf
59/59 [==============================] - 13s 220ms/step - loss: 3.8119 - word_accuracy: 0.2691 - char_accuracy: 0.5206 - val_loss: 4.7676 - val_word_accuracy: 0.1964 - val_char_accuracy: 0.4929
Epoch 18/100
59/59 [==============================] - ETA: 0s - loss: 3.1290 - word_accuracy: 0.3379 - char_accuracy: 0.5712
Epoch 00018: val_loss improved from 4.76755 to 4.45828, saving model to 1k_captcha.tf
59/59 [==============================] - 13s 221ms/step - loss: 3.1290 - word_accuracy: 0.3379 - char_accuracy: 0.5712 - val_loss: 4.4583 - val_word_accuracy: 0.2768 - val_char_accuracy: 0.5375
Epoch 19/100
59/59 [==============================] - ETA: 0s - loss: 2.6048 - word_accuracy: 0.4163 - char_accuracy: 0.6267
Epoch 00019: val_loss improved from 4.45828 to 4.13174, saving model to 1k_captcha.tf
59/59 [==============================] - 13s 222ms/step - loss: 2.6048 - word_accuracy: 0.4163 - char_accuracy: 0.6267 - val_loss: 4.1317 - val_word_accuracy: 0.2054 - val_char_accuracy: 0.5143
Epoch 20/100
59/59 [==============================] - ETA: 0s - loss: 2.1555 - word_accuracy: 0.5117 - char_accuracy: 0.6979
Epoch 00020: val_loss improved from 4.13174 to 3.35257, saving model to 1k_captcha.tf
59/59 [==============================] - 13s 223ms/step - loss: 2.1555 - word_accuracy: 0.5117 - char_accuracy: 0.6979 - val_loss: 3.3526 - val_word_accuracy: 0.3482 - val_char_accuracy: 0.5518
Epoch 21/100
59/59 [==============================] - ETA: 0s - loss: 1.8185 - word_accuracy: 0.5604 - char_accuracy: 0.7284
Epoch 00021: val_loss did not improve from 3.35257
59/59 [==============================] - 13s 223ms/step - loss: 1.8185 - word_accuracy: 0.5604 - char_accuracy: 0.7284 - val_loss: 3.5486 - val_word_accuracy: 0.3304 - val_char_accuracy: 0.5500
Epoch 22/100
59/59 [==============================] - ETA: 0s - loss: 1.4279 - word_accuracy: 0.6578 - char_accuracy: 0.8021
Epoch 00022: val_loss improved from 3.35257 to 2.97987, saving model to 1k_captcha.tf
59/59 [==============================] - 14s 229ms/step - loss: 1.4279 - word_accuracy: 0.6578 - char_accuracy: 0.8021 - val_loss: 2.9799 - val_word_accuracy: 0.3750 - val_char_accuracy: 0.6679
Epoch 23/100
59/59 [==============================] - ETA: 0s - loss: 1.1666 - word_accuracy: 0.7278 - char_accuracy: 0.8417
Epoch 00023: val_loss did not improve from 2.97987
59/59 [==============================] - 13s 224ms/step - loss: 1.1666 - word_accuracy: 0.7278 - char_accuracy: 0.8417 - val_loss: 5.2543 - val_word_accuracy: 0.1429 - val_char_accuracy: 0.4768
Epoch 24/100
59/59 [==============================] - ETA: 0s - loss: 1.0938 - word_accuracy: 0.7511 - char_accuracy: 0.8576
Epoch 00024: val_loss improved from 2.97987 to 2.72415, saving model to 1k_captcha.tf
59/59 [==============================] - 14s 226ms/step - loss: 1.0938 - word_accuracy: 0.7511 - char_accuracy: 0.8576 - val_loss: 2.7242 - val_word_accuracy: 0.4911 - val_char_accuracy: 0.7250
Epoch 25/100
59/59 [==============================] - ETA: 0s - loss: 0.8378 - word_accuracy: 0.7977 - char_accuracy: 0.8837
Epoch 00025: val_loss improved from 2.72415 to 2.47315, saving model to 1k_captcha.tf
59/59 [==============================] - 13s 223ms/step - loss: 0.8378 - word_accuracy: 0.7977 - char_accuracy: 0.8837 - val_loss: 2.4731 - val_word_accuracy: 0.4554 - val_char_accuracy: 0.6964
Epoch 26/100
59/59 [==============================] - ETA: 0s - loss: 0.6497 - word_accuracy: 0.8633 - char_accuracy: 0.9195
Epoch 00026: val_loss improved from 2.47315 to 2.10521, saving model to 1k_captcha.tf
59/59 [==============================] - 14s 227ms/step - loss: 0.6497 - word_accuracy: 0.8633 - char_accuracy: 0.9195 - val_loss: 2.1052 - val_word_accuracy: 0.4821 - val_char_accuracy: 0.6929
Epoch 27/100
59/59 [==============================] - ETA: 0s - loss: 0.4810 - word_accuracy: 0.9153 - char_accuracy: 0.9528
Epoch 00027: val_loss did not improve from 2.10521
59/59 [==============================] - 14s 226ms/step - loss: 0.4810 - word_accuracy: 0.9153 - char_accuracy: 0.9528 - val_loss: 2.5292 - val_word_accuracy: 0.4375 - val_char_accuracy: 0.7054
Epoch 28/100
59/59 [==============================] - ETA: 0s - loss: 0.4621 - word_accuracy: 0.9121 - char_accuracy: 0.9500
Epoch 00028: val_loss did not improve from 2.10521
59/59 [==============================] - 14s 224ms/step - loss: 0.4621 - word_accuracy: 0.9121 - char_accuracy: 0.9500 - val_loss: 2.1713 - val_word_accuracy: 0.4821 - val_char_accuracy: 0.7268
要重现此问题,如果您之前运行过 notebook,您可能需要重新启动运行时,然后尝试在不急切执行的情况下运行,并且指标将永远不会显示。如果要重现错误,请注释掉该行if y_true.shape[1] is not None
并将 if 块与其余代码合并。我需要在提供的笔记本中修改什么以使指标按演示工作,而无需使用急切执行?