1

我有一个程序(ASGI 服务器),其结构大致如下:

import asyncio
import contextvars

ctxvar = contextvars.ContextVar("ctx")


async def lifepsan():
    ctxvar.set("spam")


async def endpoint():
    assert ctxvar.get() == "spam"


async def main():
    ctx = contextvars.copy_context()
    task = asyncio.create_task(lifepsan())
    await task
    task = asyncio.create_task(endpoint())
    await task

asyncio.run(main())

因为生命周期事件/端点在任务中运行,所以它们不能共享上下文变量。这是设计使然:任务在执行之前复制上下文,因此lifespan无法ctxvar正确设置。这是端点的期望行为,但我希望执行看起来像这样(从用户的角度来看):

async def lifespan():
    ctxvar.set("spam")
    await endpoint()

换句话说,端点在它们自己的独立上下文中执行,但在生命周期的上下文中。

我试图通过使用来使它工作contextlib.copy_context()

import asyncio
import contextvars

ctxvar = contextvars.ContextVar("ctx")


async def lifepsan():
    ctxvar.set("spam")
    print("set")


async def endpoint():
    print("get")
    assert ctxvar.get() == "spam"


async def main():
    ctx = contextvars.copy_context()
    task = ctx.run(asyncio.create_task, lifepsan())
    await task
    endpoint_ctx = ctx.copy()
    task = endpoint_ctx.run(asyncio.create_task, endpoint())
    await task

asyncio.run(main())

也:

async def main():
    ctx = contextvars.copy_context()
    task = asyncio.create_task(ctx.run(lifespan))
    await task
    endpoint_ctx = ctx.copy()
    task = asyncio.create_task(endpoint_ctx.run(endpoint))
    await task

但是,这种方式似乎contextvars.Context.run不起作用(我猜上下文是在创建协程时绑定的,但不是在执行时绑定的)。

有没有一种简单的方法来实现所需的行为,而无需重新构建任务的创建方式等?

4

1 回答 1

1

这是我在PEP 555asgiref的启发下想出的:

from contextvars import Context, ContextVar, copy_context
from typing import Any


def _set_cvar(cvar: ContextVar, val: Any):
    cvar.set(val)


class CaptureContext:

    def __init__(self) -> None:
        self.context = Context()

    def __enter__(self) -> "CaptureContext":
        self._outer = copy_context()
        return self

    def sync(self):
        final = copy_context()
        for cvar in final:
            if cvar not in self._outer:
                # new contextvar set
                self.context.run(_set_cvar, cvar, final.get(cvar))
            else:
                final_val = final.get(cvar)
                if self._outer.get(cvar) != final_val:
                    # value changed
                    self.context.run(_set_cvar, cvar, final_val)

    def __exit__(self, *args: Any):
        self.sync()


def restore_context(context: Context) -> None:
    """Restore `context` to the current Context"""
    for cvar in context.keys():
        try:
            cvar.set(context.get(cvar))
        except LookupError:
            cvar.set(context.get(cvar))

用法:

import asyncio
import contextvars

ctxvar = contextvars.ContextVar("ctx")


async def lifepsan(cap: CaptureContext):
    with cap:
        ctxvar.set("spam")


async def endpoint():
    assert ctxvar.get() == "spam"


async def main():
    cap = CaptureContext()
    await asyncio.create_task(lifepsan(cap))
    restore_context(cap.context)
    task = asyncio.create_task(endpoint())
    await task

asyncio.run(main())

sync()如果任务长时间运行并且您需要在任务完成之前捕获上下文,则提供该方法。一个有点做作的例子:

import asyncio
import contextvars

ctxvar = contextvars.ContextVar("ctx")


async def lifepsan(cap: CaptureContext, event: asyncio.Event):
    with cap:
        ctxvar.set("spam")
        cap.sync()
        event.set()
        await asyncio.sleep(float("inf"))


async def endpoint():
    assert ctxvar.get() == "spam"


async def main():
    cap = CaptureContext()
    event = asyncio.Event()
    asyncio.create_task(lifepsan(cap, event))
    await event.wait()
    restore_context(cap.context)
    task = asyncio.create_task(endpoint())
    await task

asyncio.run(main())

contextvars.Context.run我认为如果与协程一起工作会更好。

于 2021-08-04T17:42:24.257 回答