0

编辑#1:我已经按照建议使用“手动”基于 if/else 的解决方案更新了示例,以证明需要进一步自动化。


在基于关键字参数名称而不是类型选择目标函数的情况下,如何有效地调度函数(即实现类似multimethods的函数)?

我的用例是为数据类实现多个工厂方法,这些数据类的字段相互依赖,并且可以根据这些字段的不同子集进行初始化,例如

下面是一个使用递归的示例,它可以正常工作,但需要大量手写的容易出错的代码,并且不能真正扩展到更复杂的情况。

def a_from_b_c(b, c):
    return b+c

def b_from_a_c(a, c):
    return a+c

def c_from_a_b(a, b):
    return a**b

@datalass
class foo(object):
    a: float
    b: float
    c: float

    @classmethod
    def init_from(cls, **kwargs):
        if "a" not in kwargs and all(k in kwargs for k in ("b", "c")):
            kwargs["a"] = a_from_b_c(kwargs["b"], kwargs["c"])
            cls.init_from(**kwargs)
        if "b" not in kwargs and all(k in kwargs for k in ("a", "c")):
            kwargs["b"] = b_from_a_c(kwargs["a"], kwargs["c"])
            cls.init_from(**kwargs)
        if "c" not in kwargs and all(k in kwargs for k in ("a", "b")):
            kwargs["c"] = c_from_a_b(kwargs["a"], kwargs["b"])
            cls.init_from(**kwargs)
        return cls(**kwargs)
        

我正在寻找一种解决方案,该解决方案可以扩展到具有许多字段和复杂初始化路径的数据类,而另一方面需要更少的手写代码,并且有很多重复和错误来源.. 上面代码中的模式非常明显并且可以是自动化的,但我想确保在这里使用正确的工具。

4

2 回答 2

0

在最近的编辑之后,坦率地说这是一个很大的变化,你可以做这样的事情:

import inspect
from dataclasses import dataclass
from collections import defaultdict


class Initializer:
    """This class collects all registered functions and allows
    multiple ways to calculate your field.
    """
    def __init__(self):
        self.mappings = defaultdict(list)
        
    def __call__(self, arg):
        def wrapper(func):
            self.mappings[arg].append(func)      
        return wrapper


# Create an instance and register your functions
init = Initializer()


# Add the `kwargs` for convenience
@init("a")
def a_from_b_c(b, c, **kwargs):
    return b + c


@init("a")
def a_from_b_d(b, d, **kwargs):
    return b + d


@init("b")
def b_from_a_c(a, c, **kwargs):
    return a + c


@init("c")
def c_from_a_b(a, b, **kwargs):
    return a ** b


@init("d")
def d_from_a_b_c(a, b, c, **kwargs):
    return a ** b + c


@dataclass
class foo(object):
    a: float
    b: float
    c: float
    d: float

    @classmethod
    def init_from(cls, **kwargs):
        # Not sure if there is a better way to access the fields
        for field in foo.__dataclass_fields__:
            if field not in kwargs:
                funcs = init.mappings[field]

                # Multiple functions means a loop. If you're sure 
                # you have a 1-to-1 mapping then change the defaultdict 
                # to a dict[field->function]
                for func in funcs:
                    func_args = inspect.getfullargspec(func).args
                    
                    if all(arg in kwargs for arg in func_args):
                        kwargs[field] = func(**kwargs)
                        return foo(**kwargs)

然后使用它:

>>> foo.init_from(a=3, b=2, d=3)
foo(a=3, b=2, c=9, d=3)

>>> foo.init_from(a=3, b=2, c=3)
foo(a=3, b=2, c=3, d=12)
于 2021-01-16T19:06:02.950 回答
0

这是一个基于@kostas-mouratidis 想法的解决方案,该想法存储从字段到用于初始化这些字段的方法的映射。通过使用类装饰器,映射可以与类一起存储(恕我直言)。通过对初始化字段的方法使用另一个装饰器,生成的代码看起来非常干净和可读。

有什么改进建议吗?

from dataclasses import dataclass
import inspect 

def dataclass_greedy_init(cls):
    """Dataclass decorator that adds an 'init_from' class method to recursively initialize 
    all fields and fully initialize an instance of the class from a given subset of
    fields specified as keyword arguments.

    In order to achieve this, the class is searched for *field init methods*, i.e. static 
    methods decoarted with the 'init_field' decorator. A mapping from field names to these 
    methods is built and stored as an attribute of the class. The 'init_from' method looks
    up appropriate methods given the set fields specified as keyword arguments in the 
    'init_from' class method. It initializes missing fields recursively in a greedy fashion,   
    i.e. it initializes the first missing field for which a field init method can be found
    and all arguments to this field init method can be supplied.  
    """    

    # Collect all field init methods
    init_methods = inspect.getmembers(cls, lambda f: inspect.isfunction(f) and not inspect.ismethod(f) and hasattr(f, "init_field"))
    # Create a mapping from field names to signatures (i.e. required fields)
    # and field init methods.
    cls.init_mapping = {}
    for init_method_name, init_method in init_methods:
        init_field = init_method.init_field
        if not init_field in cls.init_mapping:
            cls.init_mapping[init_field] = []
        cls.init_mapping[init_field].append((inspect.signature(init_method), init_method))
    # Add classmethod 'init_from'
    def init_from(cls, **kwargs):
        for field in cls.__dataclass_fields__:
            if field not in kwargs and field in cls.init_mapping:
                for init_method_sig, init_method in cls.init_mapping[field]:
                    try:
                        mapped_kwargs = {p: kwargs[p] for p in init_method_sig.parameters if p in kwargs}
                        bound_args = init_method_sig.bind(**mapped_kwargs)
                        bound_args.apply_defaults()
                        kwargs[field] = init_method(**bound_args.arguments)
                        return cls.init_from(**kwargs)
                    except TypeError:
                        pass
        return cls(**kwargs)
    cls.init_from = classmethod(init_from)
    return cls

def init_field(field_name):
    """Decorator to be used in combination with 'dataclass_greedy_init' to generate
    static methods with an additional 'field_name' attribute that indicates for which 
    of the dataclass's fields this method should be used during initialization."""
    def inner(func):
        func.init_field = field_name
        return staticmethod(func)
    return inner

@dataclass_greedy_init
@dataclass
class foo(object):
    a: float
    b: float
    c: float
    d: float

    @init_field("a")
    def init_a_from_b_c(b,c):
        return c-b

    @init_field("b")
    def init_b_from_a_c(a,c):
        return c-a

    @init_field("c")
    def init_c_from_a_b(a,b):
        return a+b

    @init_field("c")
    def init_c_from_d(d):
        return d/2

    @init_field("d")
    def init_d_from_a_b_c(a,b,c):
        return a+b+c

    @init_field("d")
    def init_d_from_a(a):
        return 6*a

print(foo.init_from(a=1, b=2))
print(foo.init_from(a=1, c=3))
print(foo.init_from(b=2, c=3))
print(foo.init_from(a=1))

于 2021-01-17T12:37:12.777 回答