我正在解决一个文本分类问题。Estimator
我使用我自己的类定义了我的分类器model_fn
。我想使用 Google 的预训练word2vec
嵌入作为初始值,然后针对手头的任务进一步优化它。
我看到了这篇文章:在 TensorFlow 中使用预训练的词嵌入(word2vec 或 Glove),
它解释了如何在“原始”TensorFlow 代码中进行处理。但是,我真的很想使用这个Estimator
类。
作为扩展,我想在 Cloud ML Engine 上训练这段代码,有没有一种很好的方法可以传入具有初始值的相当大的文件?
假设我们有类似的东西:
def build_model_fn():
def _model_fn(features, labels, mode, params):
input_layer = features['feat'] #shape=[-1, params["sequence_length"]]
#... what goes here to initialize W
embedded = tf.nn.embedding_lookup(W, input_layer)
...
return predictions
estimator = tf.contrib.learn.Estimator(
model_fn=build_model_fn(),
model_dir=MODEL_DIR,
params=params)
estimator.fit(input_fn=read_data, max_steps=2500)