2

kedro建议将参数存储在conf/base/parameters.yml. 让我们假设它看起来像这样:

step_size: 1
model_params:
    learning_rate: 0.01
    test_data_ratio: 0.2
    num_train_steps: 10000

现在想象一下,我有一些data_engineering管道,其nodes.py功能如下所示:

def some_pipeline_step(num_train_steps):
    """
    Takes the parameter `num_train_steps` as argument.
    """
    pass

我将如何着手并将嵌套参数直接传递给这个函数data_engineering/pipeline.py?我尝试失败:

from kedro.pipeline import Pipeline, node

from .nodes import split_data


def create_pipeline(**kwargs):
    return Pipeline(
        [
            node(
                some_pipeline_step,
                ["params:model_params.num_train_steps"],
                dict(
                    train_x="train_x",
                    train_y="train_y",
                ),
            )
        ]
    )

我知道我可以通过使用将所有参数传递给函数,['parameters']或者只是传递所有model_params参数,['params:model_params']但这似乎不优雅,我觉得必须有一种方法。将不胜感激任何输入!

4

2 回答 2

2

(免责声明:我是 Kedro 团队的一员)

谢谢你的问题。不幸的是,当前版本的 Kedro 不支持嵌套参数。临时解决方案是在节点内使用顶级键(正如您已经指出的那样)或使用某种参数过滤器装饰您的节点函数,这也不是优雅的。

可能最可行的解决方案是通过覆盖方法来自定义您的ProjectContext(in ) 类,如下所示:src/<package_name>/run.py_get_feed_dict

class ProjectContext(KedroContext):
    # ...


    def _get_feed_dict(self) -> Dict[str, Any]:
        """Get parameters and return the feed dictionary."""
        params = self.params
        feed_dict = {"parameters": params}

        def _add_param_to_feed_dict(param_name, param_value):
            """This recursively adds parameter paths to the `feed_dict`,
            whenever `param_value` is a dictionary itself, so that users can
            specify specific nested parameters in their node inputs.

            Example:

                >>> param_name = "a"
                >>> param_value = {"b": 1}
                >>> _add_param_to_feed_dict(param_name, param_value)
                >>> assert feed_dict["params:a"] == {"b": 1}
                >>> assert feed_dict["params:a.b"] == 1
            """
            key = "params:{}".format(param_name)
            feed_dict[key] = param_value

            if isinstance(param_value, dict):
                for key, val in param_value.items():
                    _add_param_to_feed_dict("{}.{}".format(param_name, key), val)

        for param_name, param_value in params.items():
            _add_param_to_feed_dict(param_name, param_value)

        return feed_dict

另请注意,此问题已在开发中得到解决,并将在下一个版本中提供。该修复程序使用上面代码段中的方法。

于 2020-04-27T09:31:30.917 回答
1

正如 Dmitry 所提到的,在节点输入中kedro 0.16.0 引入了嵌套参数值,可以通过.运算符访问:

node(func, "params:a.b", None)

在 CLI 中kedro 0.17.6 启用覆盖嵌套参数params,例如

kedro run --params="model.model_tuning.booster:gbtree"
于 2022-01-06T13:35:03.580 回答