(qiskit QGAN的预处理,但用例有点无关紧要)我有点迷失在试图弄清楚如何在将图像数据集传递给 GAN 之前对其进行预处理。以下是我的错误的所有相关代码。此代码源自https://github.com/Qiskit/qiskit-tutorials/blob/master/legacy_tutorials/aqua/machine_learning/qgans_for_loading_random_distributions.ipynb并已(尝试)更改以适应不同的输入数据集。(原件生成了更简单维度的虚拟数据)
# Root directory for dataset
dataroot = "./data/land"
# Number of workers for dataloader
workers = 2
# Batch size during training
batch_size = 128
#img size
image_size = 64
dataset = dset.ImageFolder(root=dataroot,
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=True, num_workers=workers)
real_batch = next(iter(dataloader))
real_batch_arr= [t.numpy() for t in real_batch]
# Set number of qubits per data dimension as list of k qubit values[#q_0,...,#q_k-1]
num_qubits = [4]
k = len(num_qubits)
num_epochs = 100
# Initialize qGAN
qgan = QGAN(real_batch_arr,bounds=bounds, num_qubits = num_qubits,batch_size = 128, num_epochs=num_epochs, snapshot_dir=None)
这给了我以下错误。
ValueError Traceback (most recent call last)
<ipython-input-42-8cba9a74f024> in <module>
5
6 # Initialize qGAN
----> 7 qgan = QGAN(real_batch_arr,bounds=bounds, num_qubits = num_qubits,batch_size = 128, num_epochs=num_epochs, snapshot_dir=None)
8 qgan.seed = 1
9 # Set quantum instance to run the quantum generator
~\Anaconda3\lib\site-packages\qiskit\aqua\algorithms\distribution_learners\qgan.py in __init__(self, data, bounds, num_qubits, batch_size, num_epochs, seed, discriminator, generator, tol_rel_ent, snapshot_dir, quantum_instance)
99 if data is None:
100 raise AquaError('Training data not given.')
--> 101 self._data = np.array(data)
102 if bounds is None:
103 bounds_min = np.percentile(self._data, 5, axis=0)
ValueError: could not broadcast input array from shape (128,3,64,64) into shape (128)
我知道 qiskit 函数 (QGAN) 在某些时候正试图将 real_batch_arr 转换为一个数组(在传递给 QGAN 时定义为一个列表)。该数组预计仅为 (128),但除此之外,需要将数组传递给 QGAN,而不是列表(基于上面链接的原始代码)。
我的问题是如何将我的列表转换为我需要的数组。也可能有一些我根本就想念的东西。我真的很感激任何建议或意见。