5

我目前正在编写一个 C++ 程序,该程序需要对 torchScript 格式的 CNN 模型的结构进行一些分析。我正在使用 C++ 火炬库,它在 torch.org 上显示的方式,像这样加载到模型中:

#include <torch/script.h>
#include <torch/torch.h>

#include <iostream>
#include <memory>

int main(int argc, const char* argv[]) {
    if (argc != 2) {
        std::cerr << "usage: example-app <path-to-exported-script-module>\n";
        return -1;
    }

    torch::jit::script::Module module;
    try {
        // Deserialize the ScriptModule from a file using torch::jit::load().
        module = torch::jit::load(argv[1]);
    }
    catch (const c10::Error& e) {
        std::cerr << "error loading the model\n";
        return -1;
    }

    return 0;
}

据我所知,module由一组嵌套的集合组成,torch::jit::script::Module其中最低的代表内置函数。我访问那些最低的模块如下:

void print_modules(const torch::jit::script::Module& imodule) {

    for (const auto& module : imodule.named_children()) {

        if(module.value.children().size() > 0){
            print_modules(module.value);

        }
        else{

            std::cout << module.name << "\n";

        }
    }
}

该函数递归地遍历模块并打印最低级别的名称,这些名称对应于torch脚本的内置函数。

我现在的问题是,如何访问那些内置函数的详细信息,例如卷积的步长?

我一生都无法弄清楚如何访问模块的这些基本属性。

4

0 回答 0