4

我现在在 Windows 中使用 pytorch 0.4.0 来构建 CNN,这是我的代码:

class net(nn.Module):
    def __init__(self):
        super(net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(1,3),stride=1 )

        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(1,3), stride=1)

        self.dense1 = nn.Linear(32 * 28 * 24, 60)
        self.out = nn.Linear(60,3)

    def forward(self, input):
        x = F.relu(self.conv1(input))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1) # flatten(batch,32*7*7)
        x = self.dense1(x)
        output = self.out(x)
        return output

但我得到的错误

File "D:\Anaconda\lib\site-packages\torch\nn\modules\conv.py", line 301, in forward
    self.padding, self.dilation, self.groups)

RuntimeError: expected stride to be a single integer value or a list of 1 values to match the convolution dimensions, but got stride=[1, 1]

我认为这表明我在上面的代码中犯了一些错误,但我不知道如何修复它,任何人都可以帮助我吗?提前致谢!

4

2 回答 2

2

好吧,也许我知道发生了什么,因为我在 4 或 5 小时前遇到了同样的运行时错误。

这是我的解决方案(我自己定义了数据集):

我输入网络的图像是 1 个通道,与您的代码(self.conv1 = nn.Conv2d(in_channels=1,...))相同。并且会带来运行时错误的图像属性如下:

error_img

在此处输入图像描述

我修复的图像如下:

固定图像

在此处输入图像描述

你能感觉到不同吗?输入图像的通道应该是 1,所以img.shape()应该是元组!使用img.reshape(1,100,100)它来修复它,网络的转发功能将继续进行。

我希望它可以帮助你。

于 2018-05-07T02:52:25.963 回答
0

原因之一可能是input输入模型进行处理;input必须缺少其中一个维度。

尝试torch.unsqueeze(input, 0)

于 2018-07-20T10:34:10.320 回答