0

我正在使用 ELMO 层训练用于令牌分类的 Keras 模型。我需要保存模型以备将来使用,我已经尝试使用 model.save_weights("model_weights.h5"),但是如果我将它们加载到我构建的新模型中,然后我调用 model.predict(. ..),我得到的结果就好像模型从未被训练过一样。看起来配置没有正确保存。

我是 keras 和 tensorflow 1 的新手,我不确定这是否是这样做的方法。欢迎任何帮助!我显然在这里遗漏了一些东西,但是我找不到足够的方法来保存带有 elmo 层的模型。

我正在定义这样的模型:

def ElmoEmbedding(x):
    return elmo_model(inputs={"tokens": tf.squeeze(tf.cast(x, tf.string)),
                              "sequence_len": tf.constant(batch_size*[max_len])},
                      signature="tokens",
                      as_dict=True)["elmo"]

def build_model(max_len, n_words, n_tags): 
    word_input_layer = Input(shape=(max_len, 40, ))
    elmo_input_layer = Input(shape=(max_len,), dtype=tf.string)
    
    word_output_layer = Dense(n_tags, activation = 'softmax')(word_input_layer)
    elmo_output_layer = Lambda(ElmoEmbedding, output_shape=(1, 1024))(elmo_input_layer)
    
    output_layer = Concatenate()([word_output_layer, elmo_output_layer])
    output_layer = BatchNormalization()(output_layer)
    output_layer = Bidirectional(LSTM(units=512, return_sequences=True, recurrent_dropout=0.2, dropout=0.2))(output_layer)
    output_layer = TimeDistributed(Dense(n_tags, activation='softmax'))(output_layer)
    
    model = Model([elmo_input_layer, word_input_layer], output_layer)
    
    return model

然后我像这样进行培训:

tf.disable_eager_execution()
elmo_model = hub.Module("https://tfhub.dev/google/elmo/3", trainable=False)

sess = tf.Session()
K.set_session(sess)
sess.run([tf.global_variables_initializer(), tf.tables_initializer()])

model = build_model(max_len, n_words, n_tags)
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

history = model.fit([np.array(X1_train), np.array(X2_train).reshape((len(X2_train), max_len, 40))],
                    y_train,
                    validation_data=([np.array(X1_valid), np.array(X2_valid).reshape((len(X2_valid), max_len, 40))], y_valid),
                    batch_size=batch_size, epochs=5, verbose=1)

model.save_weights("model_weights.h5")

如果我尝试在另一个会话中加载权重,如下所示,我的准确度为零:

tf.disable_eager_execution()
elmo_model = hub.Module("https://tfhub.dev/google/elmo/3", trainable=False)

sess = tf.Session()
K.set_session(sess)
sess.run([tf.global_variables_initializer(), tf.tables_initializer()])

model = build_model(max_len, n_words, n_tags)
model.load_weights("model_weights.h5")
y_pred = model.predict([X1_test, np.array(X2_test).reshape((len(X2_test), max_len, 40))])
4

0 回答 0