diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2022-05-31 20:56:53 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@ci3.zzzcomputing.com> | 2022-05-31 20:56:53 +0000 |
| commit | f192da8d70b9f3d3bf2c7d3b2ca2a8876c2bfe67 (patch) | |
| tree | 9272570c91cf3dcc23bcce83c8e517f7ff60a894 /lib/sqlalchemy/ext/asyncio/session.py | |
| parent | 6eeee37190b49fc1f47596be7248ec1836136826 (diff) | |
| parent | d24cd5e96d7f8e47c86b5013a7f989a15e2eec89 (diff) | |
| download | sqlalchemy-f192da8d70b9f3d3bf2c7d3b2ca2a8876c2bfe67.tar.gz | |
Merge "establish sessionmaker and async_sessionmaker as generic" into main
Diffstat (limited to 'lib/sqlalchemy/ext/asyncio/session.py')
| -rw-r--r-- | lib/sqlalchemy/ext/asyncio/session.py | 33 |
1 files changed, 18 insertions, 15 deletions
diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index eac2e5806..be3414cef 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -8,6 +8,7 @@ from __future__ import annotations from typing import Any from typing import Dict +from typing import Generic from typing import Iterable from typing import Iterator from typing import NoReturn @@ -698,7 +699,8 @@ class AsyncSession(ReversibleProxy[Session]): from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import create_async_engine - from sqlalchemy.orm import Session, sessionmaker + from sqlalchemy.ext.asyncio import async_sessionmaker + from sqlalchemy.orm import Session # construct async engines w/ async drivers engines = { @@ -721,8 +723,7 @@ class AsyncSession(ReversibleProxy[Session]): ].sync_engine # apply to AsyncSession using sync_session_class - AsyncSessionMaker = sessionmaker( - class_=AsyncSession, + AsyncSessionMaker = async_sessionmaker( sync_session_class=RoutingSession ) @@ -850,14 +851,13 @@ class AsyncSession(ReversibleProxy[Session]): """Close all :class:`_asyncio.AsyncSession` sessions.""" await greenlet_spawn(self.sync_session.close_all) - async def __aenter__(self) -> AsyncSession: + async def __aenter__(self: _AS) -> _AS: return self async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: await self.close() - def _maker_context_manager(self) -> _AsyncSessionContextManager: - # TODO: can this use asynccontextmanager ?? + def _maker_context_manager(self: _AS) -> _AsyncSessionContextManager[_AS]: return _AsyncSessionContextManager(self) # START PROXY METHODS AsyncSession @@ -1367,7 +1367,10 @@ class AsyncSession(ReversibleProxy[Session]): # END PROXY METHODS AsyncSession -class async_sessionmaker: +_AS = TypeVar("_AS", bound="AsyncSession") + + +class async_sessionmaker(Generic[_AS]): """A configurable :class:`.AsyncSession` factory. The :class:`.async_sessionmaker` factory works in the same way as the @@ -1409,12 +1412,12 @@ class async_sessionmaker: """ # noqa E501 - class_: Type[AsyncSession] + class_: Type[_AS] def __init__( self, bind: Optional[_AsyncSessionBind] = None, - class_: Type[AsyncSession] = AsyncSession, + class_: Type[_AS] = AsyncSession, # type: ignore autoflush: bool = True, expire_on_commit: bool = True, info: Optional[_InfoType] = None, @@ -1437,7 +1440,7 @@ class async_sessionmaker: self.kw = kw self.class_ = class_ - def begin(self) -> _AsyncSessionContextManager: + def begin(self) -> _AsyncSessionContextManager[_AS]: """Produce a context manager that both provides a new :class:`_orm.AsyncSession` as well as a transaction that commits. @@ -1458,7 +1461,7 @@ class async_sessionmaker: session = self() return session._maker_context_manager() - def __call__(self, **local_kw: Any) -> AsyncSession: + def __call__(self, **local_kw: Any) -> _AS: """Produce a new :class:`.AsyncSession` object using the configuration established in this :class:`.async_sessionmaker`. @@ -1498,16 +1501,16 @@ class async_sessionmaker: ) -class _AsyncSessionContextManager: +class _AsyncSessionContextManager(Generic[_AS]): __slots__ = ("async_session", "trans") - async_session: AsyncSession + async_session: _AS trans: AsyncSessionTransaction - def __init__(self, async_session: AsyncSession): + def __init__(self, async_session: _AS): self.async_session = async_session - async def __aenter__(self) -> AsyncSession: + async def __aenter__(self) -> _AS: self.trans = self.async_session.begin() await self.trans.__aenter__() return self.async_session |
