将 Python 3.6 与 Pytorch 1.3.1 一起使用。我注意到当整个模块被导入另一个模块时,一些保存的 nn.Modules 无法加载。举个例子,这里是一个最小工作示例的模板。
#!/usr/bin/env python3
#encoding:utf-8
# file 'dnn_predict.py'
from torch import nn
class NN(nn.Module):##NN network
# Initialisation and other class methods
networks=[torch.load(f=os.path.join(resource_directory, 'nn-classify-cpu_{fold}.pkl'.format(fold=fold))) for fold in range(5)]
...
if __name__=='__main__':
# Some testing snippets
pass
当我直接在 shell 中运行它时,整个文件工作得很好。但是,当我想使用该类并使用此代码将神经网络加载到另一个文件中时,它会失败。
#!/usr/bin/env python3
#encoding:utf-8
from dnn_predict import *
错误读取AttributeError: Can't get attribute 'NN' on <module '__main__'>
在 Pytorch 中加载保存的变量或导入模块是否与其他常见的 Python 库不同?一些帮助或指向根本原因的指针将不胜感激。