使用Keras 分布式训练示例中的代码;使用 TF 2.4.1。
关注其他文档:
https://www.tensorflow.org/guide/distributed_training https://www.tensorflow.org/guide/distributed_training#multiworkermirroredstrategy
在 AWS 中使用 MNIST 模型训练单个 m5.large 节点需要 1m 47s。使用 3 台该类型的机器和 MultiWorkerMirroredStrategy,需要 4 分 30 秒。
仅仅是因为训练是在一个“相对较小”的 MNIST 模型上进行的,并且它开始真正在大型或非常大的数据集上大放异彩吗?
我的实际输入数据要大得多。在 1/2 GB 数据上进行模型训练的最佳方法是什么?1GB?2GB?非分布式在单个节点上是行不通的,但在 MultiWorkerMirroredStrategy 上运行的 MNIST 会引发速度问题。
基于 MNIST 的测试代码如下。
非分布式:
import json
import os
import sys
import time
import numpy as np
import tensorflow as tf
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
if "." not in sys.path:
sys.path.insert(0, ".")
def mnist_dataset(batch_size):
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
# The `x` arrays are in uint8 and have values in the range [0, 255].
# You need to convert them to float32 with values in the range [0, 1]
x_train = x_train / np.float32(255)
y_train = y_train.astype(np.int64)
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
return train_dataset
def build_and_compile_cnn_model():
model = tf.keras.Sequential(
[
tf.keras.Input(shape=(28, 28)),
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(32, 3, activation="relu"),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dense(10),
]
)
model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
metrics=["accuracy"],
)
return model
start_time = time.time()
global_batch_size = 64
multi_worker_dataset = mnist_dataset(global_batch_size)
multi_worker_model = build_and_compile_cnn_model()
multi_worker_model.fit(multi_worker_dataset, epochs=50, steps_per_epoch=70)
elapsed_time = time.time() - start_time
str_elapsed_time = time.strftime("%H : %M : %S", time.gmtime(elapsed_time))
print(">> Finished. Time elapsed: {}.".format(str_elapsed_time))
分散式:
import json
import os
import sys
import time
import numpy as np
import tensorflow as tf
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
if "." not in sys.path:
sys.path.insert(0, ".")
def mnist_dataset(batch_size):
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
# The `x` arrays are in uint8 and have values in the range [0, 255].
# You need to convert them to float32 with values in the range [0, 1]
x_train = x_train / np.float32(255)
y_train = y_train.astype(np.int64)
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
return train_dataset
def build_and_compile_cnn_model():
model = tf.keras.Sequential(
[
tf.keras.Input(shape=(28, 28)),
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(32, 3, activation="relu"),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dense(10),
]
)
model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
metrics=["accuracy"],
)
return model
start_time = time.time()
per_worker_batch_size = 64
tf_config = json.loads(os.environ["TF_CONFIG"])
num_workers = len(tf_config["cluster"]["worker"])
strategy = tf.distribute.MultiWorkerMirroredStrategy()
# global_batch_size = per_worker_batch_size * num_workers
# multi_worker_dataset = mnist_dataset(global_batch_size)
# OR - turn on sharding
global_batch_size = 64
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
multi_worker_dataset = mnist_dataset(global_batch_size)
multi_worker_dataset_with_shrd = multi_worker_dataset.with_options(options)
with strategy.scope():
# Model building/compiling need to be within `strategy.scope()`.
multi_worker_model = build_and_compile_cnn_model()
multi_worker_model.fit(multi_worker_dataset_with_shrd, epochs=50, steps_per_epoch=70)
elapsed_time = time.time() - start_time
str_elapsed_time = time.strftime("%H : %M : %S", time.gmtime(elapsed_time))
print(">> Finished. Time elapsed: {}.".format(str_elapsed_time))