我想加速一个平均多个 TensorFlow 检查点的工具,但为了简单起见,假设我只需要加载一个检查点,可能修改一些变量并将其保存回磁盘。
当前实现将所有变量加载到numpy数组(3秒),tf.get_variable()
为每个变量准备tf变量(),占位符和assign_ops(20秒),执行初始化所有变量的会话(16秒),运行所有赋值(81秒),最后它将检查点存储到磁盘(24 秒)。总时间为144 秒。
我的替代实现使用tf.get_variable(name, shape=numpy_array.shape, initializer=tf.constant_initializer(numpy_array))
and no placeholders 也没有 assign_ops,因此它将总时间减少到57 seconds。但是,该*.meta
文件还存储了所有常量初始化程序(因此它与主检查点数据文件一样大,这不是我想要的),当我将其应用于更大的检查点时,由于2GB tf 限制而失败。
如果加载所有变量需要 3 秒,我相信存储它们的时间应该远少于 141 秒,甚至少于 54 秒。{var_name1: numpy_array1,...}
有没有办法在不需要运行 tf 会话的情况下将dict写入 tf 检查点文件(从另一个检查点重用元图)?我试图按照相关问题中的链接进行操作,但没有成功。