最后,在 Tensorflow 2.1.0 中添加了对 TPU 的支持(截至 2020 年 1 月 8 日)。从这里的发行说明https://github.com/tensorflow/tensorflow/releases/tag/v2.1.0:
对 Keras .compile、.fit、.evaluate 和 .predict 的实验性支持适用于 Cloud TPU、Cloud TPU,适用于所有类型的 Keras 模型(顺序、功能和子类模型)。
该教程可在此处获得:https ://www.tensorflow.org/guide/tpu
为了完整起见,我将在此处添加演练:
- 转到 Google Colab 并在此处创建一个新的 Python 3 Notebook:https ://colab.research.google.com/
- 在工具栏中,单击运行时/更改运行时类型,然后在硬件加速器下选择“TPU”。
- 将以下代码复制并粘贴到笔记本中,然后单击运行单元(播放按钮)。
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import os
import tensorflow_datasets as tfds
# Distribution strategies
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
# MNIST model
def create_model():
return tf.keras.Sequential(
[tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)])
# Input datasets
def get_dataset(batch_size=200):
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True,
try_gcs=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255.0
return image, label
train_dataset = mnist_train.map(scale).shuffle(10000).batch(batch_size)
test_dataset = mnist_test.map(scale).batch(batch_size)
return train_dataset, test_dataset
# Create and train a model
strategy = tf.distribute.experimental.TPUStrategy(resolver)
with strategy.scope():
model = create_model()
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['sparse_categorical_accuracy'])
train_dataset, test_dataset = get_dataset()
model.fit(train_dataset,
epochs=5,
validation_data=test_dataset,steps_per_epoch=50)
请注意,当我按原样运行 tensorflow 教程中的代码时,会出现以下错误。我已经通过在 model.fit() 中添加 steps_per_epoch 参数来纠正这个问题
ValueError:无法从数据中推断出步数,请传递 steps_per_epoch 参数。