5

我正在解决一个文本分类问题。Estimator我使用我自己的类定义了我的分类器model_fn。我想使用 Google 的预训练word2vec嵌入作为初始值,然后针对手头的任务进一步优化它。

我看到了这篇文章:在 TensorFlow 中使用预训练的词嵌入(word2vec 或 Glove),
它解释了如何在“原始”TensorFlow 代码中进行处理。但是,我真的很想使用这个Estimator类。

作为扩展,我想在 Cloud ML Engine 上训练这段代码,有没有一种很好的方法可以传入具有初始值的相当大的文件?

假设我们有类似的东西:

def build_model_fn():
    def _model_fn(features, labels, mode, params):
        input_layer = features['feat'] #shape=[-1, params["sequence_length"]]
        #... what goes here to initialize W

        embedded = tf.nn.embedding_lookup(W, input_layer)
        ...
        return predictions

estimator = tf.contrib.learn.Estimator(
    model_fn=build_model_fn(),
    model_dir=MODEL_DIR,
    params=params)
estimator.fit(input_fn=read_data, max_steps=2500)
4

1 回答 1

9

嵌入通常足够大,唯一可行的方法是使用它们来初始化tf.Variable图中的 a。这将允许您利用分布式等的参数服务器。

为此(以及其他),我建议您使用新的“核心”估算器,tf.estimator.Estimator因为这会使事情变得更容易。

根据您提供的链接中的答案,并且知道我们想要一个变量而不是常量,我们可以采取以下方法:

(2) 使用 feed dict 初始化变量,或 (3) 从检查点加载变量


我将首先介绍选项(3),因为它更容易,更好:

在您的中,只需使用调用返回model_fn的变量初始化一个变量。这需要:Tensortf.contrib.framework.load_variable

  1. 您的嵌入具有有效的 TF 检查点
  2. 您知道检查点内嵌入变量的完全限定名称。

代码非常简单:

def model_fn(mode, features, labels, hparams):
  embeddings = tf.Variable(tf.contrib.framework.load_variable(
      'gs://my-bucket/word2vec_checkpoints/',
      'a/fully/qualified/scope/embeddings'
  ))
  ....
  return tf.estimator.EstimatorSpec(...)

但是,如果您的嵌入不是由另一个 TF 模型生成的,则此方法对您不起作用,因此选项 (2)。


对于 (2),我们需要使用tf.train.Scaffoldwhich 本质上是一个配置对象,它包含启动 a 的所有选项tf.Session(由于很多原因,估计器有意隐藏)。

您可以Scaffoldtf.train.EstimatorSpec您的model_fn.

我们在 model_fn 中创建一个占位符,并将其作为嵌入变量的初始化操作,然后init_feed_dict通过Scaffold. 例如

def model_fn(mode, features, labels, hparams):
  embed_ph = tf.placeholder(
      shape=[hparams.vocab_size, hparams.embedding_size], 
      dtype=tf.float32)
  embeddings = tf.Variable(embed_ph)
  # Define your model
  return tf.estimator.EstimatorSpec(
      ..., # normal EstimatorSpec args
      scaffold=tf.train.Scaffold(init_feed_dict={embed_ph: my_embedding_numpy_array})
  )

这里发生的是init_feed_dict将在运行时填充embed_ph占位符的值,然后允许embeddings.initialization_op(占位符的分配)运行。


于 2017-06-21T21:46:12.347 回答