编辑:我试过 PyTorch 1.6.0 和 1.7.1,都给了我同样的错误。
我有一个模型,可以让用户在不同的架构 A 和 B 之间轻松切换。两种架构的转发功能也不同,所以我有以下模型类:
PS我这里只是用一个非常简单的例子来演示我的问题,实际模型要复杂得多。
class Net(nn.Module):
def __init__(self, condition):
super().__init__()
self.linear = nn.Linear(10, 1)
if condition == 'A':
self.forward = self.forward_A
elif condition == 'B':
self.linear2 = nn.Linear(10, 1)
self.forward = self.forward_B
def forward_A(self, x):
return self.linear(x)
def forward_B(self, x1, x2):
return self.linear(x1) + self.linear2(x2)
它在单个 GPU 情况下运行良好。然而,在多 GPU 的情况下,它会给我一个错误。
device= 'cuda:0'
x = torch.randn(8,10).to(device)
model = Net('B')
model = model.to(device)
model = nn.DataParallel(model)
model(x, x)
RuntimeError: 期望所有张量都在同一个设备上,但发现至少有两个设备,cuda:0 和 cuda:1!(在方法 wrapper_addmm 中检查参数 mat1 的参数时)
如何使这个模型类工作nn.DataParallel
?