2

我的模型太大,无法使用普通 v2 TPU 设备获得 >64 的批次。在故障排除站点上提到即将推出的 tensorflow 版本将支持 bfloat16。新支持的 tf 版本 1.9-1.12 现在可以使用 bfloat16 吗?如果可以,我可以使用一组有限的优化器吗?我没有找到任何进一步的文档,但在 tensor2tensor 模型中看到了 bfloat16 的用法,所以我想一定有办法。

此外,我读到TPU v3 也支持更大的模型,但模型需要的更改很少,但我没有找到任何需要更改的文档。

我已经在使用Adafactor并尝试减少我的层数,如果您有任何进一步的减少技巧,那也很棒。我正在使用图片矩阵和词向量(截至目前为 float32)作为输入。

4

1 回答 1

2

您可以bfloat16与 TPU 一起使用。主要有两件事要做:

  1. 在输入管道中将输入转换为 bfloat16
  2. 将您的网络包围在 bfloat16 范围内,并将输出转换为 F32 以供进一步计算。

这是一个说明必要更改的代码片段:

def input_fn():

  def dataset_parser(self, value):
    """Parse an ImageNet record from a serialized string Tensor."""
    image = self.image_preprocessing_fn(
        image_bytes=image_bytes,
        is_training=self.is_training,
    )

    if self.use_bfloat16:
      image = tf.cast(image, tf.bfloat16)

    return image, label


def resnet_model_fn(features, labels, mode, params):
  """The model_fn for ResNet to be used with TPUEstimator."""

  # This nested function allows us to avoid duplicating the logic which
  # builds the network, for different values of --precision.
  def build_network():
    network = resnet_model.resnet_v1(
        resnet_depth=FLAGS.resnet_depth,
        num_classes=LABEL_CLASSES,
        data_format=FLAGS.data_format)
    return network(
        inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))

  if FLAGS.precision == 'bfloat16':
    with bfloat16.bfloat16_scope():
      logits = build_network()
    logits = tf.cast(logits, tf.float32)
  elif FLAGS.precision == 'float32':
    logits = build_network()

您还可以看到此 TPU 模型中说明的第二个条件。

于 2018-12-18T19:47:54.060 回答