我正在尝试使用异步API 和pytest
.
首先,我尝试将zzzeekunittest
的示例转换为,效果很好pytest
import pytest
from sqlalchemy.orm import Session
from sqlalchemy import event, Column, Integer, create_engine
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
# a model
class Thing(Base):
__tablename__ = "thing"
id = Column(Integer, primary_key=True)
@pytest.fixture(scope="session")
def engine_fixture():
engine = create_engine("postgresql://postgres:changethis@db/app_test", echo=True)
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
yield engine
Base.metadata.drop_all(engine)
@pytest.fixture
def session(engine_fixture):
conn = engine_fixture.connect()
trans = conn.begin()
session = Session(bind=conn)
def _fixture(session):
session.add_all([Thing(), Thing(), Thing()])
session.commit()
# load fixture data within the scope of the transaction
_fixture(session)
# start the session in a SAVEPOINT...
session.begin_nested()
# then each time that SAVEPOINT ends, reopen it
@event.listens_for(session, "after_transaction_end")
def restart_savepoint(session, transaction):
if transaction.nested and not transaction._parent.nested:
session.begin_nested()
yield session
# same teardown from the docs
session.close()
trans.rollback()
conn.close()
def _test_thing(session, extra_rollback=0):
rows = session.query(Thing).all()
assert len(rows) == 3
for elem in range(extra_rollback):
# run N number of rollbacks
session.add_all([Thing(), Thing(), Thing()])
rows = session.query(Thing).all()
assert len(rows) == 6
session.rollback()
# after rollbacks, still @ 3 rows
rows = session.query(Thing).all()
assert len(rows) == 3
session.add_all([Thing(), Thing()])
session.commit()
rows = session.query(Thing).all()
assert len(rows) == 5
session.add(Thing())
rows = session.query(Thing).all()
assert len(rows) == 6
for elem in range(extra_rollback):
# run N number of rollbacks
session.add_all([Thing(), Thing(), Thing()])
rows = session.query(Thing).all()
if elem > 0:
# b.c. we rolled back that other "thing" too
assert len(rows) == 8
else:
assert len(rows) == 9
session.rollback()
rows = session.query(Thing).all()
if extra_rollback:
assert len(rows) == 5
else:
assert len(rows) == 6
def test_thing_one_pytest(session):
# run zero rollbacks
_test_thing(session, 0)
def test_thing_two_pytest(session):
# run two extra rollbacks
_test_thing(session, 2)
然后我尝试使用0.14.0 版切换到asyncio
APIpytest-asyncio
import pytest
from sqlalchemy import Column, Integer, create_engine, event
from sqlalchemy.future import select
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
Base = declarative_base()
# a model
class Thing(Base):
__tablename__ = "thing"
id = Column(Integer, primary_key=True)
@pytest.fixture(scope="session", autouse=True)
def meta_migration():
# setup
sync_engine = create_engine(
"postgresql://postgres:changethis@db/app_test", echo=True
)
Base.metadata.drop_all(sync_engine)
Base.metadata.create_all(sync_engine)
yield sync_engine
# teardown
Base.metadata.drop_all(sync_engine)
@pytest.fixture(scope="session")
async def async_engine() -> AsyncEngine:
# setup
engine = create_async_engine(
"postgresql+asyncpg://postgres:changethis@db/app_test", echo=True
)
yield engine
@pytest.fixture(scope="function")
async def session(async_engine):
conn = await async_engine.connect()
trans = await conn.begin()
session = AsyncSession(bind=conn)
async def _fixture(session: AsyncSession):
session.add_all([Thing(), Thing(), Thing()])
await session.commit()
# load fixture data within the scope of the transaction
await _fixture(session)
# start the session in a SAVEPOINT...
await session.begin_nested()
# then each time that SAVEPOINT ends, reopen it
# NOTE: no async listeners yet
@event.listens_for(session.sync_session, "after_transaction_end")
def restart_savepoint(session, transaction):
if transaction.nested and not transaction._parent.nested:
session.begin_nested()
yield session
# same teardown from the docs
await session.close()
await trans.rollback()
await conn.close()
async def _test_thing(session: AsyncSession, extra_rollback=0):
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 3
for elem in range(extra_rollback):
# run N number of rollbacks
session.add_all([Thing(), Thing(), Thing()])
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 6
await session.rollback()
# after rollbacks, still @ 3 rows
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 3
session.add_all([Thing(), Thing()])
await session.commit()
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 5
session.add(Thing())
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 6
for elem in range(extra_rollback):
# run N number of rollbacks
session.add_all([Thing(), Thing(), Thing()])
rows = (await session.execute(select(Thing))).all()
if elem > 0:
# b.c. we rolled back that other "thing" too
assert len(rows) == 8
else:
assert len(rows) == 9
await session.rollback()
rows = (await session.execute(select(Thing))).all()
if extra_rollback:
assert len(rows) == 5
else:
assert len(rows) == 6
@pytest.mark.asyncio
async def test_thing_one_pytest(session):
# run zero rollbacks
await _test_thing(session, 0)
@pytest.mark.asyncio
async def test_thing_two_pytest(session):
# run two extra rollbacks
await _test_thing(session, 2)
但是,这会失败"FAILED test_thing_two_pytest - assert 8 == 3"
,因为在第一次测试之后的事务回滚teardown
不会恢复到setup
阶段中创建的 SAVEPOINT。
由于我对 sqlalchemy 内部知识的了解不是很好,因此我正在寻求帮助来设置它,因为这对我的测试套件性能至关重要。
难道缺少async
事件侦听器和定义restart_savepoint
是AsyncSession.sync_session
不够的,只需要等待 1.4 API 的稳定版本吗?
谢谢!