我在 Keras 中运行 CRNN 模型来执行一些手写识别,但在计算 CTC 损失时出现错误。
仅当我尝试为我的 CNN 加载预训练网络时才会出现此问题。如果我从头开始制作自己的 CNN 网络,它工作得很好。
这是我遇到错误的模型:
def get_DensenetLSTM(training):
input_shape = (img_w, img_h, 3)
inputs = Input(name='the_input', shape=input_shape, dtype='float32')
# out = Conv2D(3, (3, 3), padding='same', name='conv1', kernel_initializer='he_normal')(inputs)
# out = Reshape(target_shape=((250, 80, 3)), name='reshape_inp')(inner)
densenet = DenseNet121(include_top=False, weights='imagenet', input_tensor= inputs)
inner = densenet.output
# CNN to RNN
inner = Reshape(target_shape=((32, 1536)), name='reshape')(inner) # (None, 32, 2048)
# print(inner.shape)
inner = Dense(64, activation='relu', kernel_initializer='he_normal', name='dense1')(inner) # (None, 32, 64)
# print(inner.shape)
# RNN layer
lstm_1 = LSTM(256, return_sequences=True, kernel_initializer='he_normal', name='lstm1')(inner) # (None, 32, 512)
lstm_1b = LSTM(256, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', name='lstm1_b')(inner)
lstm1_merged = add([lstm_1, lstm_1b]) # (None, 32, 512)
lstm1_merged = BatchNormalization()(lstm1_merged)
lstm_2 = LSTM(256, return_sequences=True, kernel_initializer='he_normal', name='lstm2')(lstm1_merged)
lstm_2b = LSTM(256, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', name='lstm2_b')(lstm1_merged)
lstm2_merged = concatenate([lstm_2, lstm_2b]) # (None, 32, 1024)
lstm_merged = BatchNormalization()(lstm2_merged)
# transforms RNN output to character activations:
inner = Dense(num_classes, kernel_initializer='he_normal',name='dense2')(lstm2_merged) #(None, 32, 63)
y_pred = Activation('softmax', name='softmax')(inner)
labels = Input(name='the_labels', shape=[max_text_len], dtype='float32') # (None ,8)
input_length = Input(name='input_length', shape=[1], dtype='int64') # (None, 1)
label_length = Input(name='label_length', shape=[1], dtype='int64') # (None, 1)
# Keras doesn't currently support loss funcs with extra parameters
# so CTC loss is implemented in a lambda layer
loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([y_pred, labels, input_length, label_length]) #(None, 1)
if training:
return Model(inputs=[inputs, labels, input_length, label_length], outputs=loss_out)
else:
return Model(inputs=[inputs], outputs=y_pred)
这工作正常:
def get_Model(training):
input_shape = (img_w, img_h, 1)
# Make Network
inputs = Input(name='the_input', shape=input_shape, dtype='float32')
inner = Conv2D(64, (3, 3), padding='same', name='conv1', kernel_initializer='he_normal')(inputs) # (None, 128, 64, 64)
inner = BatchNormalization()(inner)
inner = Activation('relu')(inner)
inner = MaxPooling2D(pool_size=(2, 2), name='max1')(inner)
inner = Conv2D(128, (3, 3), padding='same', name='conv2', kernel_initializer='he_normal')(inner)
inner = BatchNormalization()(inner)
inner = Activation('relu')(inner)
inner = MaxPooling2D(pool_size=(2, 2), name='max2')(inner)
inner = Conv2D(256, (3, 3), padding='same', name='conv3', kernel_initializer='he_normal')(inner)
inner = BatchNormalization()(inner)
inner = Activation('relu')(inner)
inner = Conv2D(256, (3, 3), padding='same', name='conv4', kernel_initializer='he_normal')(inner)
inner = BatchNormalization()(inner)
inner = Activation('relu')(inner)
inner = MaxPooling2D(pool_size=(1, 2), name='max3')(inner)
inner = Conv2D(512, (3, 3), padding='same', name='conv5', kernel_initializer='he_normal')(inner)
inner = BatchNormalization()(inner)
inner = Activation('relu')(inner)
inner = Conv2D(512, (3, 3), padding='same', name='conv6')(inner)
inner = BatchNormalization()(inner)
inner = Activation('relu')(inner)
inner = MaxPooling2D(pool_size=(1, 2), name='max4')(inner)
inner = Conv2D(512, (2, 2), padding='same', kernel_initializer='he_normal', name='con7')(inner)
inner = BatchNormalization()(inner)
inner = Activation('relu')(inner)
# CNN to RNN
inner = Reshape(target_shape=((62, 2560)), name='reshape')(inner)
inner = Dense(64, activation='relu', kernel_initializer='he_normal', name='dense1')(inner)
# RNN layer
lstm_1 = LSTM(256, return_sequences=True, kernel_initializer='he_normal', name='lstm1')(inner)
lstm_1b = LSTM(256, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', name='lstm1_b')(inner)
lstm1_merged = add([lstm_1, lstm_1b])
lstm1_merged = BatchNormalization()(lstm1_merged)
lstm_2 = LSTM(256, return_sequences=True, kernel_initializer='he_normal', name='lstm2')(lstm1_merged)
lstm_2b = LSTM(256, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', name='lstm2_b')(lstm1_merged)
lstm2_merged = concatenate([lstm_2, lstm_2b])
lstm_merged = BatchNormalization()(lstm2_merged)
# transforms RNN output to character activations:
inner = Dense(num_classes, kernel_initializer='he_normal',name='dense2')(lstm2_merged)
y_pred = Activation('softmax', name='softmax')(inner)
labels = Input(name='the_labels', shape=[max_text_len], dtype='float32')
input_length = Input(name='input_length', shape=[1], dtype='int64')
label_length = Input(name='label_length', shape=[1], dtype='int64')
# Keras doesn't currently support loss funcs with extra parameters
# so CTC loss is implemented in a lambda layer
loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([y_pred, labels, input_length, label_length]) #(None, 1)
if training:
return Model(inputs=[inputs, labels, input_length, label_length], outputs=loss_out)
else:
return Model(inputs=[inputs], outputs=y_pred)
这是 ctc_loss 函数:
def ctc_lambda_func(args):
y_pred, labels, input_length, label_length = args
# the 2 is critical here since the first couple outputs of the RNN
# tend to be garbage:
y_pred = y_pred[:, 2:, :]
print(y_pred.shape)
print(input_length)
print(labels.shape)
return K.ctc_batch_cost(labels, y_pred, input_length, label_length)
当我运行 Densenet_LSTM 模型时,我收到以下错误:
tensorflow.python.framework.errors_impl.InvalidArgumentError: sequence_length(0) <= 30
[[{{node ctc/CTCLoss}} = CTCLoss[_class=["loc:@training/Adam/gradients/ctc/CTCLoss_grad/mul"], ctc_merge_repeated=true, ignore_longer_outputs_than_inputs=false, preprocess_collapse_repeated=false, _device="/job:localhost/replica:0/task:0/device:CPU:0"](ctc/Log/_7309, ctc/ToInt64/_7311, ctc/ToInt32_2/_7313, ctc/ToInt32_1/_7315)]]
请帮忙。