我在 datatset.py 模块中收到此错误,该模块用于为连体网络创建正负对。知道为什么它给我这个错误吗?我在我自己的自定义数据集的训练数据集上将此 Siamese MNIST 称为我的主要模块“itr1.py”。
siamese_train_dataset = SiameseMNIST(train_dataset)
siamese_test_dataset = SiameseMNIST(test_dataset)
class SiameseMNIST(Dataset):
def __init__(self, dataset, train=True):
self.dataset = dataset
self.train = self.dataset.train
self.transform = self.dataset.transform
if self.train:
self.train_labels = self.dataset.train_labels
self.train_data = self.dataset.train_data
self.labels_set = set(self.train_labels.numpy())
self.label_to_indices = {label: np.where(self.train_labels.numpy() == label)[0]
for label in self.labels_set}
else:
# generate fixed pairs for testing
self.test_labels = self.dataset.test_labels
self.test_data = self.dataset.test_data
self.labels_set = set(self.test_labels.numpy())
self.label_to_indices = {label: np.where(self.test_labels.numpy() == label)[0]
for label in self.labels_set}
random_state = np.random.RandomState(29)
positive_pairs = [[i,
random_state.choice(self.label_to_indices[self.test_labels[i].item()]),
1]
for i in range(0, len(self.test_data), 2)]
negative_pairs = [[i,
random_state.choice(self.label_to_indices[
np.random.choice(
list(self.labels_set - set([self.test_labels[i].item()]))
)
]),
0]
for i in range(1, len(self.test_data), 2)]
self.test_pairs = positive_pairs + negative_pairs
def __getitem__(self, index):
if self.train:
target = np.random.randint(0, 2)
img1, label1 = self.train_data[index], self.train_labels[index].item()
if target == 1:
siamese_index = index
while siamese_index == index:
siamese_index = np.random.choice(self.label_to_indices[label1])
else:
siamese_label = np.random.choice(list(self.labels_set - set([label1])))
siamese_index = np.random.choice(self.label_to_indices[siamese_label])
img2 = self.train_data[siamese_index]
else:
img1 = self.test_data[self.test_pairs[index][0]]
img2 = self.test_data[self.test_pairs[index][1]]
target = self.test_pairs[index][2]
img1 = Image.fromarray(img1.numpy(), mode='L')
img2 = Image.fromarray(img2.numpy(), mode='L')
if self.transform is not None:
img1 = self.transform(img1)
img2 = self.transform(img2)
return (img1, img2), target
def __len__(self):
return len(self.dataset)