我有二维数组,当我使用以下代码计算损失时:
_roi_score = roi_score[row_index, col_index]
gt_roi_label_lst = gt_roi_label_lst[row_index, col_index]
loss = F.sigmoid_cross_entropy(roi_score, gt_roi_label_lst) # multi label
反向传播时,代码报错:
File "AU_rcnn/train.py", line 249, in main
trainer.run()
File "/usr/local/anaconda3/lib/python3.6/site-packages/chainer-3.0.0b1-py3.6.egg/chainer/training/trainer.py", line 324, in run
six.reraise(*sys.exc_info())
File "/usr/local/anaconda3/lib/python3.6/site-packages/six.py", line 686, in reraise
raise value
File "/usr/local/anaconda3/lib/python3.6/site-packages/chainer-3.0.0b1-py3.6.egg/chainer/training/trainer.py", line 310, in run
update()
File "/usr/local/anaconda3/lib/python3.6/site-packages/chainer-3.0.0b1-py3.6.egg/chainer/training/updater.py", line 223, in update
self.update_core()
File "/usr/local/anaconda3/lib/python3.6/site-packages/chainer-3.0.0b1-py3.6.egg/chainer/training/updater.py", line 367, in update_core
loss.backward()
File "/usr/local/anaconda3/lib/python3.6/site-packages/chainer-3.0.0b1-py3.6.egg/chainer/variable.py", line 916, in backward
target_input_indexes, out_grad, in_grad)
File "/usr/local/anaconda3/lib/python3.6/site-packages/chainer-3.0.0b1-py3.6.egg/chainer/function_node.py", line 486, in backward_accumulate
gxs = self.backward(target_input_indexes, grad_outputs)
File "/usr/local/anaconda3/lib/python3.6/site-packages/chainer-3.0.0b1-py3.6.egg/chainer/function.py", line 124, in backward
gxs = self._function.backward(in_data, grad_out_data)
File "/usr/local/anaconda3/lib/python3.6/site-packages/chainer-3.0.0b1-py3.6.egg/chainer/functions/connection/linear.py", line 56, in backward
gb = gy.sum(0)
File "cupy/core/core.pyx", line 967, in cupy.core.core.ndarray.sum
File "cupy/core/core.pyx", line 975, in cupy.core.core.ndarray.sum
File "cupy/core/reduction.pxi", line 216, in cupy.core.core.simple_reduction_function.__call__
File "cupy/core/elementwise.pxi", line 102, in cupy.core.core._preprocess_args
ValueError: Array device must be same as the current device: array device = 1 while current = 0
虽然我只使用了一个 GPU,但它出现了。这是什么原因造成的,我卡了很长时间。