我试图通过使用编码器创建一个连体网络来微调通用句子编码器。我想在训练期间训练 tensorflow_hub 通用编码器模块的权重,但我不确定如何使用估计器来做到这一点。
我的问题是,如果我在下面使用两个 hub.text_embedding_column 进行设置,它将训练两个单独的网络,而不是像训练它们是连体网络一样训练它们。如果不共享权重,我将如何更改它以便共享和训练权重。如果有帮助,我可以从本地机器加载模块。
def train_and_evaluate_with_module(hub_module, train_module=False):
embedded_text_feature_column1 = hub.text_embedding_column(
key="sentence1", module_spec=hub_module, trainable=train_module)
embedded_text_feature_column2 = hub.text_embedding_column(
key="sentence2", module_spec=hub_module, trainable=train_module)
estimator = tf.estimator.DNNClassifier(
hidden_units=[500, 100],
feature_columns=[embedded_text_feature_column1,embedded_text_feature_column2],
n_classes=2,
optimizer=tf.train.AdagradOptimizer(learning_rate=0.003))
estimator.train(input_fn=train_input_fn, steps=1000)
train_eval_result = estimator.evaluate(input_fn=predict_train_input_fn)
test_eval_result = estimator.evaluate(input_fn=predict_test_input_fn)
training_set_accuracy = train_eval_result["accuracy"]
test_set_accuracy = test_eval_result["accuracy"]
return {
"Training accuracy": training_set_accuracy,
"Test accuracy": test_set_accuracy
}
results = train_and_evaluate_with_module("https://tfhub.dev/google/universal-sentence-encoder-large/3", True)