2

我从阅读教程中得到的印象是这很容易:打个电话给add_graph().

第一个棘手的部分是找出把它放在哪里。我在循环的最顶部使用了fairseq_cli/train.pyenumerate()train()

for i, samples in enumerate(progress):

    if i == 0:
        # Output graph for tensorboard
        writer = progress._writer("")  #The "" is tag
        writer.add_graph(trainer._model, samples)
        writer.flush()

我正在--tensorboard-logdir mydir/调用fairseq-train. 这会导致TensorboardProgressBarWrapper包装器SimpleProgressBar(或您使用的任何日志记录格式),因此我正在尝试重新使用编写器实例。(也许这是我的错?)

我可以从对象中取出模型,trainer我们需要的另一件事add_graph是一些数据,这就是为什么我把它放在上面的循环中,所以我可以使用samples. 我也尝试samples[0]过完全相同的错误消息。

这是:

Tracer cannot infer type of ({'id': tensor([743216, 642485,  92182, 793806, 494734, 275334,  53282, 449572,   1758,
  ...
  20734, 469070, 678489, 473213]), 'nsentences': 832, 'ntokens': 29775, 'net_input': {'src_tokens': tensor([[  225,    30,   874,  ...,  1330,    84,     2],
  [  442,   734,  7473,  ...,    38,     5,     2],
  ...,
  [ 6238, 11411,   428,  ..., 10387,     5,     2]]), 'src_lengths': tensor([21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21,
  ...
  21, 21, 21, 21]), 'prev_output_tokens': tensor([[   2,   22,  133,  ..., 4445,   46,    1],
  ...,
  [   2,   22,   59,  ...,  112,    6,    1]])}, 'target': tensor([[ 22, 133, 203,  ...,  46,   2,   1],
  [132, 429,  40,  ...,  46,   2,   1],
  ...,
  [ 22,  59, 177,  ...,   6,   2,   1]])},)
:Dictionary inputs to traced functions must have consistent type. Found Tensor and int
Error occurs, No graph saved

我希望有人可能比我更好地理解该错误消息?也欢迎有关更好的调用位置add_graph()或处理整个事情的更好方法(获得我的模型的可视化表示)的提示。(PyTorchViz已经 2 年多没有更新了,显然不能与最新的 pytorch 一起使用。)

顺便说一句,训练的 Tensorboard 记录正在工作。

额外的

我又看了看,发现它会调用forward()我的模型,它(忽略可选参数)看起来像:

def forward(self, src_tokens, src_lengths, prev_output_tokens):

所以我尝试手动创建空数据:

dummy_data = {'src_tokens':torch.zeros((1,256)),'src_lengths':torch.zeros((1,256)),'prev_output_tokens':torch.zeros((1,256))}
writer.add_graph(trainer._model,dummy_data,verbose=True)

那失败了:

TypeError: forward() missing 2 required positional arguments: 'src_lengths' and 'prev_output_tokens'

如果我使用 kwargs 风格:

dummy_data = {'src_tokens':torch.zeros((1,256)),'src_lengths':torch.zeros((1,256)),'prev_output_tokens':torch.zeros((1,256))}
writer.add_graph(trainer._model,**dummy_data,verbose=True)

它失败了:

TypeError: add_graph() got an unexpected keyword argument 'src_tokens'

我开始认为这可能是一个限制,add_graph()它只适用于简单的def forward(self, x):?

4

0 回答 0