我正在关注并实现这个关于 Torchtext 的简短教程中的代码,鉴于 Torchtext 的文档很差,这非常清楚。
创建迭代器(批处理生成器)后,他建议创建一个包装器以生成更多可重用的代码。(参见教程中的第 5 步)。
该代码包含一个令人惊讶的长而奇怪的行,我不明白它会引发SyntaxError: invalid syntax。有没有人知道发生了什么?
(有问题的那一行开头是:if self.y_vars is <g [...])
class BatchWrapper:
def __init__(self, dl, x_var, y_vars):
self.dl, self.x_var, self.y_vars = dl, x_var, y_vars # we pass in the list of attributes for x <g class="gr_ gr_3178 gr-alert gr_spell gr_inline_cards gr_disable_anim_appear ContextualSpelling ins-del" id="3178" data-gr-id="3178">and y</g>
def __iter__(self):
for batch in self.dl:
x = getattr(batch, self.x_var) # we assume only one input in this wrapper
if self.y_vars is <g class="gr_ gr_3177 gr-alert gr_gramm gr_inline_cards gr_disable_anim_appear Grammar replaceWithoutSep" id="3177" data-gr-id="3177">not</g> None: # we will concatenate y into a single tensor
y = torch.cat([getattr(batch, feat).unsqueeze(1) for feat in self.y_vars], dim=1).float()
else:
y = torch.zeros((1))
yield (x, y)
def __len__(self):
return len(self.dl)