我正在尝试用执行相同操作的自定义操作替换图中完成的计算。
假设该图有一个常量A
和权重变量W
,我创建了自定义操作来获取这两个输入并进行整个计算(除了权重更新的最后一步):
custom_op_tensor = custom_module.custom_op([A,W])
g_def = tf.get_default_graph().as_graph_def()
input_map = { tensor.name : custom_op_tensor }
train_op, = tf.import_graph_def(g_def, input_map=input_map, return_elements=[train_op])
导入图def后有两个W
,一个来自原图def,一个在导入图中。当我们运行训练操作时,自定义操作最终会读取旧W
的,而新W
的会更新。结果,梯度下降最终未能做正确的事情。
问题是 custom_op 的实例化需要输入权重张量W
。新W
的只有在导入后才知道。而且,导入需要自定义操作。如何解决这个问题?