我有一项任务必须使用 Adam 优化器快速运行 N 个并行简单优化。我一直在使用 tensorflow 1.x 进行此操作,但尝试将所有内容更新到 2.x 或现代 pytorch 会导致行为慢得多。
我在 tensorflow 1.x 2.x 和 pytorch 中构建了最小的(至少在我的能力范围内)示例,这些示例都在 cpu 上。 https://gist.github.com/gftabor/abeb108fc9aa8b1c799bfc63287c2e5f
如您所见,tensorflow 2.x 的耗时大约是 tensorflow 1.x 的 8 倍。Pytorch 类似于 tf 2.x。
我认为问题是动态图执行,所以这是我挖得最多的兔子洞,但我关心的只是性能。
兔子洞
我一直在尝试通过使用 tf.function 和 torch.jit.script 功能静态构建图表来匹配 Tensorflow 1.x 的性能。它们似乎都没有给我几乎 Tensorflow 1.x 的表现力,而且它们似乎都没有与 Tensorflow 1.x 的性能相匹配,在我的问题上,我可以构建到静态图中的部分只是前向 + 损失函数而不是一个完整的亚当优化步骤。在使用 Tensorflow 1.x 时,我可以将这些完整的亚当优化步骤中的 N 个并行构建为单个静态图,并快速将数据传递给它。
声称 使用现代 Tensorflow 或 pytorch 库中的任何一个,似乎都不可能达到 Tensorflow 1.x 的一半速度,这使得更新非常没有吸引力。
有没有一种方法可以在我缺少的 Tensorflow 2.x 或 pytorch 中获得大幅加速?