我从阅读教程中得到的印象是这很容易:打个电话给add_graph()
.
第一个棘手的部分是找出把它放在哪里。我在循环的最顶部使用了fairseq_cli/train.py:enumerate()
train()
for i, samples in enumerate(progress):
if i == 0:
# Output graph for tensorboard
writer = progress._writer("") #The "" is tag
writer.add_graph(trainer._model, samples)
writer.flush()
我正在--tensorboard-logdir mydir/
调用fairseq-train
. 这会导致TensorboardProgressBarWrapper
包装器SimpleProgressBar
(或您使用的任何日志记录格式),因此我正在尝试重新使用编写器实例。(也许这是我的错?)
我可以从对象中取出模型,trainer
我们需要的另一件事add_graph
是一些数据,这就是为什么我把它放在上面的循环中,所以我可以使用samples
. 我也尝试samples[0]
过完全相同的错误消息。
这是:
Tracer cannot infer type of ({'id': tensor([743216, 642485, 92182, 793806, 494734, 275334, 53282, 449572, 1758,
...
20734, 469070, 678489, 473213]), 'nsentences': 832, 'ntokens': 29775, 'net_input': {'src_tokens': tensor([[ 225, 30, 874, ..., 1330, 84, 2],
[ 442, 734, 7473, ..., 38, 5, 2],
...,
[ 6238, 11411, 428, ..., 10387, 5, 2]]), 'src_lengths': tensor([21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
...
21, 21, 21, 21]), 'prev_output_tokens': tensor([[ 2, 22, 133, ..., 4445, 46, 1],
...,
[ 2, 22, 59, ..., 112, 6, 1]])}, 'target': tensor([[ 22, 133, 203, ..., 46, 2, 1],
[132, 429, 40, ..., 46, 2, 1],
...,
[ 22, 59, 177, ..., 6, 2, 1]])},)
:Dictionary inputs to traced functions must have consistent type. Found Tensor and int
Error occurs, No graph saved
我希望有人可能比我更好地理解该错误消息?也欢迎有关更好的调用位置add_graph()
或处理整个事情的更好方法(获得我的模型的可视化表示)的提示。(PyTorchViz已经 2 年多没有更新了,显然不能与最新的 pytorch 一起使用。)
顺便说一句,训练的 Tensorboard 记录正在工作。
额外的
我又看了看,发现它会调用forward()
我的模型,它(忽略可选参数)看起来像:
def forward(self, src_tokens, src_lengths, prev_output_tokens):
所以我尝试手动创建空数据:
dummy_data = {'src_tokens':torch.zeros((1,256)),'src_lengths':torch.zeros((1,256)),'prev_output_tokens':torch.zeros((1,256))}
writer.add_graph(trainer._model,dummy_data,verbose=True)
那失败了:
TypeError: forward() missing 2 required positional arguments: 'src_lengths' and 'prev_output_tokens'
如果我使用 kwargs 风格:
dummy_data = {'src_tokens':torch.zeros((1,256)),'src_lengths':torch.zeros((1,256)),'prev_output_tokens':torch.zeros((1,256))}
writer.add_graph(trainer._model,**dummy_data,verbose=True)
它失败了:
TypeError: add_graph() got an unexpected keyword argument 'src_tokens'
我开始认为这可能是一个限制,add_graph()
它只适用于简单的def forward(self, x):
?