0

我正在尝试优化 GPT2 的推理时间。在 Google Colab 上调用脚本后生成样本的当前时间为 55 秒。我输入了时间戳以尝试找出瓶颈在哪里。这是代码:

 for _ in range(nsamples // batch_size):
            out = sess.run(output, feed_dict={
                context: [context_tokens for _ in range(batch_size)]
            })[:, len(context_tokens):]
            for i in range(batch_size):
                generated += 1
                text = enc.decode(out[i])
                print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
                print(text)
        print("=" * 80)

线

out = sess.run(output, feed_dict={
                context: [context_tokens for _ in range(batch_size)]
            })[:, len(context_tokens):] 

是复杂性所在。有没有人有办法改进这段代码?太感谢了!

4

1 回答 1

1

batch_size 在 GPT2 中设置为 1,并且没有办法在不使进程崩溃的情况下更改它。所以“[context_tokens for _ in range(batch_size)]”的意思是“[context_tokens for _ in range(1)]”的意思是“[context_tokens]”,它不会大大提高速度,但可以安全地实现并查看代码理智一点。真正的复杂性是您在该会话中访问的 ram 中有一个 6 GB 的巨兽。

实际上,您发送的令牌越少,处理这些令牌的时间越少,这部分执行的速度就越快。因为每个令牌都需要通过 GPT2 AI 发送。但因此,响应将变得越不“智能”。

顺便说一下 // 是整数除法运算,所以 nsamples // batch_size = nsamples/1 = nsamples 大小。从我看到的情况来看,当我在 print(nsamples) 中打印它的值时,nsamples 是 1。这样for循环是一个项目的另一个循环,这意味着可以删除循环。

GPT2 只是 tensorflow 的一个实现。查找:如何在tensorflow中制作图表;如何为该图调用会话;如何使保护程序保存该会话中的变量以及如何使用保护程序恢复会话。您将了解检查点、元文件和其他使您的文件更有意义的实现。

tensorflow 模块位于 Lib、site-packages、tensorflow_core(至少在 AI Dungeon 2 Henk717 fork 中)。大多数处理发生在子目录 python/ops 和 framework 中。如果您的编码破坏了 tf 所期望的钩子,您将看到这些弹出窗口。

如果这个问题与 AI Dungeon 中的实现有关,那么我能够实现的最好的方法是对 generator.generate 的递归调用,该调用由尝试退出,除了 KeyboardInterrupt: with a print(token, end = '', flush = True) for每个令牌生成时。通过这种方式,您可以在 AI 生成每个令牌时查看它,而不是等待 55 秒等待 ping 声。

此外,Cuda 警告需要单引号,而不是双引号,因此 import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' not "3" 这将在导入 tensorflow 时取消 cuda 警告。

接下来,在 1.5 以上的 tensorflow 版本中,GPT2 的实现会弹出折旧。

关闭那些 tfv = tf.compat.v1 tfv.set_verbosity(tfv.logging.Error) 就是你所需要的。您不需要导入警告。

即便如此,在 tf 初始化、示例初始生成和将模块加载到 ram 之间还是有很长的加载时间。我在 model.shape_list(x) 中添加:以下行 print("_",end ='',flush=True) 至少对于正在构建的模块以将其本地化到机器,您可以查看“进度条”各种各样的。

于 2021-07-27T01:54:51.223 回答