2

我有一个,也许是小问题,但我现在被困了很长一段时间。希望有人可以帮助我。我目前正在使用我喜欢通过深度学习(CNN 网络)进行训练的 Kddcup99 数据集

我有一个包含 Panda Dataframe 的“数据集”类。因此我分成正常和验证数据集。到目前为止,没有问题。我将它加载到 Numpy 向量中,将其火炬传递到 Tensor,然后将其定向到 DataLoader。

数据集类有这两个重要的用于迭代的类:

def __len__(self):
        return len(self.val_df)

def __getitem__(self, index):        
        img, target = self.val_df[index][:-1], self.val_df[index][-1]
        return img, target, index

不在类中的是 DataLoader 字符串:

test_dataloader = DataLoader(datat.val_df, batch_size=10, shuffle=True)

在我的 Trainer Class 中,我有一个 for 循环,它应该遍历 Dataloader:

with torch.no_grad():
            for data in dataloader:
                inputs, labels, idx = data
                inputs = inputs.to(self.device)

但它不会。我无法访问标签、索引等。

我现在的问题是:为什么? 如何通过 Dataloader 从给定的数据集中访问标签、索引?

谢谢大家的帮助!非常感谢。

4

1 回答 1

2

的第一个参数DataLoader是您要从中加载数据的数据集,通常是 a Dataset,但不限于Dataset. 只要它定义了长度(__len__)并且可以被索引(__getitem__允许)它是可以接受的。

您正在传递datat.val_dfDataLoader,这可能是一个 NumPy 数组。NumPy 数组有一个长度并且可以被索引,所以它可以在DataLoader. 由于您直接传递该数组,__getitem__因此永远不会调用您的数据集,但数组本身已被索引,因此每个项目都是data.val_df[index].

DataLoader您必须使用数据集本身(),而不是使用基础数据datat

test_dataloader = DataLoader(datat, batch_size=10, shuffle=True)
于 2020-05-18T14:37:33.470 回答