我正在尝试在 Pytorch 中运行以下代码:
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
class H5Dataset(data.Dataset):
def __init__(self, trainx_path, trainy_path):
super(H5Dataset, self).__init__()
x_file = h5py.File(trainx_path)
y_file = h5py.File(trainy_path)
self.data = x_file.get('X')
self.target = y_file.get('y')
def __getitem__(self, size):
permutation1 = list(np.random.permutation(249000))
permutation2 = list(np.random.permutation(np.arange(249000,498000)))
size1 = int(size/2)
index1=list(permutation1[0:size1])
index2=list(permutation2[0:size1])
index = index1+index2
labels=np.array(self.target).reshape(498000,-1)
train_labels=labels[index]
train_batch=[]
for i in range(size):
img=(self.data)[index[i]]
train_batch.append(img)
train_batch=np.array(train_batch)
return (torch.from_numpy(train_batch).float(), torch.from_numpy(train_labels).float())
def __len__(self):
return len(self.data)
dataset = H5Dataset('//content//drive//My Drive//E2E_1//train_x.hdf5','//content//drive//My Drive//E2E_1//train_y.hdf5')
train_loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True, num_workers=2, pin_memory=False)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(2, 64, kernel_size=(2,2), padding=(8,8), stride=(2,2),padding_mode='zeros')
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(64, 128, kernel_size=(2,2), padding=(8,8), stride=(2,2),padding_mode='zeros')
self.fc1 = nn.Linear(32768, 500)
self.drop = nn.Dropout(p=0.5)
self.fc2 = nn.Linear(500, 1)
self.fc3 = nn.Linear(1,1)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = F.relu(self.conv2(x))
x = x.view(-1, 16 * 16 * 128)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
print(net)
import torch.optim as optim
criterion = nn.BCELoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)
for epoch in range(2): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(train_loader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 32 == 31: # print every 32 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 32))
running_loss = 0.0
print('Finished Training')
但我收到以下错误。我的 train_x 和 train_y 数据存储在 2 个单独的 .hdf5 文件中,当我在训练时尝试读取它们时,会弹出此错误。请任何人都可以告诉必须进行哪些更改。
OSError Traceback (most recent call last)
<ipython-input-8-a8f66fac8b5c> in <module>()
2
3 running_loss = 0.0
----> 4 for i, data in enumerate(train_loader, 0):
5 # get the inputs; data is a list of [inputs, labels]
6 inputs, labels = data
3 frames
/usr/local/lib/python3.6/dist-packages/torch/_utils.py in reraise(self)
392 # (https://bugs.python.org/issue2651), so we work around it.
393 msg = KeyErrorMessage(msg)
--> 394 raise self.exc_type(msg)
OSError: Caught OSError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "<ipython-input-5-0df8cfc88081>", line 37, in __getitem__
img=(self.data)[index[i]]
File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
File "/usr/local/lib/python3.6/dist-packages/h5py/_hl/dataset.py", line 573, in __getitem__
self.id.read(mspace, fspace, arr, mtype, dxpl=self._dxpl)
File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
File "h5py/h5d.pyx", line 182, in h5py.h5d.DatasetID.read
File "h5py/_proxy.pyx", line 130, in h5py._proxy.dset_rw
File "h5py/_proxy.pyx", line 84, in h5py._proxy.H5PY_H5Dread
OSError: Can't read data (wrong B-tree signature)
我收到上面显示的错误。我是 PyTorch 的新手,所以请你建议可以做什么?