From d24cd5e96d7f8e47c86b5013a7f989a15e2eec89 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 26 May 2022 14:35:03 -0400 Subject: establish sessionmaker and async_sessionmaker as generic This is so that custom Session and AsyncSession classes can be typed for these factories. Added appropriate typevars to `__call__()`, `__enter__()` and other methods so that a custom Session or AsyncSession subclass is carried through. Fixes: #7656 Change-Id: Ia2b8c1f22b4410db26005c3285f6ba3d13d7f0e0 --- test/ext/mypy/plain_files/sessionmakers.py | 88 ++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 test/ext/mypy/plain_files/sessionmakers.py (limited to 'test') diff --git a/test/ext/mypy/plain_files/sessionmakers.py b/test/ext/mypy/plain_files/sessionmakers.py new file mode 100644 index 000000000..ce9b76638 --- /dev/null +++ b/test/ext/mypy/plain_files/sessionmakers.py @@ -0,0 +1,88 @@ +"""test #7656""" + +from sqlalchemy import create_engine +from sqlalchemy import Engine +from sqlalchemy.ext.asyncio import async_scoped_session +from sqlalchemy.ext.asyncio import async_sessionmaker +from sqlalchemy.ext.asyncio import AsyncEngine +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.orm import scoped_session +from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker + + +async_engine = create_async_engine("...") + + +class MyAsyncSession(AsyncSession): + pass + + +def async_session_factory( + engine: AsyncEngine, +) -> async_sessionmaker[MyAsyncSession]: + return async_sessionmaker(engine, class_=MyAsyncSession) + + +def async_scoped_session_factory( + engine: AsyncEngine, +) -> async_scoped_session[MyAsyncSession]: + return async_scoped_session( + async_sessionmaker(engine, class_=MyAsyncSession), + scopefunc=lambda: None, + ) + + +async def async_main() -> None: + fac = async_session_factory(async_engine) + + async with fac() as sess: + # EXPECTED_TYPE: MyAsyncSession + reveal_type(sess) + + async with fac.begin() as sess: + # EXPECTED_TYPE: MyAsyncSession + reveal_type(sess) + + scoped_fac = async_scoped_session_factory(async_engine) + + sess = scoped_fac() + + # EXPECTED_TYPE: MyAsyncSession + reveal_type(sess) + + +engine = create_engine("...") + + +class MySession(Session): + pass + + +def session_factory( + engine: Engine, +) -> sessionmaker[MySession]: + return sessionmaker(engine, class_=MySession) + + +def scoped_session_factory(engine: Engine) -> scoped_session[MySession]: + return scoped_session(sessionmaker(engine, class_=MySession)) + + +def main() -> None: + fac = session_factory(engine) + + with fac() as sess: + # EXPECTED_TYPE: MySession + reveal_type(sess) + + with fac.begin() as sess: + # EXPECTED_TYPE: MySession + reveal_type(sess) + + scoped_fac = scoped_session_factory(engine) + + sess = scoped_fac() + # EXPECTED_TYPE: MySession + reveal_type(sess) -- cgit v1.2.1