6

我需要用我训练的卷积神经网络的数据测试结果编写一个文件。数据包括语音数据收集。文件格式需要是“文件名,预测”,但我很难提取文件名。我像这样加载数据:

import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

TEST_DATA_PATH = ...

trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

test_dataset = torchvision.datasets.MNIST(
    root=TEST_DATA_PATH,
    train=False,
    transform=trans,
    download=True
)

test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)

我正在尝试按如下方式写入文件:

f = open("test_y", "w")
with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader, 0):
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        file = os.listdir(TEST_DATA_PATH + "/all")[i]
        format = file + ", " + str(predicted.item()) + '\n'
        f.write(format)
f.close()

问题os.listdir(TESTH_DATA_PATH + "/all")[i]在于它与加载的文件顺序不同步test_loader。我能做些什么?

4

3 回答 3

7

好吧,这取决于您Dataset的实施方式。例如,在这种torchvision.datasets.MNIST(...)情况下,您不能仅仅因为没有单个样本的文件名之类的东西而检索文件名(MNIST 样本以不同的方式加载)。

由于您没有展示您的Dataset实现,我将告诉您如何使用torchvision.datasets.ImageFolder(...)(或任何torchvision.datasets.DatasetFolder(...))完成此操作:

f = open("test_y", "w")
with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader, 0):
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        sample_fname, _ = test_loader.dataset.samples[i]
        f.write("{}, {}\n".format(sample_fname, predicted.item()))
f.close()

您可以看到文件的路径是在 期间检索的__getitem__(self, index),特别是在此处

如果您实现了自己的Dataset(并且可能希望支持shuffleand batch_size > 1),那么我会sample_fname__getitem__(...)通话中返回并执行以下操作:

for i, (images, labels, sample_fname) in enumerate(test_loader, 0):
    # [...]

这样你就不需要关心了shuffle。如果batch_size大于 1,则需要将循环的内容更改为更通用的内容,例如:

f = open("test_y", "w")
for i, (images, labels, samples_fname) in enumerate(test_loader, 0):
    outputs = model(images)
    pred = torch.max(outputs, 1)[1]
    f.write("\n".join([
        ", ".join(x)
        for x in zip(map(str, pred.cpu().tolist()), samples_fname)
    ]) + "\n")
f.close()
于 2019-06-21T08:12:45.313 回答
1

在一般情况下DataLoader,它会为您提供其内部数据集中的批次。

AS @Barriel 在单/多标签分类问题的情况下提到,DataLoader没有图像文件名,只有代表图像的张量和类/标签。

但是,DataLoader加载对象时的构造函数可能会占用一些小东西(与数据集一起,您可以根据需要打包目标/标签和文件名),甚至是数据框

这样,DataLoader可能会以某种方式抓住您需要的东西。

于 2019-06-21T17:50:55.480 回答
0

如果您使用 PyCharm 或任何具有调试工具的 IDE,请使用它来查看您的 data_loader,希望您能看到文件名列表,就像我的情况一样。

就我而言,我的 data_loader 是由 mmsegmentation 创建的。 由 mmsegmentation 创建的 Data_loader

于 2021-12-19T13:32:47.713 回答