0

我正在关注并实现这个关于 Torchtext 的简短教程中的代码,鉴于 Torchtext 的文档很差,这非常清楚。

创建迭代器(批处理生成器)后,他建议创建一个包装器以生成更多可重用的代码。(参见教程中的第 5 步)。

该代码包含一个令人惊讶的长而奇怪的行,我不明白它会引发SyntaxError: invalid syntax。有没有人知道发生了什么?

(有问题的那一行开头是:if self.y_vars is <g [...])

class BatchWrapper:
  def __init__(self, dl, x_var, y_vars):
        self.dl, self.x_var, self.y_vars = dl, x_var, y_vars # we pass in the list of attributes for x <g class="gr_ gr_3178 gr-alert gr_spell gr_inline_cards gr_disable_anim_appear ContextualSpelling ins-del" id="3178" data-gr-id="3178">and y</g>

  def __iter__(self):
        for batch in self.dl:
              x = getattr(batch, self.x_var) # we assume only one input in this wrapper

              if self.y_vars is <g class="gr_ gr_3177 gr-alert gr_gramm gr_inline_cards gr_disable_anim_appear Grammar replaceWithoutSep" id="3177" data-gr-id="3177">not</g> None: # we will concatenate y into a single tensor
                    y = torch.cat([getattr(batch, feat).unsqueeze(1) for feat in self.y_vars], dim=1).float()
              else:
                    y = torch.zeros((1))

              yield (x, y)

  def __len__(self):
        return len(self.dl)
4

1 回答 1

1

是的,估计是作者写错了。我认为正确的代码是这样的:

if self.y_vars is not None:
    y = torch.cat([getattr(batch, feat).unsqueeze(1) for feat in self.y_vars], dim=1).float()
else:
    y = torch.zeros((1))

您也可以在第 3 行的评论中看到这个错字(在博文中的代码中)。

于 2018-10-09T02:56:05.813 回答