我有一个,也许是小问题,但我现在被困了很长一段时间。希望有人可以帮助我。我目前正在使用我喜欢通过深度学习(CNN 网络)进行训练的 Kddcup99 数据集
我有一个包含 Panda Dataframe 的“数据集”类。因此我分成正常和验证数据集。到目前为止,没有问题。我将它加载到 Numpy 向量中,将其火炬传递到 Tensor,然后将其定向到 DataLoader。
数据集类有这两个重要的用于迭代的类:
def __len__(self):
return len(self.val_df)
def __getitem__(self, index):
img, target = self.val_df[index][:-1], self.val_df[index][-1]
return img, target, index
不在类中的是 DataLoader 字符串:
test_dataloader = DataLoader(datat.val_df, batch_size=10, shuffle=True)
在我的 Trainer Class 中,我有一个 for 循环,它应该遍历 Dataloader:
with torch.no_grad():
for data in dataloader:
inputs, labels, idx = data
inputs = inputs.to(self.device)
但它不会。我无法访问标签、索引等。
我现在的问题是:为什么? 如何通过 Dataloader 从给定的数据集中访问标签、索引?
谢谢大家的帮助!非常感谢。