您可以bfloat16
与 TPU 一起使用。主要有两件事要做:
- 在输入管道中将输入转换为 bfloat16
- 将您的网络包围在 bfloat16 范围内,并将输出转换为 F32 以供进一步计算。
这是一个说明必要更改的代码片段:
def input_fn():
def dataset_parser(self, value):
"""Parse an ImageNet record from a serialized string Tensor."""
image = self.image_preprocessing_fn(
image_bytes=image_bytes,
is_training=self.is_training,
)
if self.use_bfloat16:
image = tf.cast(image, tf.bfloat16)
return image, label
def resnet_model_fn(features, labels, mode, params):
"""The model_fn for ResNet to be used with TPUEstimator."""
# This nested function allows us to avoid duplicating the logic which
# builds the network, for different values of --precision.
def build_network():
network = resnet_model.resnet_v1(
resnet_depth=FLAGS.resnet_depth,
num_classes=LABEL_CLASSES,
data_format=FLAGS.data_format)
return network(
inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))
if FLAGS.precision == 'bfloat16':
with bfloat16.bfloat16_scope():
logits = build_network()
logits = tf.cast(logits, tf.float32)
elif FLAGS.precision == 'float32':
logits = build_network()
您还可以看到此 TPU 模型中说明的第二个条件。