我目前正在编写一个 C++ 程序,该程序需要对 torchScript 格式的 CNN 模型的结构进行一些分析。我正在使用 C++ 火炬库,它在 torch.org 上显示的方式,像这样加载到模型中:
#include <torch/script.h>
#include <torch/torch.h>
#include <iostream>
#include <memory>
int main(int argc, const char* argv[]) {
if (argc != 2) {
std::cerr << "usage: example-app <path-to-exported-script-module>\n";
return -1;
}
torch::jit::script::Module module;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
module = torch::jit::load(argv[1]);
}
catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return -1;
}
return 0;
}
据我所知,module
由一组嵌套的集合组成,torch::jit::script::Module
其中最低的代表内置函数。我访问那些最低的模块如下:
void print_modules(const torch::jit::script::Module& imodule) {
for (const auto& module : imodule.named_children()) {
if(module.value.children().size() > 0){
print_modules(module.value);
}
else{
std::cout << module.name << "\n";
}
}
}
该函数递归地遍历模块并打印最低级别的名称,这些名称对应于torch脚本的内置函数。
我现在的问题是,如何访问那些内置函数的详细信息,例如卷积的步长?
我一生都无法弄清楚如何访问模块的这些基本属性。