0

我在 pytorch 项目中创建了我的自定义数据集,我需要通过转换向我的数据集添加高斯噪声。我的数据集是 1 和 -1 的二维数组。我执行以下操作:

class AddGaussianNoise(object): 

    def __init__(self, mean, std):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)   
 

class Normalize(object):
 
    def __init__(self, mean, std):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        
        return (tensor.sub_(self.mean)).div(self.std)
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)  

class MyDataset(Dataset):

    def __init__(self, data, transforms = None):
        self.samples = data
        self.transforms= transforms
    def __len__(self):
        return len(self.samples)


    def __getitem__(self, idx):
        sample = self.samples[idx]
        sample = self.transforms(sample)
        return sample

要检查结果:

data =np.array([[-1,-1,1,-1],[-1,1,-1,-1],[1,-1,-1,-1],[-1,-1,-1,1]])
transformed = NumbersDataset(data,transforms.Compose([AddGaussianNoise(0.5, 0.5),
                                Normalize(0.5,0.5),
                               ]))
print(transformed.samples)

transformes  [[-1 -1  1 -1]
 [-1  1 -1 -1]
 [ 1 -1 -1 -1]
 [-1 -1 -1  1]]

什么都没发生。但自定义转换在 MyDataset 类之外运行良好:

def add_noise(inputs, mean, std):
    transform = transforms.Compose([AddGaussianNoise(0.5, 0.5),
                                    Normalize(0.5,0.5),
                                   ])
    return transform(inputs)

tensor([[-2.0190, -2.7867,  1.8440, -1.1421],
    [-2.3795,  2.2529,  0.0627, -3.0331],
    [ 2.4760, -1.5299, -2.2118, -0.9087],
    [-1.7003,  0.1757, -1.9060,  2.0312]])

不明白问题出在哪里,谢谢

4

2 回答 2

0

@samiogx,您没有应用转换。 “transformed.samples”只给你输入而不是输出。因此,如果您想获得输出,请应用

transform(np.array([[-1,-1,1,-1],[-1,1,-1,-1],[1,-1,-1,-1],[-1,-1,-1,1]]))

这就对了。

于 2021-11-27T07:16:49.903 回答
0

您的数据集 getitem 方法使用转换而不是它自己的转换对象。

class MyDataset(Dataset):

    def __init__(self, data, transforms = None):
        self.samples = data
        self.transforms= transforms

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        if self.transforms:
            sample = self.transforms(sample)
        return sample
于 2021-11-26T09:43:33.053 回答