summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2022-05-31 20:56:53 +0000
committerGerrit Code Review <gerrit@ci3.zzzcomputing.com>2022-05-31 20:56:53 +0000
commitf192da8d70b9f3d3bf2c7d3b2ca2a8876c2bfe67 (patch)
tree9272570c91cf3dcc23bcce83c8e517f7ff60a894 /test
parent6eeee37190b49fc1f47596be7248ec1836136826 (diff)
parentd24cd5e96d7f8e47c86b5013a7f989a15e2eec89 (diff)
downloadsqlalchemy-f192da8d70b9f3d3bf2c7d3b2ca2a8876c2bfe67.tar.gz
Merge "establish sessionmaker and async_sessionmaker as generic" into main
Diffstat (limited to 'test')
-rw-r--r--test/ext/mypy/plain_files/sessionmakers.py88
1 files changed, 88 insertions, 0 deletions
diff --git a/test/ext/mypy/plain_files/sessionmakers.py b/test/ext/mypy/plain_files/sessionmakers.py
new file mode 100644
index 000000000..ce9b76638
--- /dev/null
+++ b/test/ext/mypy/plain_files/sessionmakers.py
@@ -0,0 +1,88 @@
+"""test #7656"""
+
+from sqlalchemy import create_engine
+from sqlalchemy import Engine
+from sqlalchemy.ext.asyncio import async_scoped_session
+from sqlalchemy.ext.asyncio import async_sessionmaker
+from sqlalchemy.ext.asyncio import AsyncEngine
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.ext.asyncio import create_async_engine
+from sqlalchemy.orm import scoped_session
+from sqlalchemy.orm import Session
+from sqlalchemy.orm import sessionmaker
+
+
+async_engine = create_async_engine("...")
+
+
+class MyAsyncSession(AsyncSession):
+ pass
+
+
+def async_session_factory(
+ engine: AsyncEngine,
+) -> async_sessionmaker[MyAsyncSession]:
+ return async_sessionmaker(engine, class_=MyAsyncSession)
+
+
+def async_scoped_session_factory(
+ engine: AsyncEngine,
+) -> async_scoped_session[MyAsyncSession]:
+ return async_scoped_session(
+ async_sessionmaker(engine, class_=MyAsyncSession),
+ scopefunc=lambda: None,
+ )
+
+
+async def async_main() -> None:
+ fac = async_session_factory(async_engine)
+
+ async with fac() as sess:
+ # EXPECTED_TYPE: MyAsyncSession
+ reveal_type(sess)
+
+ async with fac.begin() as sess:
+ # EXPECTED_TYPE: MyAsyncSession
+ reveal_type(sess)
+
+ scoped_fac = async_scoped_session_factory(async_engine)
+
+ sess = scoped_fac()
+
+ # EXPECTED_TYPE: MyAsyncSession
+ reveal_type(sess)
+
+
+engine = create_engine("...")
+
+
+class MySession(Session):
+ pass
+
+
+def session_factory(
+ engine: Engine,
+) -> sessionmaker[MySession]:
+ return sessionmaker(engine, class_=MySession)
+
+
+def scoped_session_factory(engine: Engine) -> scoped_session[MySession]:
+ return scoped_session(sessionmaker(engine, class_=MySession))
+
+
+def main() -> None:
+ fac = session_factory(engine)
+
+ with fac() as sess:
+ # EXPECTED_TYPE: MySession
+ reveal_type(sess)
+
+ with fac.begin() as sess:
+ # EXPECTED_TYPE: MySession
+ reveal_type(sess)
+
+ scoped_fac = scoped_session_factory(engine)
+
+ sess = scoped_fac()
+ # EXPECTED_TYPE: MySession
+ reveal_type(sess)