0

我正在尝试PyTorch Lightning通过编写一个非常简单的DataModuleClass. 之后prepare_data()setup()我正在尝试检查这些功能是否正常工作。所以,我正在尝试从中获取trainingvalidation数据集setup()。但我收到一个错误

AttributeError: 'DataModuleClass' object has no attribute 'training_dataset'

代码

def prepare_data(self):
    x = np.random.uniform(0, 10, 10)
    e = np.random.normal(0, self.sigma, len(x))
    
    # Making target or labels
    y = x + e
    
    # Marging x and e for 2 features
    X = np.transpose(np.array([x, e]))

    # Converting numpy array to Tensor
    self.x_train_tensor = torch.from_numpy(X).float().to(device)
    self.y_train_tensor = torch.from_numpy(y).float().to(device)
    
    training_dataset = TensorDataset(self.x_train_tensor, self.y_train_tensor)

    self.training_dataset = training_dataset

def setup(self):
    data = self.training_dataset
    self.train_data, self.val_data = random_split(data, [8, 2])
    return self.train_data, self.val_data
    
    
def train_dataloader(self):
    return DataLoader(self.train_data)

def val_dataloader(self):
    return DataLoader(self.val_data)
    
obj = DataModuleClass()
print(obj.setup())  

你能告诉我为什么我会收到这个错误吗?

4

1 回答 1

0

从代码在我看来。

的变量self.training_dataset在第一行DataModuleClass被初始化prepare_datasetup需要它。

但是你打电话setup没有打电话training_dataset

如果prepare_data预计每次创建DataModuleClass对象时都会调用,那么最好prepare_data放入__init__. 像

def __init__(self, other_params):
    ..... all your code previously in __init__
    self.prepare_data()  # put this in the last line of this function

但如果你不需要,那么你需要先prepare_data打电话setup

obj = DataModuleClass()
obj.prepare_data()
print(obj.setup())  

或者放在prepare_data自己setup身上。

def setup(self):
    self.prepare_data()
    data = self.training_dataset
    self.train_data, self.val_data = random_split(data, [8, 2])
    return self.train_data, self.val_data

self.train_data编辑 1:查看和的实际值self.val_data

从返回的对象setuptorch.utils.data.dataset.Subset. 基本上有两种获取张量的方法。

1. 像对待列表一样对待它们

train_data, val_data = obj.setup()
print(train_data[0])

2.使用for循环

train_data, val_data = obj.setup()
for data in train_data:
    print(data)

笔记

我不确定你是否会得到张量或TensorDataset. 如果是后者,则再次使用相同的技巧,例如

train_data, val_data = obj.setup()
train_tensor_data = train_data[0]
print(train_tensor_data[0])
于 2021-05-08T14:11:17.683 回答