我正在尝试PyTorch Lightning
通过编写一个非常简单的DataModuleClass
. 之后prepare_data()
,setup()
我正在尝试检查这些功能是否正常工作。所以,我正在尝试从中获取training
和validation
数据集setup()
。但我收到一个错误
AttributeError: 'DataModuleClass' object has no attribute 'training_dataset'
代码
def prepare_data(self):
x = np.random.uniform(0, 10, 10)
e = np.random.normal(0, self.sigma, len(x))
# Making target or labels
y = x + e
# Marging x and e for 2 features
X = np.transpose(np.array([x, e]))
# Converting numpy array to Tensor
self.x_train_tensor = torch.from_numpy(X).float().to(device)
self.y_train_tensor = torch.from_numpy(y).float().to(device)
training_dataset = TensorDataset(self.x_train_tensor, self.y_train_tensor)
self.training_dataset = training_dataset
def setup(self):
data = self.training_dataset
self.train_data, self.val_data = random_split(data, [8, 2])
return self.train_data, self.val_data
def train_dataloader(self):
return DataLoader(self.train_data)
def val_dataloader(self):
return DataLoader(self.val_data)
obj = DataModuleClass()
print(obj.setup())
你能告诉我为什么我会收到这个错误吗?