3

我正在尝试使用异步API 和pytest.

首先,我尝试将zzzeekunittest的示例转换为,效果很好pytest

import pytest
from sqlalchemy.orm import Session
from sqlalchemy import event, Column, Integer, create_engine
from sqlalchemy.ext.declarative import declarative_base

Base = declarative_base()

# a model
class Thing(Base):
    __tablename__ = "thing"

    id = Column(Integer, primary_key=True)


@pytest.fixture(scope="session")
def engine_fixture():
    engine = create_engine("postgresql://postgres:changethis@db/app_test", echo=True)
    Base.metadata.drop_all(engine)
    Base.metadata.create_all(engine)

    yield engine

    Base.metadata.drop_all(engine)


@pytest.fixture
def session(engine_fixture):
    conn = engine_fixture.connect()
    trans = conn.begin()
    session = Session(bind=conn)

    def _fixture(session):
        session.add_all([Thing(), Thing(), Thing()])
        session.commit()

    # load fixture data within the scope of the transaction
    _fixture(session)

    # start the session in a SAVEPOINT...
    session.begin_nested()

    # then each time that SAVEPOINT ends, reopen it
    @event.listens_for(session, "after_transaction_end")
    def restart_savepoint(session, transaction):
        if transaction.nested and not transaction._parent.nested:
            session.begin_nested()

    yield session

    # same teardown from the docs
    session.close()
    trans.rollback()
    conn.close()


def _test_thing(session, extra_rollback=0):

    rows = session.query(Thing).all()
    assert len(rows) == 3

    for elem in range(extra_rollback):
        # run N number of rollbacks
        session.add_all([Thing(), Thing(), Thing()])
        rows = session.query(Thing).all()
        assert len(rows) == 6

        session.rollback()

    # after rollbacks, still @ 3 rows
    rows = session.query(Thing).all()
    assert len(rows) == 3

    session.add_all([Thing(), Thing()])
    session.commit()

    rows = session.query(Thing).all()
    assert len(rows) == 5

    session.add(Thing())
    rows = session.query(Thing).all()
    assert len(rows) == 6

    for elem in range(extra_rollback):
        # run N number of rollbacks
        session.add_all([Thing(), Thing(), Thing()])
        rows = session.query(Thing).all()
        if elem > 0:
            # b.c. we rolled back that other "thing" too
            assert len(rows) == 8
        else:
            assert len(rows) == 9
        session.rollback()

    rows = session.query(Thing).all()
    if extra_rollback:
        assert len(rows) == 5
    else:
        assert len(rows) == 6


def test_thing_one_pytest(session):
    # run zero rollbacks
    _test_thing(session, 0)


def test_thing_two_pytest(session):
    # run two extra rollbacks
    _test_thing(session, 2)

然后我尝试使用0.14.0 版切换到asyncioAPIpytest-asyncio

import pytest
from sqlalchemy import Column, Integer, create_engine, event
from sqlalchemy.future import select
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine

Base = declarative_base()

# a model
class Thing(Base):
    __tablename__ = "thing"

    id = Column(Integer, primary_key=True)


@pytest.fixture(scope="session", autouse=True)
def meta_migration():
    # setup
    sync_engine = create_engine(
        "postgresql://postgres:changethis@db/app_test", echo=True
    )
    Base.metadata.drop_all(sync_engine)
    Base.metadata.create_all(sync_engine)

    yield sync_engine

    # teardown
    Base.metadata.drop_all(sync_engine)


@pytest.fixture(scope="session")
async def async_engine() -> AsyncEngine:
    # setup
    engine = create_async_engine(
        "postgresql+asyncpg://postgres:changethis@db/app_test", echo=True
    )

    yield engine


@pytest.fixture(scope="function")
async def session(async_engine):
    conn = await async_engine.connect()
    trans = await conn.begin()
    session = AsyncSession(bind=conn)

    async def _fixture(session: AsyncSession):
        session.add_all([Thing(), Thing(), Thing()])
        await session.commit()

    # load fixture data within the scope of the transaction
    await _fixture(session)

    # start the session in a SAVEPOINT...
    await session.begin_nested()

    # then each time that SAVEPOINT ends, reopen it
    # NOTE: no async listeners yet
    @event.listens_for(session.sync_session, "after_transaction_end")
    def restart_savepoint(session, transaction):
        if transaction.nested and not transaction._parent.nested:
            session.begin_nested()

    yield session

    # same teardown from the docs
    await session.close()
    await trans.rollback()
    await conn.close()


async def _test_thing(session: AsyncSession, extra_rollback=0):

    rows = (await session.execute(select(Thing))).all()
    assert len(rows) == 3

    for elem in range(extra_rollback):
        # run N number of rollbacks
        session.add_all([Thing(), Thing(), Thing()])
        rows = (await session.execute(select(Thing))).all()
        assert len(rows) == 6

        await session.rollback()

    # after rollbacks, still @ 3 rows
    rows = (await session.execute(select(Thing))).all()
    assert len(rows) == 3

    session.add_all([Thing(), Thing()])
    await session.commit()

    rows = (await session.execute(select(Thing))).all()
    assert len(rows) == 5

    session.add(Thing())
    rows = (await session.execute(select(Thing))).all()
    assert len(rows) == 6

    for elem in range(extra_rollback):
        # run N number of rollbacks
        session.add_all([Thing(), Thing(), Thing()])
        rows = (await session.execute(select(Thing))).all()
        if elem > 0:
            # b.c. we rolled back that other "thing" too
            assert len(rows) == 8
        else:
            assert len(rows) == 9
        await session.rollback()

    rows = (await session.execute(select(Thing))).all()
    if extra_rollback:
        assert len(rows) == 5
    else:
        assert len(rows) == 6


@pytest.mark.asyncio
async def test_thing_one_pytest(session):
    # run zero rollbacks
    await _test_thing(session, 0)


@pytest.mark.asyncio
async def test_thing_two_pytest(session):
    # run two extra rollbacks
    await _test_thing(session, 2)

但是,这会失败"FAILED test_thing_two_pytest - assert 8 == 3",因为在第一次测试之后的事务回滚teardown不会恢复到setup阶段中创建的 SAVEPOINT。

由于我对 sqlalchemy 内部知识的了解不是很好,因此我正在寻求帮助来设置它,因为这对我的测试套件性能至关重要。

难道缺少async事件侦听器和定义restart_savepointAsyncSession.sync_session不够的,只需要等待 1.4 API 的稳定版本吗?

谢谢!

4

1 回答 1

2

原来是一个错误,直接联系了 SA 开发人员。

Github问题

使固定

注意:根据@zzzek的说法,API 发生了变化,应该使用connection.begin_nested()支持:session.begin_nested()

The "legacy" pattern that you have above which uses "session.begin_nested()" to create the savepoint, this is not supported for the "future" style engine which asyncio uses. The new version uses the connection itself to recreate the savepoint inside the event.

于 2021-01-07T07:24:45.877 回答