attr 的包以某种方式破坏了 pytorch 的parameter()
模块方法。我想知道是否有人有任何变通办法或解决方案,以便这两个包可以无缝集成?
如果没有,关于将问题发布到哪个 github 的任何建议?我的直觉是将其发布到 attr 的 github 上,但堆栈跟踪几乎完全与 pytorch 的代码库相关。
Python 3.7.3
attrs== 19.1.0
torch==1.1.0.post2
torchvision==0.3.0
import attr
import torch
class RegularModule(torch.nn.Module):
pass
@attr.s
class AttrsModule(torch.nn.Module):
pass
module = RegularModule()
print(list(module.parameters()))
module = AttrsModule()
print(list(module.parameters()))
实际输出为:
$python attrs_pytorch.py
[]
Traceback (most recent call last):
File "attrs_pytorch.py", line 18, in <module>
print(list(module.parameters()))
File "/usr/local/anaconda3/envs/bgg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 814, in parameters
for name, param in self.named_parameters(recurse=recurse):
File "/usr/local/anaconda3/envs/bgg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 840, in named_parameters
for elem in gen:
File "/usr/local/anaconda3/envs/bgg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 784, in _named_members
for module_prefix, module in modules:
File "/usr/local/anaconda3/envs/bgg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 975, in named_modules
if self not in memo:
TypeError: unhashable type: 'AttrsModule'
预期的输出是:
$python attrs_pytorch.py
[]
[]