Google 的 BERT repo 中的README说,即使是长度为 512 的单个句子也不能放在 BERT-Large 模型的 12 GB Titan X 中。
但在 BERT 论文中,它说使用 64 个 TPU 芯片来训练 BERT-Large,最大长度为 512,批量大小为 256。它们如何将大于 256 倍的批量放入仅增加 171 倍的内存中?
从另一个角度来看,我们可以在每个样本的内存使用情况下比较这两种配置:
- TPU:假设TPUv3用于预训练,总TPU内存为32GB/芯片*64芯片=2048GB。根据论文,256 的批量大小和最大长度 512 在此配置中运行良好,这意味着8 GB 内存能够容纳单个样本。此外,如果使用 GPUv2,每个样本的内存使用量将减少到仅 4 GB。
- GPU:12 GB Titan X 甚至无法容纳长度为 512 的单个样本。
为什么 GPU 上的内存消耗要大得多?这是否意味着 TPU 上的内存消耗比 GPU 上的优化方式更好?