16

我正在使用 Airflow 1.8.1,我想从 PostgreOperator 推送 sql 请求的结果。

这是我的任务:

check_task = PostgresOperator(
    task_id='check_task',
    postgres_conn_id='conx',
    sql="check_task.sql",
    xcom_push=True,
    dag=dag)

def py_is_first_execution(**kwargs):
    value = kwargs['ti'].xcom_pull(task_ids='check_task')
    print 'count ----> ', value
    if value == 0:
       return 'next_task'
    else:
       return 'end-flow'

check_branch = BranchPythonOperator(
    task_id='is-first-execution',
    python_callable=py_is_first_execution,
    provide_context=True,
    dag=dag)

这是我的 sql 脚本:

select count(1) from table

当我从中检查 xcom 值时,check_task它会检索none值。

4

2 回答 2

10

如果我是正确的,当查询返回值时,气流会自动推送到 xcom。但是,当您查看postgresoperator的代码时,您会发现它有一个执行方法,该方法调用 PostgresHook 的 run 方法(dbapi_hook 的扩展)。这两种方法都不会返回任何东西,因此它不会向 xcom 推送任何内容。我们为解决这个问题所做的是创建一个 CustomPostgresSelectOperator,它是 PostgresOperator 的副本,但不是 'hook.run(..)' 而是执行 'return hook.get_records(..)'。

希望对您有所帮助。

于 2017-08-15T15:28:11.763 回答
2

最后,我ExecuteSqlOperator在插件管理器的$AIRFLOW_HOME/plugins.

以我CheckOperator为例,我修改了返回值:这个算子的基本运行与我所需要的正好相反。

这是默认值ExecuteSqlOperatorCheckOperator

这是我的定制SqlSensorReverseSqlSensor

class SqlExecuteOperator(BaseOperator):
    """
    Performs checks against a db. The ``CheckOperator`` expects
    a sql query that will return a single row.

    Note that this is an abstract class and get_db_hook
    needs to be defined. Whereas a get_db_hook is hook that gets a
    single record from an external source.
    :param sql: the sql to be executed
    :type sql: string
    """

    template_fields = ('sql',)
    template_ext = ('.hql', '.sql',)
    ui_color = '#fff7e6'

    @apply_defaults
    def __init__(
            self, sql,
            conn_id=None,
            *args, **kwargs):
        super(SqlExecuteOperator, self).__init__(*args, **kwargs)
        self.conn_id = conn_id
        self.sql = sql

    def execute(self, context=None):
        logging.info('Executing SQL statement: ' + self.sql)
        records = self.get_db_hook().get_first(self.sql)
        logging.info("Record: " + str(records))
        records_int = int(records[0])
        print (records_int)
        return records_int

    def get_db_hook(self):
        return BaseHook.get_hook(conn_id=self.conn_id)
于 2017-08-18T09:34:52.773 回答