0

通过 Detectron2 训练了一个模块后,我尝试将模型导出到 TorchScript,然后出现以下错误:

无法导出 Python 函数调用“_ScaleGradient”。删除对 Python 函数的调用 > 在导出之前。您是否忘记添加 @script 或 @script_method 注释?如果这是 > nn.ModuleList,请将其添加到 __constants__

我发现代码在detectron2/modeling/roi_heads/cascade_rcnn.py

class _ScaleGradient(Function):
    @staticmethod
    def forward(ctx, input, scale):
        ctx.scale = scale
        return input

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output * ctx.scale, None

所以我将@statcmethod annos 更改为@torch.jit.script_method,之后出现“'ScriptMethodStub' 对象不可调用”错误。

我对torchscript不熟悉,如何解决这个问题?

提前致谢。

4

1 回答 1

0

在推理阶段似乎不需要 _ScaleGradient 方法,所以我只需将以下代码添加到 cacasde_rcnn.py

if self.training:
    #call _ScaleGradient.apply
else:
    #don't call _ScaleGradient.apply
于 2021-03-27T03:38:43.673 回答