我想用 tensorflow 训练 GAN,然后将生成器和鉴别器导出为 tensorflow_hub 模块。
为此:
- 我用 tensorflow 定义我的 GAN 架构
- 训练它并保存检查点
- 创建一个具有不同标签的 module_spec,例如:
(set(), {'batch_size': 8, 'model': 'gen'})
({'bs8', 'gen'}, {'batch_size': 8, 'model': 'gen'})
({'bs8', 'disc'}, {'batch_size': 8, 'model': 'disc'})
- 使用我在训练期间保存的 checkpoint_path 在 tf_hub_path 使用 module_spec 导出
然后,我可以使用以下命令加载生成器:
hub.Module(tf_hub_path, tags={"gen", "bs8"})
但是,当我尝试使用类似的命令加载鉴别器时:
hub.Module(tf_hub_path, tags={"disc", "bs8"})
我得到了错误:
ValueError: Tensor discriminator/linear/bias is not found in b'/tf_hub/variables/variables' checkpoint {'generator/fc_noise/kernel': [2, 48], 'generator/fc_noise/bias': [48]}
因此,我得出结论,鉴别器中存在的变量没有保存在磁盘上的模块中。我检查了我想象的不同错误来源:
- 正确定义了模块规范。为此,我决定训练我的模型,创建模块规范并直接从该 module_spec 加载模块。这对生成器和鉴别器都很好。然后,我假设我的 module_spec 是正确的
然后,我想知道检查点是否正确地将所有变量保存在我的图表中。
checkpoint_path = tf.train.latest_checkpoint(self.model_dir) inspect_list = tf.train.list_variables(checkpoint_path) print(inspect_list) [('disc_step_1/beta1_power', []), ('disc_step_1/beta2_power', []), ('discriminator/linear/bias', [1]), ('discriminator/linear/bias/d_opt', [1]), ('discriminator/linear/bias/d_opt_1', [1]), ('discriminator/linear/kernel', [3, 1]), ('discriminator/linear/kernel/d_opt', [3, 1]), ('discriminator/linear/kernel/d_opt_1', [3, 1]), ('gen_step/beta1_power', []), ('gen_step/beta2_power', []), ('generator/fc_noise/bias', [48]), ('generator/fc_noise/bias/g_opt', [48]), ('generator/fc_noise/bias/g_opt_1', [48]), ('generator/fc_noise/kernel', [2, 48]), ('generator/fc_noise/kernel/g_opt', [2, 48]), ('generator/fc_noise/kernel/g_opt_1', [2, 48]), ('global_step', []), ('global_step_disc', [])]
因此,我看到所有变量都正确保存在检查点内。只有与生成器相关的两个变量在磁盘上的 tf hub 模块中正确导出。
最后,我想我的错误来自:
module_spec.export(tf_hub_path, checkpoint_path=checkpoint_path)
从 checkpoint_path 导出变量时,只考虑标签“gen”。我还检查了 module.variable_map 和检查点路径中的列表变量之间的变量名称是否对应。这是带有标签“disc”的模块的变量映射:
print(module.variable_map)
{'discriminator/linear/bias': <tf.Variable 'module_8/discriminator/linear/bias:0' shape=(1,) dtype=float32>, 'discriminator/linear/kernel': <tf.Variable 'module_8/discriminator/linear/kernel:0' shape=(3, 1) dtype=float32>}
我有
- 张量流:1.13.1
- 张量流集线器:0.4.0
- 蟒蛇:3.5.2
谢谢你的帮助