0

我正在尝试将 MLP 与 2 个类一起使用。

X_train 是一个 numpy 数据数组 (217,36 (36inputs)) X_test 是一个 numpy 数据数组 (100,36) y_train 是一个 [1,0] 或 [0,1] 的 numpy 数组(指定每个类别的类别sample) 大小为 (217,2) 和 y_test 与大小 (100,2) 相同

以下是我使用 Pytorch 的代码:

import torch
from torch.utils.data import TensorDataset, DataLoader
from torch import nn
from torch import optim
X_train = train
X_test = test
y_train = y_train
y_test = y_test
X_train = torch.Tensor(X_train)
X_test = torch.Tensor(X_test)
y_train = torch.LongTensor(y_train)
y_test = torch.LongTensor(y_test)
ds_train = TensorDataset(X_train, y_train)
ds_test = TensorDataset(X_test, y_test)
loader_train = DataLoader(ds_train, batch_size=4, shuffle=True)
loader_test = DataLoader(ds_test, batch_size=4, shuffle=False)
model = nn.Sequential()
model.add_module('fc1', nn.Linear(1*36, 100))
model.add_module('relu1', nn.ReLU())
model.add_module('fc2', nn.Linear(100, 100))
model.add_module('relu2', nn.ReLU())
model.add_module('fc3', nn.Linear(100, 2))
print(model)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
def train(epoch):
    model.train()  
    
    for data, targets in loader_train:
        optimizer.zero_grad()  
        outputs = model(data)  
        loss = loss_fn(outputs, targets)  
        loss.backward()  
        optimizer.step()  
    print("epoch{}:\n".format(epoch))
def test():
    model.eval()  
    correct = 0
   
    with torch.no_grad():  
        for data, targets in loader_test:
            outputs = model(data)  
           
            _, predicted = torch.max(outputs.data, 1)  
            correct += predicted.eq(targets.data.view_as(predicted)).sum()  
 
    data_num = len(loader_test.dataset)  
    print('\nAccuracy: {}/{} ({:.0f}%)\n'.format(correct,
                                                   data_num, 100. * correct / data_num))


test()

但是有一个错误:

回溯(最近一次通话最后):

文件“C:\users\Harry.spyder-py3\done.py”,第 218 行,在 test()

文件“C:\users\Harry.spyder-py3\done.py”,第 212 行,测试正确 += predicted.eq(targets.data.view_as(predicted)).sum()

RuntimeError:形状“[4]”对于大小为 8 的输入无效

任何人都可以帮助我吗?我认为这是批量大小的问题,所以我尝试使用不同的批量大小但不起作用:(

4

0 回答 0