26

我想创建一个config dataclass以简化对特定环境变量的白名单和访问(打字os.environ['VAR_NAME']相对于 而言是乏味的config.VAR_NAME)。因此,我需要忽略我dataclass__init__函数中未使用的环境变量,但我不知道如何提取默认值__init__以便用例如也包含*_作为参数之一的函数来包装它。

import os
from dataclasses import dataclass

@dataclass
class Config:
    VAR_NAME_1: str
    VAR_NAME_2: str

config = Config(**os.environ)

运行这个给了我TypeError: __init__() got an unexpected keyword argument 'SOME_DEFAULT_ENV_VAR'

4

3 回答 3

36

在将参数列表传递给构造函数之前清理它可能是最好的方法。不过,我建议不要编写自己的__init__函数,因为 dataclass'__init__做了一些其他方便的事情,你会因为覆盖它而丢失。

此外,由于参数清理逻辑与类的行为非常紧密地绑定并返回一个实例,因此将其放入 a 中可能是有意义的classmethod

from dataclasses import dataclass
import inspect

@dataclass
class Config:
    var_1: str
    var_2: str

    @classmethod
    def from_dict(cls, env):      
        return cls(**{
            k: v for k, v in env.items() 
            if k in inspect.signature(cls).parameters
        })


# usage:
params = {'var_1': 'a', 'var_2': 'b', 'var_3': 'c'}
c = Config.from_dict(params)   # works without raising a TypeError 
print(c)
# prints: Config(var_1='a', var_2='b')
于 2019-03-11T07:25:41.250 回答
20

我只会提供一个明确的__init__而不是使用自动生成的。循环体只设置公认的值,忽略意外的值。

请注意,这不会在稍后抱怨没有默认值的缺失值。

@dataclass(init=False)
class Config:
    VAR_NAME_1: str
    VAR_NAME_2: str

    def __init__(self, **kwargs):
        names = set([f.name for f in dataclasses.fields(self)])
        for k, v in kwargs.items():
            if k in names:
                setattr(self, k, v)

或者,您可以将过滤的环境传递给 default Config.__init__

field_names = set(f.name for f in dataclasses.fields(Config))
c = Config(**{k:v for k,v in os.environ.items() if k in field_names})
于 2019-02-13T20:14:24.230 回答
4

我使用了两个答案的组合;setattr可以成为性能杀手。自然,如果字典在数据类中没有某些记录,则需要为它们设置字段默认值。

from __future__ import annotations
from dataclasses import field, fields, dataclass

@dataclass()
class Record:
    name: str
    address: str
    zip: str = field(default=None)  # won't fail if dictionary doesn't have a zip key

    @classmethod
    def create_from_dict(cls, dict_) -> Record:
        class_fields = {f.name for f in fields(cls)}
        return Record(**{k: v for k, v in dict_.items() if k in class_fields})
于 2019-07-25T18:38:17.607 回答