2

我正在尝试加快我在torch7中实现的网络,但是当我尝试使用nn.DataParallelTable时出现错误。这就是我想要做的:

m1, m2 = createModel(8,48), createModel(8,48)
--8 # of GPUs, 48 hidden unit in the last layer
m2:share(m1,'weight', 'bias') ----THE ERROR IS HERE
prl = nn.ParallelTable()
prl:add(m1)
prl:add(m2)
prl:cuda()
mlp = nn.Sequential()
mlp:add(prl)
mlp:cuda()
crit = nn.CosineEmbeddingCriterion():cuda()

函数在哪里:

function createModel(nGPU,bot)
local features = nn.Concat(2)
local fb1 = nn.Sequential() -- branch 1
fb1:add(nn.SpatialConvolution(1,48,3,3,1,1,1,1))
fb1:add(nn.ReLU(true))               
fb1:add(nn.SpatialConvolution(48,128,3,3,1,1,1,1))        
fb1:add(nn.ReLU(true))
fb1:add(nn.SpatialConvolution(128,192,3,3,1,1,1,1))     
fb1:add(nn.ReLU(true))
fb1:add(nn.SpatialConvolution(192,192,3,3,1,1,1,1))      
fb1:add(nn.ReLU(true))
fb1:add(nn.SpatialConvolution(192,128,3,3,1,1,1,1))
fb1:add(nn.ReLU(true))  
fb1:add(nn.SpatialMaxPooling(2,2,2,2))    
view = 12  
local fb2 = fb1:clone() -- branch 2
for k,v in ipairs(fb2:findModules('nn.SpatialConvolution')) do
v:reset() -- reset branch 2's weights
end
features:add(fb1) features:add(fb2) features:cuda()

--------------the error is at this line-----------
features = makeDataParallel(features, nGPU)

local classifier = nn.Sequential()    
classifier:add(nn.View(256viewview))
classifier:add(nn.Dropout(0.5))    
classifier:add(nn.Linear(256viewview, 4096))
classifier:add(nn.Dropout(0.5))
classifier:add(nn.Linear(4096, 4096))
classifier:add(nn.Tanh())
classifier:add(nn.Linear(4096, bot))
classifier:add(nn.Tanh())
classifier:cuda()
local model = nn.Sequential():add(features):add(classifier)
return model
end

另一个是:

function makeDataParallel(model, nGPU)
if nGPU > 1 then
print('converting module to nn.DataParallelTable')
assert(nGPU <= cutorch.getDeviceCount(), 'number of GPUs less than nGPU specified')
local model_single = model
model = nn.DataParallelTable(1)
for i=1, nGPU do
cutorch.setDevice(i)
model:add(model_single:clone():cuda(), i)
end
end
cutorch.setDevice(1)
return model
end

我得到的错误是:

[C]: in function 'error'
...a/torch/install/share/lua/5.1/cunn/DataParallelTable.lua:337: in function 'share'
/home/andrea/torch/install/share/lua/5.1/nn/Container.lua:97: in function 'share'
main.lua:123: in main chunk
[C]: at 0x00406670

您可能知道错误在哪里吗?抱歉,但我对此有点陌生,我找不到解决方法。当然我弄错了网络结构。提前致谢。

4

0 回答 0