7

如何从 DataLoader 加载整个数据集?我只得到一批数据集。

这是我的代码

dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=64)
images, labels = next(iter(dataloader))
4

3 回答 3

8

batch_size=dataset.__len__()如果 dataset 是 torch ,您可以设置Dataset,否则batch_szie=len(dataset)应该可以使用。

请注意,这可能需要大量内存,具体取决于您的数据集。

于 2019-08-07T04:53:44.717 回答
7

我不确定您是想在网络训练以外的其他地方使用数据集(例如检查图像),还是想在训练期间迭代批次。

遍历数据集

要么按照 Usman Ali 的回答(可能会溢出)你的记忆,要么你可以这样做

for i in range(len(dataset)): # or i, image in enumerate(dataset)
    images, labels = dataset[i] # or whatever your dataset returns

您之所以能够写作dataset[i],是因为您实现了__len__并且__getitem__在您的Dataset类中(只要它是 Pytorch类的子Dataset类)。

从数据加载器获取所有批次

我理解您的问题的方式是您想要检索所有批次来训练网络。您应该明白,这iter为您提供了数据加载器的迭代器(如果您不熟悉迭代器的概念,请参阅wikipedia entry)。next告诉迭代器给你下一个项目。

因此,与遍历列表的迭代器相比,数据加载器总是返回下一个项目。列表迭代器在某个时候停止。我假设你有一些类似的时期和每个时期的步数。然后你的代码看起来像这样

for i in range(epochs):
    # some code
    for j in range(steps_per_epoch):
        images, labels = next(iter(dataloader))
        prediction = net(images)
        loss = net.loss(prediction, labels)
        ...

小心next(iter(dataloader))。如果您想遍历一个列表,这也可能有效,因为 Python 缓存对象,但每次从索引 0 开始时,您都可能会得到一个新的迭代器。为了避免这种情况,将迭代器取出到顶部,如下所示:

iterator = iter(dataloader)
for i in range(epochs):
    for j in range(steps_per_epoch):
        images, labels = next(iterator)
于 2019-08-07T06:35:07.577 回答
4

另一种选择是直接获取整个数据集,而不使用数据加载器,如下所示:

images, labels = dataset[:]
于 2020-06-05T13:35:23.443 回答