2

我有一个dataclass我想知道每个字段是否被显式设置或者它是否由default或填充default_factory

我知道我可以使用所有字段dataclasses.fields(...),这可能适用于使用的字段,default但对于使用default_factory.

我的最终目标是合并两个数据类实例AB。而B应该只覆盖A使用默认值的A的字段。

用例是一个可以在多个位置指定的配置对象,其中一些位置的优先级高于其他位置。

编辑:一个例子

from dataclasses import dataclass, field

def bar():
  return "bar"

@dataclass
class Configuration:
  foo: str = field(default_factory=bar)

conf1 = Configuration(
)

conf2 = Configuration(
  foo="foo"
)

conf3 = Configuration(
  foo="bar"
)

我想检测到它conf1.foo正在使用默认值并且conf2.foo&conf3.foo是明确设置的。

4

1 回答 1

3

首先,merge根据您fieldsz. 但是鉴于此实现dataclass完全按照预期的方式使用工具,这意味着它相当稳定,所以如果可能的话,你会想要使用它:

from dataclasses import asdict, dataclass, field, fields, MISSING


@dataclass
class A:
    a: str
    b: float = 5
    c: list = field(default_factory=list)


def merge(base, add_on):
    retain = {}
    for f in fields(base):
        val = getattr(base, f.name)
        if val == f.default:
            continue
        if f.default_factory != MISSING:
            if val == f.default_factory():
                continue
        retain[f.name] = val
    kwargs = {**asdict(add_on), **retain}
    return type(base)(**kwargs)


fill = A('1', 1, [1])

x = A('a')
y = A('a', 2, [3])
z = A('a', 5, [])
print(merge(x, fill))  # good: A(a='a', b=1, c=[1])
print(merge(y, fill))  # good: A(a='a', b=2, c=[3])
print(merge(z, fill))  # bad:  A(a='a', b=1, c=[1])

正确处理z案例将涉及某种黑客行为,我个人只是再次装饰数据类:

from dataclasses import asdict, dataclass, field, fields


def mergeable(inst):
    old_init = inst.__init__

    def new_init(self, *args, **kwargs):
        self.__customs = {f.name for f, _ in zip(fields(self), args)}
        self.__customs |= kwargs.keys()
        old_init(self, *args, **kwargs)

    def merge(self, other):
        retain = {n: v for n, v in asdict(self).items() if n in self.__customs}
        kwargs = {**asdict(other), **retain}
        return type(self)(**kwargs)

    inst.__init__ = new_init
    inst.merge = merge
    return inst


@mergeable
@dataclass
class A:
    a: str
    b: float = 5
    c: list = field(default_factory=list)


fill = A('1', 1, [1])

x = A('a')
y = A('a', 2, [3])
z = A('a', 5, [])

print(x.merge(fill))  # good: A(a='a', b=1, c=[1])
print(y.merge(fill))  # good: A(a='a', b=2, c=[3])
print(z.merge(fill))  # good: A(a='a', b=5, c=[])

不过,这很可能会产生一些难以猜测的副作用,因此使用风险自负。

于 2019-06-04T22:58:25.947 回答