12

我正在用 FastAPI 编写我的第一个项目,但我有点挣扎。特别是,我不确定我应该如何在我的应用程序中使用 asyncpg 连接池。目前我所拥有的是这样的

在 db.py 我有

pgpool = None


async def get_pool():
    global pgpool
    if not pgpool:
        pgpool = await asyncpg.create_pool(dsn='MYDB_DSN')
    return pgpool

然后在单个文件中,我使用 get_pool 作为依赖项。

@router.post("/user/", response_model=models.User, status_code=201)
async def create_user(user: models.UserCreate, pgpool = Depends(get_pool)):
    # ... do things ...

首先,我拥有的每个端点都使用数据库,因此为每个函数添加依赖参数似乎很愚蠢。其次,这似乎是一种迂回的做事方式。我定义一个全局,然后定义一个返回该全局的函数,然后注入该函数。我相信有更自然的方式来处理它。

我看到人们建议将我需要的任何东西作为属性添加到应用程序对象中

@app.on_event("startup")
async def startup():
    app.pool = await asyncpg.create_pool(dsn='MYDB_DSN')

但是当我有多个带有路由器的文件时它不起作用,我不知道如何从路由器对象访问应用程序对象。

我错过了什么?

4

2 回答 2

9

您可以使用应用程序工厂模式来设置您的应用程序。

为避免使用全局或直接向应用程序对象添加内容,您可以创建自己的类数据库来保存连接池。

要将连接池传递给每个路由,您可以使用中间件并将池添加到request.state

这是示例代码:

import asyncio

import asyncpg
from fastapi import FastAPI, Request

class Database():

    async def create_pool(self):
        self.pool = await asyncpg.create_pool(dsn='MYDB_DSN')

def create_app():

    app = FastAPI()
    db = Database()

    @app.middleware("http")
    async def db_session_middleware(request: Request, call_next):
        request.state.pgpool = db.pool
        response = await call_next(request)
        return response

    @app.on_event("startup")
    async def startup():
        await db.create_pool()

    @app.on_event("shutdown")
    async def shutdown():
        # cleanup
        pass

    @app.get("/")
    async def hello(request: Request):
        print(request.state.pool)

    return app

app = create_app()
于 2020-08-05T23:49:17.563 回答
0

我这样做的方式是在 db.py 中。

class Database:
    def __init__(self,user,password,host,database,port="5432"):
        self.user = user
        self.password = password
        self.host = host
        self.port = port
        self.database = database
        self._cursor = None

        self._connection_pool = None
        
    async def connect(self):
        if not self._connection_pool:
            try:
                self._connection_pool = await asyncpg.create_pool(
                    min_size=1,
                    max_size=20,
                    command_timeout=60,
                    host=self.host,
                    port=self.port,
                    user=self.user,
                    password=self.password,
                    database=self.database,
                    ssl="require"
                )
                logger.info("Database pool connectionn opened")

            except Exception as e:
                logger.exception(e)

    async def fetch_rows(self, query: str,*args):
        if not self._connection_pool:
            await self.connect()
        else:
            con = await self._connection_pool.acquire()
            try:
                result = await con.fetch(query,*args)
                return result
            except Exception as e:
                logger.exception(e)
            finally:
                await self._connection_pool.release(con)

    async def close(self):
        if not self._connection_pool:
            try:
                await self._connection_pool.close()
                logger.info("Database pool connection closed")
            except Exception as e:
                logger.exception(e)

然后在应用程序中

@app.on_event("startup")
async def startup_event():
    database_instance = db.Database(**db_arguments)
    await database_instance.connect()
    app.state.db = database_instance
    logger.info("Server Startup")

@app.on_event("shutdown")
async def shutdown_event():
    if not app.state.db:
        await app.state.db.close()
    logger.info("Server Shutdown")

然后,您可以通过在路由中传入请求参数来获取带有 request.app.state.db 的数据库实例。

于 2021-01-19T07:53:14.033 回答