我正在低延迟环境中实现残差 cnn(修改后的 xception 的较小版本)。我做了很多手动调整以最小化我的网络的运行时间速度(减少过滤器的数量,删除层等)。
但是现在我想尝试让我的网络在每个残差块之后的残差连接上进行分类预测(最终 fcnn 层)。
基本逻辑——
尝试以残差连接作为输入的最终预测
如果这个 fcnn 层以概率 > 设定阈值预测某个类:
return fcnn output as if it was normal final layer
别的:
do next residual block like normal and try the previous conditional again unless we are already at final block
我希望这将使我的网络学会用更少的计算来解决更简单的问题,同时如果它仍然不确定分类,它仍然可以做额外的层。
所以我的基本问题是:在 pytorch 中,以允许我的 nn 在运行时决定是否进行更多处理的方式实现此条件的最佳方法是什么
目前我已经测试了在转发函数中的块之后返回中间 x,但我不知道如何最好地设置条件来选择返回哪个 x
另请注意:我相信我最终可能需要在残差和 fcnn 之间再添加一个 cnn 层,作为一个函数,将用于处理的内部表示转换为 fcnn 理解分类的表示。