4

是否可以在不费力重写其函数的情况下将 TensorFlow 中的 an 转换Estimator为 a ?TPUEstimator我有一个在EstimatorCPU 上运行良好的模型,但我不知道一种方便的方法将其转换为 aTPUEstimator而无需重写model_fnand input_fn

这需要手动完成大量工作的原因是我使用 Keras 创建模型,然后使用以下辅助函数创建Estimator

   my_keras_model.compile(
                optimizer=tf.keras.optimizers.SGD(lr=0.0001, momentum=0.9),
                loss='categorical_crossentropy',
                metric='accuracy')
   estimator = tf.keras.estimator.model_to_estimator(keras_model=my_keras_model)

如果我能做类似estimator.to_TPU_estimator()或类似的事情会很棒——也许有人知道解决方案?

4

1 回答 1

7

不可能有这样的功能,因为model_fn两个估计器的规格不同。一些差异非常深刻,例如这个(来自TPU 教程):

在云 TPU 上训练时,您必须将优化器包装在 a 中 tf.contrib.tpu.CrossShardOptimizer,它使用 aallreduce来聚合梯度并将结果广播到每个分片(每个 TPU 核心)。

这意味着修补 keras 优化器和更新操作的内部。

model_fn推荐的方法是为 GPU 和 TPU 模型使用不同的包装器,这对您来说似乎是最快的方法。model_to_estimator在您的情况下,这意味着为 TPU 估计器重写 keras函数。


第一个也是最简单的近似是这样的:

def model_to_estimator(keras_model=None,
                       keras_model_path=None,
                       custom_objects=None,
                       model_dir=None,
                       config=None):
  keras_weights = keras_model.get_weights()
  keras_model_fn = _create_keras_tpu_model_fn(keras_model, custom_objects)
  est = tf.contrib.tpu.TPUEstimator(keras_model_fn, model_dir=model_dir, config=config)
  _save_first_checkpoint(keras_model, est, custom_objects, keras_weights)
  return est

在这里,_save_first_checkpointcall 实际上是可选的,但如果你想保留它,请从tensorflow.python.keras._impl.keras.estimator.


真正的工作发生在_create_keras_tpu_model_fn函数中,它取代了_create_keras_model_fn. 这些变化是:

  • 内部张量流优化器必须CrossShardOptimizer像前面提到的那样包装,并且

  • 内部函数必须返回TPUEstimatorSpec

可能还需要修补更多的行,但对我来说看起来还可以。完整版本如下:

from tensorflow.python.keras._impl.keras.estimator import _save_first_checkpoint, _clone_and_build_model

def model_to_estimator(keras_model=None,
                       keras_model_path=None,
                       custom_objects=None,
                       model_dir=None,
                       config=None):
  keras_weights = keras_model.get_weights()
  keras_model_fn = _create_keras_tpu_model_fn(keras_model, custom_objects)
  est = tf.contrib.tpu.TPUEstimator(keras_model_fn, model_dir=model_dir, config=config)
  _save_first_checkpoint(keras_model, est, custom_objects, keras_weights)
  return est


def _create_keras_tpu_model_fn(keras_model, custom_objects=None):

  def model_fn(features, labels, mode):
    """model_fn for keras Estimator."""
    model = _clone_and_build_model(mode, keras_model, custom_objects, features,
                                   labels)
    predictions = dict(zip(model.output_names, model.outputs))

    loss = None
    train_op = None
    eval_metric_ops = None

    # Set loss and metric only during train and evaluate.
    if mode is not tf.estimator.ModeKeys.PREDICT:
      model.optimizer.optimizer = tf.contrib.tpu.CrossShardOptimizer(model.optimizer.optimizer)

      model._make_train_function()  # pylint: disable=protected-access
      loss = model.total_loss

      if model.metrics:
        eval_metric_ops = {}
        # When each metric maps to an output
        if isinstance(model.metrics, dict):
          for i, output_name in enumerate(model.metrics.keys()):
            metric_name = model.metrics[output_name]
            if callable(metric_name):
              metric_name = metric_name.__name__
            # When some outputs use the same metric
            if list(model.metrics.values()).count(metric_name) > 1:
              metric_name += '_' + output_name
            eval_metric_ops[metric_name] = tf.metrics.mean(
                model.metrics_tensors[i - len(model.metrics)])
        else:
          for i, metric_name in enumerate(model.metrics):
            if callable(metric_name):
              metric_name = metric_name.__name__
            eval_metric_ops[metric_name] = tf.metrics.mean(
                model.metrics_tensors[i])

    if mode is tf.estimator.ModeKeys.TRAIN:
      train_op = model.train_function.updates_op

    return tf.contrib.tpu.TPUEstimatorSpec(
        mode=mode,
        predictions=predictions,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops)

  return model_fn
于 2018-02-26T20:09:01.343 回答