我有一个简单的神经网络模型,我在模型上应用cuda()
或应用DataParallel()
如下。
model = torch.nn.DataParallel(model).cuda()
或者,
model = model.cuda()
当我不使用 DataParallel 时,只需将我的模型转换为cuda()
,我需要将批量输入显式转换为cuda()
然后将其提供给模型,否则它会返回以下错误。
torch.index_select 收到无效的参数组合 - 得到 (torch.cuda.FloatTensor, int, torch.LongTensor)
但是使用 DataParallel,代码可以正常工作。其余的其他事情都是一样的。为什么会发生这种情况?为什么当我使用 DataParallel 时,我不需要将批处理输入显式转换为cuda()
?