嗨,我正在尝试使用 Omniglot 数据集实现连体神经网络以进行一次性图像识别。实现的初始步骤需要生成具有相同/不同类的对样本,为此我使用了Ben Myara 的 github中的make_pair函数,几乎没有修改。但是,每次调用函数时都会弹出keyError,所以我想知道导致此错误的原因,这是我的实现:
import requests
import io
def load_numpy_arr_from_url(url):
"""
Loads a numpy array from surfdrive.
Input:
url: Download link of dataset
Outputs:
dataset: numpy array with input features or labels
"""
response = requests.get(url)
response.raise_for_status()
return np.load(io.BytesIO(response.content))
# Downloading may take a while..
train_x =load_numpy_arr_from_url('https://surfdrive.surf.nl/files/index.php/s/tvQmLyY7MhVsADb/download')
#Transform bool type to integer
train_data = train_x* 1
train_y = load_numpy_arr_from_url('https://surfdrive.surf.nl/files/index.php/s/z234AHrQqx9RVGH/download')
import torch
def make_pairs(data, labels, num=1000):
digits = {}
for i, j in enumerate(labels):
if not j in digits:
digits[j] = []
digits[j].append(i)
pairs, labels_ = [], []
for i in range(num):
if np.random.rand() >= .5: # same digit
digit = random.choice(range(len(labels+1)))
d1, d2 = random.choice(digits[digit], size=2, replace=False)
labels_.append(1)
else:
digit1, digit2 = np.random.choice(range(len(labels+1)), size=2, replace=False)
d1, d2 = random.choice(digits[digit1]), np.random.choice(digits[digit2])
labels_.append(0)
pairs.append(torch.from_numpy(np.concatenate([data[d1,:], data[d2,:]])).view(1, 56, 28))
return torch.cat(pairs), torch.LongTensor(labels_)
当我尝试使用以下命令调用该函数时发生错误:
make_pairs(train_data,train_y, 5)
这是我得到的回溯错误:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-30-7d53181e46ef> in <module>()
25
26 return torch.cat(pairs), torch.LongTensor(labels_)
---> 27 make_pairs(train_data,train_y, 5)
28 #print(a)
<ipython-input-30-7d53181e46ef> in make_pairs(data, labels, num)
14 if np.random.rand() >= .5: # same digit
15 digit = random.choice(range(len(labels+1)))
---> 16 print(random.choice(digits[digit], replace=False))
17 d1, d2 = random.choice(digits[digit], size=2, replace=False)
18 labels_.append(1)
KeyError: 12803
此外,我还尝试在没有 for 循环的情况下实现部分功能,并且一切似乎都在那里正常工作:
import numpy as np
digits = {}
for i, j in enumerate(train_y):
if not j in digits:
digits[j] = []
digits[j].append(i)
pairs, labels_ = [], []
digit = np.random.choice(range(len(train_y)+1)
d1, d2 = np.random.choice(digits[digit], size=2, replace=False)
labels_.append(1)
print(torch.from_numpy(np.concatenate([train_data[d1,:], train_data[d2,:]])).view(1, 56, 28))