我正在尝试将 ExplainerDashboard 用于我的 PyTorch 神经网络分类器。根据GitHub的要求,我将其作为 Skorch 模型传递,如下所示,它在拟合后正确生成模型。
def get_skorch_classifier():
X_train_m = X_train.astype(np.float32)
y_train_m = y_train.astype(np.float32)
X_train_df = pd.DataFrame(X_train_m, columns=X.columns)
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.layer_1 = nn.Linear(298, 60)
self.layer_2 = nn.Linear(60, 60)
self.layer_3 = nn.Linear(60, 60)
self.layer_4 = nn.Linear(60, 60)
self.layer_out = nn.Linear(60, 1)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.1)
self.batchnorm1 = nn.BatchNorm1d(60, momentum=0.2)
self.batchnorm2 = nn.BatchNorm1d(60, momentum=0.2)
self.batchnorm3 = nn.BatchNorm1d(60, momentum=0.2)
self.batchnorm4 = nn.BatchNorm1d(60, momentum=0.2)
self.sigmoid = nn.Sigmoid()
def forward(self, inputs):
x = self.relu(self.layer_1(inputs))
x = self.batchnorm1(x)
x = self.dropout(x)
x = self.relu(self.layer_2(x))
x = self.batchnorm2(x)
x = self.dropout(x)
x = self.relu(self.layer_3(x))
x = self.batchnorm3(x)
x = self.dropout(x)
x = self.relu(self.layer_4(x))
x = self.batchnorm4(x)
#x = self.dropout(x)
x = self.layer_out(x)
#x = self.sigmoid(x)
return x
model = NeuralNetBinaryClassifier(MyModule, max_epochs=10, lr=0.01, optimizer=optim.Adam)
model.fit(X_train_m, torch.FloatTensor(y_train_m))
return model, X_train_df, y_train_m
model, Xm_df, ym = get_skorch_classifier()
然后,当试图将它传递给分类器解释器函数时,它会导致 TypeError
explainer = ClassifierExplainer(model, Xm_df, y_test)
ExplainerDashboard(explainer, mode='inline').run(port=8051)
> TypeError Traceback (most recent call
> last) <ipython-input-17-76facb1989af> in <module>()
> ----> 1 explainer = ClassifierExplainer(model, Xm_df, y_test)
> 2 ExplainerDashboard(explainer, mode='inline').run(port=8051)
>
> 10 frames
> /usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in
> _call_impl(self, *input, **kwargs) 1049 if not (self._backward_hooks or self._forward_hooks or
> self._forward_pre_hooks or _global_backward_hooks 1050
> or _global_forward_hooks or _global_forward_pre_hooks):
> -> 1051 return forward_call(*input, **kwargs) 1052 # Do not call functions when jit is used 1053 full_backward_hooks, non_full_backward_hooks = [], []
>
> TypeError: forward() got an unexpected keyword argument 'temp'
'temp' 是我的 DataFrame 中的列名:
任何想法为什么会发生此错误以及如何解决它?