我正在使用 PyTorch 为孟加拉语数字分类构建一个神经网络。我在构建数据集类以使用数据加载器加载我的数据集时遇到困难。我有一个包含所有图像的文件夹(0-9 的数字)和一个包含 2 列的 CSV 文件,第一列包含图像的名称,第二列包含标签(0-9)。这是我的数据加载器类,它可能不会导致错误。
class BDRWDataset(Dataset):
"""BDRW dataset."""
def __init__(self, csv_file, imgs_dir, transform=None):
"""
Args:
csv_file (string): Path to the csv file with labels.
imgs_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.labels = pd.read_csv(csv_file).iloc[:, 1].to_numpy().reshape(-1,1)
self.imgs_dir = imgs_dir
self.transform = transform
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_path = '/content/BDRW_train/digit_' + str(idx) + '.jpg'
image = io.imread(img_name, plugin='matplotlib')
image = Image.fromarray(np.uint8(image))
label = self.labels[idx]
label = float(label)
if self.transform:
image = self.transform(image)
return (image, label)
我创建了这个类的一个实例。
transformed_dataset = BDRWDataset(csv_file='/content/labels.csv',imgs_dir='/content/BDRW_train',
transform=transforms.Compose([
Rescale((28, 28)),
transforms.Normalize((0.5,), (0.5,)),
ToTensor()
]))
我已经定义了 rescale 和 to tensor 如下
class Rescale(object):
"""Rescale the image in a sample to a given size.
Args:
output_size (tuple or int): Desired output size. If tuple, output is
matched to output_size. If int, smaller of image edges is matched
to output_size keeping aspect ratio the same.
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, sample):
image, label = sample
h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
img = transform.resize(image, (new_h, new_w))
return img, label
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
image, label = sample['image'], sample['label']
# swap color axis because
# numpy image: H x W x C
# torch image: C X H X W
image = image.transpose((2, 0, 1))
return (torch.from_numpy(image), torch.from_numpy(label))
将数据集拆分为测试和训练,并使用 torch.utils.data.DataLoader 创建训练加载器和验证加载器
神经网络是
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=5, padding=2),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(2))
self.layer2 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=5, padding=2),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2))
self.fc = nn.Linear(7*7*32, 10)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
我创建了这个类的一个实例并开始训练
cnn = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate)
losses = [];
for epoch in range(num_epochs):
for i, (image, label) in enumerate(valloader):
image = Variable(image.float())
label = Variable(label)
# Forward + Backward + Optimize
optimizer.zero_grad()
outputs = cnn(image)
loss = criterion(output, label)
loss.backward()
optimizer.step()
losses.append(loss.data[0]);
if (i+1) % 100 == 0:
print ('Epoch : %d/%d, Iter : %d/%d, Loss: %.4f'
%(epoch+1, num_epochs, i+1, len(train_dataset)//batch_size, loss.data[0]))
这是我得到错误的地方
AttributeError Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/PIL/Image.py in open(fp, mode)
2812 try:
-> 2813 fp.seek(0)
2814 except (AttributeError, io.UnsupportedOperation):
AttributeError: 'str' object has no attribute 'seek'
During handling of the above exception, another exception occurred:
AttributeError Traceback (most recent call last)
9 frames
/usr/local/lib/python3.6/dist-packages/PIL/Image.py in open(fp, mode)
2813 fp.seek(0)
2814 except (AttributeError, io.UnsupportedOperation):
-> 2815 fp = io.BytesIO(fp.read())
2816 exclusive_fp = True
2817
AttributeError: 'str' object has no attribute 'read'
它指的是我在数据加载器中使用的 PIL Image。所以在我看来,这就是我做错了什么。
https://colab.research.google.com/drive/17XdP7gUoMNLxPCJ6PHEi3B09UQitzKyf?usp=sharing
这是我正在处理的笔记本。请帮我调试代码中的错误。
https://drive.google.com/open?id=1DznuHV9Fi5jVEbGdP-tg3ckmp5CNDOj1
这是我正在处理的数据集。