2

我在谷歌合作中面临这个错误。我尝试了其他数据类型,例如 bool 张量,但没有用,请帮忙

代码

 def _mask(prev_generated_seq):
        prev_mask = torch.eq(prev_generated_seq, 1)
        lengths = torch.argmax(prev_mask,dim=1)
        #test = torch.max(prev_mask,dim=1)
        #lengths = torch.FloatTensor(test)
        max_len = prev_generated_seq.size(1)
        mask = []
        for i in range(prev_generated_seq.size(0)):
            if lengths[i] == 0:
                mask_line = [0] * max_len
            else:
                mask_line = [0] * lengths[i].item()
                mask_line.extend([1] * (max_len - lengths[i].item()))
            mask.append(mask_line)
        mask = torch.ByteTensor(mask)
        if args.cuda:
            mask = mask.cuda()
        return prev_generated_seq.data.masked_fill_(mask, 0)

错误

File "main.py", line 179, in <module>
    train_epoches(abstracts, model, config.epochs, teacher_forcing_ratio=1)
  File "main.py", line 155, in train_epoches
    target_variables, model, teacher_forcing_ratio)
  File "main.py", line 139, in train_batch
    prev_generated_seq = _mask(prev_generated_seq)
  File "main.py", line 101, in _mask
    lengths = torch.argmax(prev_mask,dim=1)
RuntimeError: "argmax_cuda" not implemented for 'Bool'
4

0 回答 0