2

我有一个包含一组图像的文件夹,我想在 LibTorch 中将它们作为 CustomDataset 加载。

int main()
{
    ...
    std::string file_location{"dataset/img_align_celeba/*.jpg"};
    auto train_set = CustomDataset(file_location).map(data::transforms::Stack<>());
}

自定义数据集.h

using namespace torch;


torch::Tensor read_data(const std::string& loc);

std::vector<std::string> obtain_data(std::string &address);

class CustomDataset : public data::Dataset<CustomDataset> {
private:
    std::vector<std::string> m_filenames;
public:
    // Constructor
    explicit CustomDataset(std::string &files): m_filenames(obtain_data(files)) {
    };

    // Override get() function to return tensor at location index
    torch::data::Example<> get(size_t index) override {
        torch::Tensor sample_img = read_data(m_filenames[index]);
        torch::Tensor sample_label = torch::full({1}, 1);
        return {sample_img, sample_label};
    };

    // Return the length of data
    torch::optional<size_t> size() const override {
        return m_filenames.size();
    };
};

自定义数据集.cpp

#include "CustomDataset.h"

torch::Tensor read_data(const std::string& loc)
{
    // Read Data here
    // Return tensor form of the image
    cv::Mat img = cv::imread(loc);
    std::vector<cv::Mat> channels(3);
    cv::split(img, channels);

    auto R = torch::from_blob(
            channels[2].ptr(),
            {64, 64},
            torch::kUInt8);
    auto G = torch::from_blob(
            channels[1].ptr(),
            {64, 64},
            torch::kUInt8);
    auto B = torch::from_blob(
            channels[0].ptr(),
            {64, 64},
            torch::kUInt8);

    auto tdata = torch::cat({R, G, B})
            .view({3, 64, 64})
            .to(torch::kFloat);

    return tdata;
}

std::vector<std::string> obtain_data(std::string &address)
{
    std::vector<std::string> filenames;
    cv::glob(address, filenames);
    return filenames;
}

即使当我在 CustomDataset 类中时,我看到成员变量m_filenames不为空,并且它get()提供了带有图像和标签的张量,但一旦我回到main.cpp名为 CustomDataset 的变量train_set的大小为 0。

4

0 回答 0