3

我正在使用 Keras 训练神经网络,但在添加更多功能时会出现 nan 损失。我尝试了其他讨论中提到的以下解决方案:降低学习率,不同的优化器,添加 epsilon=1e-08 / clipnorm=1。优化器的参数。然而,这并没有解决它。

数据集很小,有 1000 个元素。我使用不同的功能集。使用 25 个功能时没有问题并且性能良好。但是当使用所有 53 个功能时,我只得到 nan 损失。(当单独使用功能集时,也没有问题,所以我认为问题不在于功能集本身,而在于它们的高数量)。网络如下。有什么建议可以解决这个问题吗?

embedding_layer = Embedding(nb_words,
                            EMBEDDING_DIM,
                            weights=[embedding_matrix],
                            input_length=300,
                            trainable=False)
lstm_layer = LSTM(75, recurrent_dropout=0.2)

sequence_1_input = Input(shape=(300,), dtype="int32")
embedded_sequences_1 = embedding_layer(sequence_1_input)
x1 = lstm_layer(embedded_sequences_1)

sequence_2_input = Input(shape=(300,), dtype="int32")
embedded_sequences_2 = embedding_layer(sequence_2_input)
y1 = lstm_layer(embedded_sequences_2)

features_input = Input(shape=(f_train.shape[1],), dtype="float32")
features_dense = BatchNormalization()(features_input)
features_dense = Dense(200, activation="relu")(features_dense)
features_dense = Dropout(0.2)(features_dense)

addition = add([x1, y1])
minus_y1 = Lambda(lambda x: -x)(y1)
merged = add([x1, minus_y1])
merged = multiply([merged, merged])
merged = concatenate([merged, addition])
merged = Dropout(0.4)(merged)

merged = concatenate([merged, features_dense])
merged = BatchNormalization()(merged)
merged = GaussianNoise(0.1)(merged)

merged = Dense(150, activation="relu")(merged)
merged = Dropout(0.2)(merged)
merged = BatchNormalization()(merged)

out = Dense(1, activation="sigmoid")(merged)
optimizer = nadam(epsilon=1e-08, lr = 0.00001)
model = Model(inputs=[sequence_1_input, sequence_2_input, features_input], 
  outputs=out) 
model.compile(loss="binary_crossentropy", optimizer=optimizer)
early_stopping = EarlyStopping(monitor="val_loss", patience=5)
best_model_path = "best_model" + str(model_count) + ".h5"
model_checkpoint = ModelCheckpoint(best_model_path, save_best_only=True, 
save_weights_only=True)

hist = model.fit([data_1_train, data_2_train, f_train], labels_train,
                 validation_data=([data_1_val, data_2_val, f_val], labels_val),
                 epochs=15, batch_size=BATCH_SIZE, shuffle=True,
                 callbacks=[early_stopping, model_checkpoint], verbose=1)
4

0 回答 0