代码是这样的,我想用来__getattr__
让State类有用,可以在上面挂载一些属性,以便我们以后可以使用它,例如:
state = State()
state.name1 = 1
print(state.name1)
完整的代码是这样的:
class State:
"""
提供给用户使用的 state;
为了实现断点重训,用户应当保证其保存的信息都是可序列化的;
# TODO:可能需要提供保存该state的函数,但是用户可以自己在 callback 里实现,现在我们先不管;
"""
def __init__(self):
self._value = dict()
def __setattr__(self, key, value):
self._value[key] = value
def __getattr__(self, item):
if item in self._value:
return self._value[item]
else:
raise AttributeError(f"{self.__name__} has no attribute {item}.")
def state_dict(self):
return self._value
def load_state_dict(self, value: dict):
if not isinstance(value, dict):
raise ValueError("If you want to reload a state dict for reasons like breakpoint retraining, this parameter"
"value should be a dict type.")
self._value = value
回溯是这样的:
Traceback (most recent call last):
File "state.py", line 165, in <module>
a = State()
File "state.py", line 52, in __init__
self._value = dict()
File "state.py", line 55, in __setattr__
self._value[key] = value
File "state.py", line 58, in __getattr__
if item in self._value:
File "state.py", line 58, in __getattr__
if item in self._value:
File "state.py", line 58, in __getattr__
if item in self._value:
[Previous line repeated 993 more times]
RecursionError: maximum recursion depth exceeded