2

attr 的包以某种方式破坏了 pytorch 的parameter()模块方法。我想知道是否有人有任何变通办法或解决方案,以便这两个包可以无缝集成?

如果没有,关于将问题发布到哪个 github 的任何建议?我的直觉是将其发布到 attr 的 github 上,但堆栈跟踪几乎完全与 pytorch 的代码库相关。

Python 3.7.3
attrs== 19.1.0
torch==1.1.0.post2
torchvision==0.3.0
import attr
import torch


class RegularModule(torch.nn.Module):
    pass

@attr.s
class AttrsModule(torch.nn.Module):
    pass


module = RegularModule()
print(list(module.parameters()))

module = AttrsModule()
print(list(module.parameters()))

实际输出为:

$python attrs_pytorch.py
[]
Traceback (most recent call last):
  File "attrs_pytorch.py", line 18, in <module>
    print(list(module.parameters()))
  File "/usr/local/anaconda3/envs/bgg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 814, in parameters
    for name, param in self.named_parameters(recurse=recurse):
  File "/usr/local/anaconda3/envs/bgg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 840, in named_parameters
    for elem in gen:
  File "/usr/local/anaconda3/envs/bgg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 784, in _named_members
    for module_prefix, module in modules:
  File "/usr/local/anaconda3/envs/bgg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 975, in named_modules
    if self not in memo:
TypeError: unhashable type: 'AttrsModule'

预期的输出是:

$python attrs_pytorch.py
[]
[]
4

2 回答 2

1

您可以使用一种解决方法并使用dataclasses它(您应该使用它,因为它在标准 Python 库中,因为3.7您显然正在使用它)。虽然我认为简单__init__更具可读性。可以使用库(禁用散列)做类似的事情attrs,如果可能的话,我只是更喜欢使用标准库的解决方案。

原因(如果您设法处理与散列相关的错误)是您正在调用torch.nn.Module.__init__()生成_parameters属性和其他特定于框架的数据。

首先解决散列dataclasses

@dataclasses.dataclass(eq=False)
class AttrsModule(torch.nn.Module):
    pass

这解决了关于和部分hashing所述的问题:documentationhasheq

默认情况下,dataclass() 不会隐式添加hash () 方法,除非这样做是安全的。

这是 PyTorch 需要的,因此该模型可以在 C++ 支持中使用(如果我错了,请纠正我),此外:

如果 eq 为 false,hash () 将保持不变,这意味着将使用超类的 hash () 方法(如果超类是对象,这意味着它将回退到基于 id 的散列)。

所以你可以很好地使用torch.nn.Module __hash__函数(如果出现任何进一步的错误,请参阅数据类的文档)。

这会给您留下错误:

AttributeError: 'AttrsModule' object has no attribute '_parameters'

因为torch.nn.Module没有调用构造函数。快速而肮脏的修复:

@dataclasses.dataclass(eq=False)
class AttrsModule(torch.nn.Module):
    def __post_init__(self):
        super().__init__()

__post_init__是一个 after 调用的函数__init__(谁会猜到),您可以在其中初始化特定于 Torch 的参数。

不过,我建议不要同时使用这两个模块。例如,您正在__repr__使用您的代码销毁 PyTorch,因此repr=False应该将其传递给dataclasses.dataclass构造函数,它会给出最终代码(我希望消除库之间的明显冲突):

import dataclasses

import torch


class RegularModule(torch.nn.Module):
    pass


@dataclasses.dataclass(eq=False, repr=False)
class AttrsModule(torch.nn.Module):
    def __post_init__(self):
        super().__init__()


module = RegularModule()
print(list(module.parameters()))

module = AttrsModule()
print(list(module.parameters()))

有关更多信息,attrs请参阅hynek回答和他的博客文章。

于 2019-07-31T13:46:55.777 回答
1

attrs有一章关于哈希性,也解释了 Python 中哈希的缺陷:https ://www.attrs.org/en/stable/hashing.html

您必须决定哪种行为适合您的具体问题。有关更多一般信息,请查看https://hynek.me/articles/hashes-and-equality/ — 事实证明,散列在 Python 中非常棘手。

于 2019-08-01T18:28:20.917 回答