这段代码占用大量内存:
int len = Input.size(0);
at::Tensor outputs = torch::zeros({ len });
for (int SliceStart = 0; SliceStart < len; SliceStart += SliceSize)
{
std::vector<torch::jit::IValue> InputVec;
InputVec.push_back(Input.slice(0, SliceStart, SliceEnd));
output = module.forward(InputVec).toTensor();
for (int i = SliceStart; i < SliceEnd; ++i)
outputs[i] = output[i - SliceStart];
}
这不是(item()
在最后一个作业中注意)
int len = Input.size(0);
at::Tensor outputs = torch::zeros({ len });
for (int SliceStart = 0; SliceStart < len; SliceStart += SliceSize)
{
std::vector<torch::jit::IValue> InputVec;
InputVec.push_back(Input.slice(0, SliceStart, SliceEnd));
output = module.forward(InputVec).toTensor();
for (int i = SliceStart; i < SliceEnd; ++i)
outputs[i] = output[i - SliceStart].item();
}
为什么会这样?