1

我想创建一个 ZCA-whitened CIFAR-10 的自定义 PyTorch 数据集,随后我可以使用 torchvision 的 function 加载该数据集torchvision.datasets.CIFAR10()。到目前为止,我可以成功地对数据进行白化(参见下面的代码),但我不知道如何以允许使用torchvision.datasets.CIFAR10(). 我该怎么做呢?

ZCA-whiten CIFAR 10 的代码:

    trainset = torchvision.datasets.CIFAR10(
        root='./datasets',
        train=True,
        download=False)
    train_data = trainset.data.reshape(-1, 32*32*3)
    zca_matrix = zca_whitening_matrix(train_data.T)
    whitened_training_data = np.matmul(zca_matrix, train_data.T).T
    whitened_training_data = whitened_training_data.reshape((-1, 32, 32, 3))

    # whiten CIFAR-10 testing data
    testset = torchvision.datasets.CIFAR10(
        root='./datasets',
        train=False,
        download=False)
    testdata = testset.data.reshape(-1, 32*32*3)
    whitened_test_data = np.matmul(zca_matrix, testdata.T).T
    whitened_test_data = whitened_test_data.reshape((-1, 32, 32, 3))

真的只是保存 numpy 数组的最佳方法,如此处所示?

PyTorch:如何将 DataLoaders 用于自定义数据集

4

1 回答 1

2

默认情况下,这对于任何数据集都是不可能的。

推荐的方法是子类化并自己创建功能,或者使用新方法对类进行修补。

于 2019-11-15T20:59:19.050 回答