summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-04-10 15:42:35 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-04-11 22:11:07 -0400
commita45e2284dad17fbbba3bea9d5e5304aab21c8c94 (patch)
treeac31614f2d53059570e2edffe731baf384baea23 /lib/sqlalchemy
parentaa9cd878e8249a4a758c7f968e929e92fede42a5 (diff)
downloadsqlalchemy-a45e2284dad17fbbba3bea9d5e5304aab21c8c94.tar.gz
pep-484: asyncio
in this patch the asyncio/events.py module, which existed only to raise errors when trying to attach event listeners, is removed, as we were already coding an asyncio-specific workaround in upstream Pool / Session to raise this error, just moved the error out to the target and did the same thing for Engine. We also add an async_sessionmaker class. The initial rationale here is because sessionmaker() is hardcoded to Session subclasses, and there's not a way to get the use case of sessionmaker(class_=AsyncSession) to type correctly without changing the sessionmaker() symbol itself to be a function and not a class, which gets too complicated for what this is. Additionally, _SessionClassMethods has only three methods on it, one of which is not usable with asyncio (close_all()), the others not generally used from the session class. Change-Id: I064a5fa5d91cc8d5bbe9597437536e37b4e801fe
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/engine/base.py20
-rw-r--r--lib/sqlalchemy/engine/events.py25
-rw-r--r--lib/sqlalchemy/engine/interfaces.py4
-rw-r--r--lib/sqlalchemy/engine/result.py4
-rw-r--r--lib/sqlalchemy/event/base.py2
-rw-r--r--lib/sqlalchemy/ext/asyncio/__init__.py29
-rw-r--r--lib/sqlalchemy/ext/asyncio/base.py118
-rw-r--r--lib/sqlalchemy/ext/asyncio/engine.py401
-rw-r--r--lib/sqlalchemy/ext/asyncio/events.py44
-rw-r--r--lib/sqlalchemy/ext/asyncio/result.py180
-rw-r--r--lib/sqlalchemy/ext/asyncio/scoping.py241
-rw-r--r--lib/sqlalchemy/ext/asyncio/session.py406
-rw-r--r--lib/sqlalchemy/orm/base.py8
-rw-r--r--lib/sqlalchemy/orm/events.py14
-rw-r--r--lib/sqlalchemy/orm/instrumentation.py9
-rw-r--r--lib/sqlalchemy/orm/mapper.py4
-rw-r--r--lib/sqlalchemy/orm/scoping.py130
-rw-r--r--lib/sqlalchemy/orm/session.py26
-rw-r--r--lib/sqlalchemy/orm/state.py36
-rw-r--r--lib/sqlalchemy/pool/events.py6
-rw-r--r--lib/sqlalchemy/sql/base.py3
-rw-r--r--lib/sqlalchemy/util/_concurrency_py3k.py10
22 files changed, 1125 insertions, 595 deletions
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 8bcc7e258..594a19344 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -42,7 +42,7 @@ from ..sql import util as sql_util
_CompiledCacheType = MutableMapping[Any, "Compiled"]
if typing.TYPE_CHECKING:
- from . import Result
+ from . import CursorResult
from . import ScalarResult
from .interfaces import _AnyExecuteParams
from .interfaces import _AnyMultiExecuteParams
@@ -472,7 +472,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
else:
return self._dbapi_connection
- def get_isolation_level(self) -> str:
+ def get_isolation_level(self) -> _IsolationLevel:
"""Return the current isolation level assigned to this
:class:`_engine.Connection`.
@@ -1186,9 +1186,9 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
statement: Executable,
parameters: Optional[_CoreAnyExecuteParams] = None,
execution_options: Optional[_ExecuteOptionsParameter] = None,
- ) -> Result:
+ ) -> CursorResult:
r"""Executes a SQL statement construct and returns a
- :class:`_engine.Result`.
+ :class:`_engine.CursorResult`.
:param statement: The statement to be executed. This is always
an object that is in both the :class:`_expression.ClauseElement` and
@@ -1235,7 +1235,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
func: FunctionElement[Any],
distilled_parameters: _CoreMultiExecuteParams,
execution_options: _ExecuteOptionsParameter,
- ) -> Result:
+ ) -> CursorResult:
"""Execute a sql.FunctionElement object."""
return self._execute_clauseelement(
@@ -1306,7 +1306,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
ddl: DDLElement,
distilled_parameters: _CoreMultiExecuteParams,
execution_options: _ExecuteOptionsParameter,
- ) -> Result:
+ ) -> CursorResult:
"""Execute a schema.DDL object."""
execution_options = ddl._execution_options.merge_with(
@@ -1403,7 +1403,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
elem: Executable,
distilled_parameters: _CoreMultiExecuteParams,
execution_options: _ExecuteOptionsParameter,
- ) -> Result:
+ ) -> CursorResult:
"""Execute a sql.ClauseElement object."""
execution_options = elem._execution_options.merge_with(
@@ -1476,7 +1476,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
compiled: Compiled,
distilled_parameters: _CoreMultiExecuteParams,
execution_options: _ExecuteOptionsParameter = _EMPTY_EXECUTION_OPTS,
- ) -> Result:
+ ) -> CursorResult:
"""Execute a sql.Compiled object.
TODO: why do we have this? likely deprecate or remove
@@ -1526,7 +1526,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
statement: str,
parameters: Optional[_DBAPIAnyExecuteParams] = None,
execution_options: Optional[_ExecuteOptionsParameter] = None,
- ) -> Result:
+ ) -> CursorResult:
r"""Executes a SQL statement construct and returns a
:class:`_engine.CursorResult`.
@@ -1603,7 +1603,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
execution_options: _ExecuteOptions,
*args: Any,
**kw: Any,
- ) -> Result:
+ ) -> CursorResult:
"""Create an :class:`.ExecutionContext` and execute, returning
a :class:`_engine.CursorResult`."""
diff --git a/lib/sqlalchemy/engine/events.py b/lib/sqlalchemy/engine/events.py
index 699faf489..ef10946a8 100644
--- a/lib/sqlalchemy/engine/events.py
+++ b/lib/sqlalchemy/engine/events.py
@@ -16,6 +16,7 @@ from typing import Tuple
from typing import Type
from typing import Union
+from .base import Connection
from .base import Engine
from .interfaces import ConnectionEventsTarget
from .interfaces import DBAPIConnection
@@ -123,9 +124,23 @@ class ConnectionEvents(event.Events[ConnectionEventsTarget]):
_dispatch_target = ConnectionEventsTarget
@classmethod
- def _listen( # type: ignore[override]
+ def _accept_with(
+ cls,
+ target: Union[ConnectionEventsTarget, Type[ConnectionEventsTarget]],
+ ) -> Optional[Union[ConnectionEventsTarget, Type[ConnectionEventsTarget]]]:
+ default_dispatch = super()._accept_with(target)
+ if default_dispatch is None and hasattr(
+ target, "_no_async_engine_events"
+ ):
+ target._no_async_engine_events() # type: ignore
+
+ return default_dispatch
+
+ @classmethod
+ def _listen(
cls,
event_key: event._EventKey[ConnectionEventsTarget],
+ *,
retval: bool = False,
**kw: Any,
) -> None:
@@ -769,7 +784,9 @@ class DialectEvents(event.Events[Dialect]):
def _listen( # type: ignore
cls,
event_key: event._EventKey[Dialect],
+ *,
retval: bool = False,
+ **kw: Any,
) -> None:
target = event_key.dispatch_target
@@ -789,10 +806,8 @@ class DialectEvents(event.Events[Dialect]):
return target.dialect
elif isinstance(target, Dialect):
return target
- elif hasattr(target, "dispatch") and hasattr(
- target.dispatch._events, "_no_async_engine_events"
- ):
- target.dispatch._events._no_async_engine_events()
+ elif hasattr(target, "_no_async_engine_events"):
+ target._no_async_engine_events()
else:
return None
diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py
index aa75da614..54fe21d74 100644
--- a/lib/sqlalchemy/engine/interfaces.py
+++ b/lib/sqlalchemy/engine/interfaces.py
@@ -46,7 +46,7 @@ from ..util.typing import TypedDict
if TYPE_CHECKING:
from .base import Connection
from .base import Engine
- from .result import Result
+ from .cursor import CursorResult
from .url import URL
from ..event import _ListenerFnType
from ..event import dispatcher
@@ -2422,7 +2422,7 @@ class ExecutionContext:
def _get_cache_stats(self) -> str:
raise NotImplementedError()
- def _setup_result_proxy(self) -> Result:
+ def _setup_result_proxy(self) -> CursorResult:
raise NotImplementedError()
def fire_sequence(self, seq: Sequence_SchemaItem, type_: Integer) -> int:
diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py
index 880bd8d4c..11998e718 100644
--- a/lib/sqlalchemy/engine/result.py
+++ b/lib/sqlalchemy/engine/result.py
@@ -536,7 +536,9 @@ class ResultInternal(InPlaceGenerative, Generic[_R]):
return interim_rows
@HasMemoized_ro_memoized_attribute
- def _onerow_getter(self) -> Callable[..., Union[_NoRow, _R]]:
+ def _onerow_getter(
+ self,
+ ) -> Callable[..., Union[Literal[_NoRow._NO_ROW], _R]]:
make_row = self._row_getter
post_creational_filter = self._post_creational_filter
diff --git a/lib/sqlalchemy/event/base.py b/lib/sqlalchemy/event/base.py
index 8ed4c64ba..c16f6870b 100644
--- a/lib/sqlalchemy/event/base.py
+++ b/lib/sqlalchemy/event/base.py
@@ -256,6 +256,7 @@ class _HasEventsDispatch(Generic[_ET]):
def _listen(
cls,
event_key: _EventKey[_ET],
+ *,
propagate: bool = False,
insert: bool = False,
named: bool = False,
@@ -361,6 +362,7 @@ class Events(_HasEventsDispatch[_ET]):
def _listen(
cls,
event_key: _EventKey[_ET],
+ *,
propagate: bool = False,
insert: bool = False,
named: bool = False,
diff --git a/lib/sqlalchemy/ext/asyncio/__init__.py b/lib/sqlalchemy/ext/asyncio/__init__.py
index 15b2cb015..dfe89a154 100644
--- a/lib/sqlalchemy/ext/asyncio/__init__.py
+++ b/lib/sqlalchemy/ext/asyncio/__init__.py
@@ -5,18 +5,17 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-from .engine import async_engine_from_config
-from .engine import AsyncConnection
-from .engine import AsyncEngine
-from .engine import AsyncTransaction
-from .engine import create_async_engine
-from .events import AsyncConnectionEvents
-from .events import AsyncSessionEvents
-from .result import AsyncMappingResult
-from .result import AsyncResult
-from .result import AsyncScalarResult
-from .scoping import async_scoped_session
-from .session import async_object_session
-from .session import async_session
-from .session import AsyncSession
-from .session import AsyncSessionTransaction
+from .engine import async_engine_from_config as async_engine_from_config
+from .engine import AsyncConnection as AsyncConnection
+from .engine import AsyncEngine as AsyncEngine
+from .engine import AsyncTransaction as AsyncTransaction
+from .engine import create_async_engine as create_async_engine
+from .result import AsyncMappingResult as AsyncMappingResult
+from .result import AsyncResult as AsyncResult
+from .result import AsyncScalarResult as AsyncScalarResult
+from .scoping import async_scoped_session as async_scoped_session
+from .session import async_object_session as async_object_session
+from .session import async_session as async_session
+from .session import async_sessionmaker as async_sessionmaker
+from .session import AsyncSession as AsyncSession
+from .session import AsyncSessionTransaction as AsyncSessionTransaction
diff --git a/lib/sqlalchemy/ext/asyncio/base.py b/lib/sqlalchemy/ext/asyncio/base.py
index 3f77f5500..7fdd2d7e0 100644
--- a/lib/sqlalchemy/ext/asyncio/base.py
+++ b/lib/sqlalchemy/ext/asyncio/base.py
@@ -1,36 +1,103 @@
+# ext/asyncio/base.py
+# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from __future__ import annotations
+
import abc
import functools
+from typing import Any
+from typing import ClassVar
+from typing import Dict
+from typing import Generic
+from typing import NoReturn
+from typing import Optional
+from typing import overload
+from typing import Type
+from typing import TypeVar
import weakref
from . import exc as async_exc
+from ... import util
+from ...util.typing import Literal
+
+_T = TypeVar("_T", bound=Any)
+
+
+_PT = TypeVar("_PT", bound=Any)
-class ReversibleProxy:
- # weakref.ref(async proxy object) -> weakref.ref(sync proxied object)
- _proxy_objects = {}
+SelfReversibleProxy = TypeVar(
+ "SelfReversibleProxy", bound="ReversibleProxy[Any]"
+)
+
+
+class ReversibleProxy(Generic[_PT]):
+ _proxy_objects: ClassVar[
+ Dict[weakref.ref[Any], weakref.ref[ReversibleProxy[Any]]]
+ ] = {}
__slots__ = ("__weakref__",)
- def _assign_proxied(self, target):
+ @overload
+ def _assign_proxied(self, target: _PT) -> _PT:
+ ...
+
+ @overload
+ def _assign_proxied(self, target: None) -> None:
+ ...
+
+ def _assign_proxied(self, target: Optional[_PT]) -> Optional[_PT]:
if target is not None:
- target_ref = weakref.ref(target, ReversibleProxy._target_gced)
+ target_ref: weakref.ref[_PT] = weakref.ref(
+ target, ReversibleProxy._target_gced
+ )
proxy_ref = weakref.ref(
self,
- functools.partial(ReversibleProxy._target_gced, target_ref),
+ functools.partial( # type: ignore
+ ReversibleProxy._target_gced, target_ref
+ ),
)
ReversibleProxy._proxy_objects[target_ref] = proxy_ref
return target
@classmethod
- def _target_gced(cls, ref, proxy_ref=None):
+ def _target_gced(
+ cls: Type[SelfReversibleProxy],
+ ref: weakref.ref[_PT],
+ proxy_ref: Optional[weakref.ref[SelfReversibleProxy]] = None,
+ ) -> None:
cls._proxy_objects.pop(ref, None)
@classmethod
- def _regenerate_proxy_for_target(cls, target):
+ def _regenerate_proxy_for_target(
+ cls: Type[SelfReversibleProxy], target: _PT
+ ) -> SelfReversibleProxy:
raise NotImplementedError()
+ @overload
@classmethod
- def _retrieve_proxy_for_target(cls, target, regenerate=True):
+ def _retrieve_proxy_for_target(
+ cls: Type[SelfReversibleProxy],
+ target: _PT,
+ regenerate: Literal[True] = ...,
+ ) -> SelfReversibleProxy:
+ ...
+
+ @overload
+ @classmethod
+ def _retrieve_proxy_for_target(
+ cls: Type[SelfReversibleProxy], target: _PT, regenerate: bool = True
+ ) -> Optional[SelfReversibleProxy]:
+ ...
+
+ @classmethod
+ def _retrieve_proxy_for_target(
+ cls: Type[SelfReversibleProxy], target: _PT, regenerate: bool = True
+ ) -> Optional[SelfReversibleProxy]:
try:
proxy_ref = cls._proxy_objects[weakref.ref(target)]
except KeyError:
@@ -38,7 +105,7 @@ class ReversibleProxy:
else:
proxy = proxy_ref()
if proxy is not None:
- return proxy
+ return proxy # type: ignore
if regenerate:
return cls._regenerate_proxy_for_target(target)
@@ -46,43 +113,54 @@ class ReversibleProxy:
return None
+SelfStartableContext = TypeVar(
+ "SelfStartableContext", bound="StartableContext"
+)
+
+
class StartableContext(abc.ABC):
__slots__ = ()
@abc.abstractmethod
- async def start(self, is_ctxmanager=False):
- pass
+ async def start(
+ self: SelfStartableContext, is_ctxmanager: bool = False
+ ) -> Any:
+ raise NotImplementedError()
- def __await__(self):
+ def __await__(self) -> Any:
return self.start().__await__()
- async def __aenter__(self):
+ async def __aenter__(self: SelfStartableContext) -> Any:
return await self.start(is_ctxmanager=True)
@abc.abstractmethod
- async def __aexit__(self, type_, value, traceback):
+ async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
pass
- def _raise_for_not_started(self):
+ def _raise_for_not_started(self) -> NoReturn:
raise async_exc.AsyncContextNotStarted(
"%s context has not been started and object has not been awaited."
% (self.__class__.__name__)
)
-class ProxyComparable(ReversibleProxy):
+class ProxyComparable(ReversibleProxy[_PT]):
__slots__ = ()
- def __hash__(self):
+ @util.ro_non_memoized_property
+ def _proxied(self) -> _PT:
+ raise NotImplementedError()
+
+ def __hash__(self) -> int:
return id(self)
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return (
isinstance(other, self.__class__)
and self._proxied == other._proxied
)
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
return (
not isinstance(other, self.__class__)
or self._proxied != other._proxied
diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py
index 3b54405c1..bb51a4d22 100644
--- a/lib/sqlalchemy/ext/asyncio/engine.py
+++ b/lib/sqlalchemy/ext/asyncio/engine.py
@@ -7,23 +7,56 @@
from __future__ import annotations
from typing import Any
+from typing import Dict
+from typing import Generator
+from typing import NoReturn
+from typing import Optional
+from typing import overload
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
from . import exc as async_exc
from .base import ProxyComparable
from .base import StartableContext
from .result import _ensure_sync_result
from .result import AsyncResult
+from .result import AsyncScalarResult
from ... import exc
from ... import inspection
from ... import util
+from ...engine import Connection
from ...engine import create_engine as _create_engine
+from ...engine import Engine
from ...engine.base import NestedTransaction
-from ...future import Connection
-from ...future import Engine
+from ...engine.base import Transaction
from ...util.concurrency import greenlet_spawn
-
-
-def create_async_engine(*arg, **kw):
+from ...util.typing import Protocol
+
+if TYPE_CHECKING:
+ from ...engine import Connection
+ from ...engine import Engine
+ from ...engine.cursor import CursorResult
+ from ...engine.interfaces import _CoreAnyExecuteParams
+ from ...engine.interfaces import _CoreSingleExecuteParams
+ from ...engine.interfaces import _DBAPIAnyExecuteParams
+ from ...engine.interfaces import _ExecuteOptions
+ from ...engine.interfaces import _ExecuteOptionsParameter
+ from ...engine.interfaces import _IsolationLevel
+ from ...engine.interfaces import Dialect
+ from ...engine.result import ScalarResult
+ from ...engine.url import URL
+ from ...pool import Pool
+ from ...pool import PoolProxiedConnection
+ from ...sql.base import Executable
+
+
+class _SyncConnectionCallable(Protocol):
+ def __call__(self, connection: Connection, *arg: Any, **kw: Any) -> Any:
+ ...
+
+
+def create_async_engine(url: Union[str, URL], **kw: Any) -> AsyncEngine:
"""Create a new async engine instance.
Arguments passed to :func:`_asyncio.create_async_engine` are mostly
@@ -43,11 +76,13 @@ def create_async_engine(*arg, **kw):
)
kw["future"] = True
kw["_is_async"] = True
- sync_engine = _create_engine(*arg, **kw)
+ sync_engine = _create_engine(url, **kw)
return AsyncEngine(sync_engine)
-def async_engine_from_config(configuration, prefix="sqlalchemy.", **kwargs):
+def async_engine_from_config(
+ configuration: Dict[str, Any], prefix: str = "sqlalchemy.", **kwargs: Any
+) -> AsyncEngine:
"""Create a new AsyncEngine instance using a configuration dictionary.
This function is analogous to the :func:`_sa.engine_from_config` function
@@ -73,6 +108,14 @@ def async_engine_from_config(configuration, prefix="sqlalchemy.", **kwargs):
class AsyncConnectable:
__slots__ = "_slots_dispatch", "__weakref__"
+ @classmethod
+ def _no_async_engine_events(cls) -> NoReturn:
+ raise NotImplementedError(
+ "asynchronous events are not implemented at this time. Apply "
+ "synchronous listeners to the AsyncEngine.sync_engine or "
+ "AsyncConnection.sync_connection attributes."
+ )
+
@util.create_proxy_methods(
Connection,
@@ -87,7 +130,9 @@ class AsyncConnectable:
"default_isolation_level",
],
)
-class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
+class AsyncConnection(
+ ProxyComparable[Connection], StartableContext, AsyncConnectable
+):
"""An asyncio proxy for a :class:`_engine.Connection`.
:class:`_asyncio.AsyncConnection` is acquired using the
@@ -115,12 +160,16 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
"sync_connection",
)
- def __init__(self, async_engine, sync_connection=None):
+ def __init__(
+ self,
+ async_engine: AsyncEngine,
+ sync_connection: Optional[Connection] = None,
+ ):
self.engine = async_engine
self.sync_engine = async_engine.sync_engine
self.sync_connection = self._assign_proxied(sync_connection)
- sync_connection: Connection
+ sync_connection: Optional[Connection]
"""Reference to the sync-style :class:`_engine.Connection` this
:class:`_asyncio.AsyncConnection` proxies requests towards.
@@ -146,12 +195,14 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
"""
@classmethod
- def _regenerate_proxy_for_target(cls, target):
+ def _regenerate_proxy_for_target(
+ cls, target: Connection
+ ) -> AsyncConnection:
return AsyncConnection(
AsyncEngine._retrieve_proxy_for_target(target.engine), target
)
- async def start(self, is_ctxmanager=False):
+ async def start(self, is_ctxmanager: bool = False) -> AsyncConnection:
"""Start this :class:`_asyncio.AsyncConnection` object's context
outside of using a Python ``with:`` block.
@@ -164,7 +215,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
return self
@property
- def connection(self):
+ def connection(self) -> NoReturn:
"""Not implemented for async; call
:meth:`_asyncio.AsyncConnection.get_raw_connection`.
"""
@@ -174,7 +225,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
"Use the get_raw_connection() method."
)
- async def get_raw_connection(self):
+ async def get_raw_connection(self) -> PoolProxiedConnection:
"""Return the pooled DBAPI-level connection in use by this
:class:`_asyncio.AsyncConnection`.
@@ -187,16 +238,11 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
adapts the driver connection to the DBAPI protocol.
"""
- conn = self._sync_connection()
-
- return await greenlet_spawn(getattr, conn, "connection")
- @property
- def _proxied(self):
- return self.sync_connection
+ return await greenlet_spawn(getattr, self._proxied, "connection")
@property
- def info(self):
+ def info(self) -> Dict[str, Any]:
"""Return the :attr:`_engine.Connection.info` dictionary of the
underlying :class:`_engine.Connection`.
@@ -211,24 +257,28 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
.. versionadded:: 1.4.0b2
"""
- return self.sync_connection.info
+ return self._proxied.info
- def _sync_connection(self):
+ @util.ro_non_memoized_property
+ def _proxied(self) -> Connection:
if not self.sync_connection:
self._raise_for_not_started()
return self.sync_connection
- def begin(self):
+ def begin(self) -> AsyncTransaction:
"""Begin a transaction prior to autobegin occurring."""
- self._sync_connection()
+ assert self._proxied
return AsyncTransaction(self)
- def begin_nested(self):
+ def begin_nested(self) -> AsyncTransaction:
"""Begin a nested transaction and return a transaction handle."""
- self._sync_connection()
+ assert self._proxied
return AsyncTransaction(self, nested=True)
- async def invalidate(self, exception=None):
+ async def invalidate(
+ self, exception: Optional[BaseException] = None
+ ) -> None:
+
"""Invalidate the underlying DBAPI connection associated with
this :class:`_engine.Connection`.
@@ -237,39 +287,27 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
"""
- conn = self._sync_connection()
- return await greenlet_spawn(conn.invalidate, exception=exception)
-
- async def get_isolation_level(self):
- conn = self._sync_connection()
- return await greenlet_spawn(conn.get_isolation_level)
-
- async def set_isolation_level(self):
- conn = self._sync_connection()
- return await greenlet_spawn(conn.get_isolation_level)
-
- def in_transaction(self):
- """Return True if a transaction is in progress.
-
- .. versionadded:: 1.4.0b2
+ return await greenlet_spawn(
+ self._proxied.invalidate, exception=exception
+ )
- """
+ async def get_isolation_level(self) -> _IsolationLevel:
+ return await greenlet_spawn(self._proxied.get_isolation_level)
- conn = self._sync_connection()
+ def in_transaction(self) -> bool:
+ """Return True if a transaction is in progress."""
- return conn.in_transaction()
+ return self._proxied.in_transaction()
- def in_nested_transaction(self):
+ def in_nested_transaction(self) -> bool:
"""Return True if a transaction is in progress.
.. versionadded:: 1.4.0b2
"""
- conn = self._sync_connection()
-
- return conn.in_nested_transaction()
+ return self._proxied.in_nested_transaction()
- def get_transaction(self):
+ def get_transaction(self) -> Optional[AsyncTransaction]:
"""Return an :class:`.AsyncTransaction` representing the current
transaction, if any.
@@ -281,15 +319,14 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
.. versionadded:: 1.4.0b2
"""
- conn = self._sync_connection()
- trans = conn.get_transaction()
+ trans = self._proxied.get_transaction()
if trans is not None:
return AsyncTransaction._retrieve_proxy_for_target(trans)
else:
return None
- def get_nested_transaction(self):
+ def get_nested_transaction(self) -> Optional[AsyncTransaction]:
"""Return an :class:`.AsyncTransaction` representing the current
nested (savepoint) transaction, if any.
@@ -301,15 +338,14 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
.. versionadded:: 1.4.0b2
"""
- conn = self._sync_connection()
- trans = conn.get_nested_transaction()
+ trans = self._proxied.get_nested_transaction()
if trans is not None:
return AsyncTransaction._retrieve_proxy_for_target(trans)
else:
return None
- async def execution_options(self, **opt):
+ async def execution_options(self, **opt: Any) -> AsyncConnection:
r"""Set non-SQL options for the connection which take effect
during execution.
@@ -321,12 +357,12 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
"""
- conn = self._sync_connection()
+ conn = self._proxied
c2 = await greenlet_spawn(conn.execution_options, **opt)
assert c2 is conn
return self
- async def commit(self):
+ async def commit(self) -> None:
"""Commit the transaction that is currently in progress.
This method commits the current transaction if one has been started.
@@ -338,10 +374,9 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
:meth:`_future.Connection.begin` method is called.
"""
- conn = self._sync_connection()
- await greenlet_spawn(conn.commit)
+ await greenlet_spawn(self._proxied.commit)
- async def rollback(self):
+ async def rollback(self) -> None:
"""Roll back the transaction that is currently in progress.
This method rolls back the current transaction if one has been started.
@@ -355,34 +390,30 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
"""
- conn = self._sync_connection()
- await greenlet_spawn(conn.rollback)
+ await greenlet_spawn(self._proxied.rollback)
- async def close(self):
+ async def close(self) -> None:
"""Close this :class:`_asyncio.AsyncConnection`.
This has the effect of also rolling back the transaction if one
is in place.
"""
- conn = self._sync_connection()
- await greenlet_spawn(conn.close)
+ await greenlet_spawn(self._proxied.close)
async def exec_driver_sql(
self,
- statement,
- parameters=None,
- execution_options=util.EMPTY_DICT,
- ):
+ statement: str,
+ parameters: Optional[_DBAPIAnyExecuteParams] = None,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> CursorResult:
r"""Executes a driver-level SQL string and return buffered
:class:`_engine.Result`.
"""
- conn = self._sync_connection()
-
result = await greenlet_spawn(
- conn.exec_driver_sql,
+ self._proxied.exec_driver_sql,
statement,
parameters,
execution_options,
@@ -393,17 +424,15 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
async def stream(
self,
- statement,
- parameters=None,
- execution_options=util.EMPTY_DICT,
- ):
+ statement: Executable,
+ parameters: Optional[_CoreAnyExecuteParams] = None,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> AsyncResult:
"""Execute a statement and return a streaming
:class:`_asyncio.AsyncResult` object."""
- conn = self._sync_connection()
-
result = await greenlet_spawn(
- conn.execute,
+ self._proxied.execute,
statement,
parameters,
util.EMPTY_DICT.merge_with(
@@ -418,10 +447,10 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
async def execute(
self,
- statement,
- parameters=None,
- execution_options=util.EMPTY_DICT,
- ):
+ statement: Executable,
+ parameters: Optional[_CoreAnyExecuteParams] = None,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> CursorResult:
r"""Executes a SQL statement construct and return a buffered
:class:`_engine.Result`.
@@ -453,10 +482,8 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
:return: a :class:`_engine.Result` object.
"""
- conn = self._sync_connection()
-
result = await greenlet_spawn(
- conn.execute,
+ self._proxied.execute,
statement,
parameters,
execution_options,
@@ -466,10 +493,10 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
async def scalar(
self,
- statement,
- parameters=None,
- execution_options=util.EMPTY_DICT,
- ):
+ statement: Executable,
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> Any:
r"""Executes a SQL statement construct and returns a scalar object.
This method is shorthand for invoking the
@@ -485,10 +512,10 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
async def scalars(
self,
- statement,
- parameters=None,
- execution_options=util.EMPTY_DICT,
- ):
+ statement: Executable,
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> ScalarResult[Any]:
r"""Executes a SQL statement construct and returns a scalar objects.
This method is shorthand for invoking the
@@ -505,10 +532,10 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
async def stream_scalars(
self,
- statement,
- parameters=None,
- execution_options=util.EMPTY_DICT,
- ):
+ statement: Executable,
+ parameters: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> AsyncScalarResult[Any]:
r"""Executes a SQL statement and returns a streaming scalar result
object.
@@ -524,7 +551,9 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
result = await self.stream(statement, parameters, execution_options)
return result.scalars()
- async def run_sync(self, fn, *arg, **kw):
+ async def run_sync(
+ self, fn: _SyncConnectionCallable, *arg: Any, **kw: Any
+ ) -> Any:
"""Invoke the given sync callable passing self as the first argument.
This method maintains the asyncio event loop all the way through
@@ -548,14 +577,12 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
:ref:`session_run_sync`
"""
- conn = self._sync_connection()
-
- return await greenlet_spawn(fn, conn, *arg, **kw)
+ return await greenlet_spawn(fn, self._proxied, *arg, **kw)
- def __await__(self):
+ def __await__(self) -> Generator[Any, None, AsyncConnection]:
return self.start().__await__()
- async def __aexit__(self, type_, value, traceback):
+ async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
await self.close()
# START PROXY METHODS AsyncConnection
@@ -661,7 +688,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
],
attributes=["url", "pool", "dialect", "engine", "name", "driver", "echo"],
)
-class AsyncEngine(ProxyComparable, AsyncConnectable):
+class AsyncEngine(ProxyComparable[Engine], AsyncConnectable):
"""An asyncio proxy for a :class:`_engine.Engine`.
:class:`_asyncio.AsyncEngine` is acquired using the
@@ -679,51 +706,60 @@ class AsyncEngine(ProxyComparable, AsyncConnectable):
# current transaction, info, etc. It should be possible to
# create a new AsyncEngine that matches this one given only the
# "sync" elements.
- __slots__ = ("sync_engine", "_proxied")
+ __slots__ = "sync_engine"
- _connection_cls = AsyncConnection
+ _connection_cls: Type[AsyncConnection] = AsyncConnection
- _option_cls: type
+ sync_engine: Engine
+ """Reference to the sync-style :class:`_engine.Engine` this
+ :class:`_asyncio.AsyncEngine` proxies requests towards.
+
+ This instance can be used as an event target.
+
+ .. seealso::
+
+ :ref:`asyncio_events`
+ """
class _trans_ctx(StartableContext):
- def __init__(self, conn):
+ __slots__ = ("conn", "transaction")
+
+ conn: AsyncConnection
+ transaction: AsyncTransaction
+
+ def __init__(self, conn: AsyncConnection):
self.conn = conn
- async def start(self, is_ctxmanager=False):
+ async def start(self, is_ctxmanager: bool = False) -> AsyncConnection:
await self.conn.start(is_ctxmanager=is_ctxmanager)
self.transaction = self.conn.begin()
await self.transaction.__aenter__()
return self.conn
- async def __aexit__(self, type_, value, traceback):
+ async def __aexit__(
+ self, type_: Any, value: Any, traceback: Any
+ ) -> None:
await self.transaction.__aexit__(type_, value, traceback)
await self.conn.close()
- def __init__(self, sync_engine):
+ def __init__(self, sync_engine: Engine):
if not sync_engine.dialect.is_async:
raise exc.InvalidRequestError(
"The asyncio extension requires an async driver to be used. "
f"The loaded {sync_engine.dialect.driver!r} is not async."
)
- self.sync_engine = self._proxied = self._assign_proxied(sync_engine)
-
- sync_engine: Engine
- """Reference to the sync-style :class:`_engine.Engine` this
- :class:`_asyncio.AsyncEngine` proxies requests towards.
+ self.sync_engine = self._assign_proxied(sync_engine)
- This instance can be used as an event target.
-
- .. seealso::
-
- :ref:`asyncio_events`
- """
+ @util.ro_non_memoized_property
+ def _proxied(self) -> Engine:
+ return self.sync_engine
@classmethod
- def _regenerate_proxy_for_target(cls, target):
+ def _regenerate_proxy_for_target(cls, target: Engine) -> AsyncEngine:
return AsyncEngine(target)
- def begin(self):
+ def begin(self) -> AsyncEngine._trans_ctx:
"""Return a context manager which when entered will deliver an
:class:`_asyncio.AsyncConnection` with an
:class:`_asyncio.AsyncTransaction` established.
@@ -741,7 +777,7 @@ class AsyncEngine(ProxyComparable, AsyncConnectable):
conn = self.connect()
return self._trans_ctx(conn)
- def connect(self):
+ def connect(self) -> AsyncConnection:
"""Return an :class:`_asyncio.AsyncConnection` object.
The :class:`_asyncio.AsyncConnection` will procure a database
@@ -759,7 +795,7 @@ class AsyncEngine(ProxyComparable, AsyncConnectable):
return self._connection_cls(self)
- async def raw_connection(self):
+ async def raw_connection(self) -> PoolProxiedConnection:
"""Return a "raw" DBAPI connection from the connection pool.
.. seealso::
@@ -769,7 +805,7 @@ class AsyncEngine(ProxyComparable, AsyncConnectable):
"""
return await greenlet_spawn(self.sync_engine.raw_connection)
- def execution_options(self, **opt):
+ def execution_options(self, **opt: Any) -> AsyncEngine:
"""Return a new :class:`_asyncio.AsyncEngine` that will provide
:class:`_asyncio.AsyncConnection` objects with the given execution
options.
@@ -781,21 +817,31 @@ class AsyncEngine(ProxyComparable, AsyncConnectable):
return AsyncEngine(self.sync_engine.execution_options(**opt))
- async def dispose(self):
+ async def dispose(self, close: bool = True) -> None:
+
"""Dispose of the connection pool used by this
:class:`_asyncio.AsyncEngine`.
- This will close all connection pool connections that are
- **currently checked in**. See the documentation for the underlying
- :meth:`_future.Engine.dispose` method for further notes.
+ :param close: if left at its default of ``True``, has the
+ effect of fully closing all **currently checked in**
+ database connections. Connections that are still checked out
+ will **not** be closed, however they will no longer be associated
+ with this :class:`_engine.Engine`,
+ so when they are closed individually, eventually the
+ :class:`_pool.Pool` which they are associated with will
+ be garbage collected and they will be closed out fully, if
+ not already closed on checkin.
+
+ If set to ``False``, the previous connection pool is de-referenced,
+ and otherwise not touched in any way.
.. seealso::
- :meth:`_future.Engine.dispose`
+ :meth:`_engine.Engine.dispose`
"""
- return await greenlet_spawn(self.sync_engine.dispose)
+ return await greenlet_spawn(self.sync_engine.dispose, close=close)
# START PROXY METHODS AsyncEngine
@@ -973,18 +1019,24 @@ class AsyncEngine(ProxyComparable, AsyncConnectable):
# END PROXY METHODS AsyncEngine
-class AsyncTransaction(ProxyComparable, StartableContext):
+class AsyncTransaction(ProxyComparable[Transaction], StartableContext):
"""An asyncio proxy for a :class:`_engine.Transaction`."""
__slots__ = ("connection", "sync_transaction", "nested")
- def __init__(self, connection, nested=False):
- self.connection = connection # AsyncConnection
- self.sync_transaction = None # sqlalchemy.engine.Transaction
+ sync_transaction: Optional[Transaction]
+ connection: AsyncConnection
+ nested: bool
+
+ def __init__(self, connection: AsyncConnection, nested: bool = False):
+ self.connection = connection
+ self.sync_transaction = None
self.nested = nested
@classmethod
- def _regenerate_proxy_for_target(cls, target):
+ def _regenerate_proxy_for_target(
+ cls, target: Transaction
+ ) -> AsyncTransaction:
sync_connection = target.connection
sync_transaction = target
nested = isinstance(target, NestedTransaction)
@@ -1000,25 +1052,22 @@ class AsyncTransaction(ProxyComparable, StartableContext):
obj.nested = nested
return obj
- def _sync_transaction(self):
+ @util.ro_non_memoized_property
+ def _proxied(self) -> Transaction:
if not self.sync_transaction:
self._raise_for_not_started()
return self.sync_transaction
@property
- def _proxied(self):
- return self.sync_transaction
+ def is_valid(self) -> bool:
+ return self._proxied.is_valid
@property
- def is_valid(self):
- return self._sync_transaction().is_valid
+ def is_active(self) -> bool:
+ return self._proxied.is_active
- @property
- def is_active(self):
- return self._sync_transaction().is_active
-
- async def close(self):
- """Close this :class:`.Transaction`.
+ async def close(self) -> None:
+ """Close this :class:`.AsyncTransaction`.
If this transaction is the base transaction in a begin/commit
nesting, the transaction will rollback(). Otherwise, the
@@ -1028,18 +1077,18 @@ class AsyncTransaction(ProxyComparable, StartableContext):
an enclosing transaction.
"""
- await greenlet_spawn(self._sync_transaction().close)
+ await greenlet_spawn(self._proxied.close)
- async def rollback(self):
- """Roll back this :class:`.Transaction`."""
- await greenlet_spawn(self._sync_transaction().rollback)
+ async def rollback(self) -> None:
+ """Roll back this :class:`.AsyncTransaction`."""
+ await greenlet_spawn(self._proxied.rollback)
- async def commit(self):
- """Commit this :class:`.Transaction`."""
+ async def commit(self) -> None:
+ """Commit this :class:`.AsyncTransaction`."""
- await greenlet_spawn(self._sync_transaction().commit)
+ await greenlet_spawn(self._proxied.commit)
- async def start(self, is_ctxmanager=False):
+ async def start(self, is_ctxmanager: bool = False) -> AsyncTransaction:
"""Start this :class:`_asyncio.AsyncTransaction` object's context
outside of using a Python ``with:`` block.
@@ -1047,24 +1096,36 @@ class AsyncTransaction(ProxyComparable, StartableContext):
self.sync_transaction = self._assign_proxied(
await greenlet_spawn(
- self.connection._sync_connection().begin_nested
+ self.connection._proxied.begin_nested
if self.nested
- else self.connection._sync_connection().begin
+ else self.connection._proxied.begin
)
)
if is_ctxmanager:
self.sync_transaction.__enter__()
return self
- async def __aexit__(self, type_, value, traceback):
- await greenlet_spawn(
- self._sync_transaction().__exit__, type_, value, traceback
- )
+ async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
+ await greenlet_spawn(self._proxied.__exit__, type_, value, traceback)
+
+
+@overload
+def _get_sync_engine_or_connection(async_engine: AsyncEngine) -> Engine:
+ ...
+
+
+@overload
+def _get_sync_engine_or_connection(
+ async_engine: AsyncConnection,
+) -> Connection:
+ ...
-def _get_sync_engine_or_connection(async_engine):
+def _get_sync_engine_or_connection(
+ async_engine: Union[AsyncEngine, AsyncConnection]
+) -> Union[Engine, Connection]:
if isinstance(async_engine, AsyncConnection):
- return async_engine.sync_connection
+ return async_engine._proxied
try:
return async_engine.sync_engine
@@ -1075,7 +1136,7 @@ def _get_sync_engine_or_connection(async_engine):
@inspection._inspects(AsyncConnection)
-def _no_insp_for_async_conn_yet(subject):
+def _no_insp_for_async_conn_yet(subject: AsyncConnection) -> NoReturn:
raise exc.NoInspectionAvailable(
"Inspection on an AsyncConnection is currently not supported. "
"Please use ``run_sync`` to pass a callable where it's possible "
@@ -1085,7 +1146,7 @@ def _no_insp_for_async_conn_yet(subject):
@inspection._inspects(AsyncEngine)
-def _no_insp_for_async_engine_xyet(subject):
+def _no_insp_for_async_engine_xyet(subject: AsyncEngine) -> NoReturn:
raise exc.NoInspectionAvailable(
"Inspection on an AsyncEngine is currently not supported. "
"Please obtain a connection then use ``conn.run_sync`` to pass a "
diff --git a/lib/sqlalchemy/ext/asyncio/events.py b/lib/sqlalchemy/ext/asyncio/events.py
deleted file mode 100644
index c5d5e0126..000000000
--- a/lib/sqlalchemy/ext/asyncio/events.py
+++ /dev/null
@@ -1,44 +0,0 @@
-# ext/asyncio/events.py
-# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
-# <see AUTHORS file>
-#
-# This module is part of SQLAlchemy and is released under
-# the MIT License: https://www.opensource.org/licenses/mit-license.php
-
-from .engine import AsyncConnectable
-from .session import AsyncSession
-from ...engine import events as engine_event
-from ...orm import events as orm_event
-
-
-class AsyncConnectionEvents(engine_event.ConnectionEvents):
- _target_class_doc = "SomeEngine"
- _dispatch_target = AsyncConnectable
-
- @classmethod
- def _no_async_engine_events(cls):
- raise NotImplementedError(
- "asynchronous events are not implemented at this time. Apply "
- "synchronous listeners to the AsyncEngine.sync_engine or "
- "AsyncConnection.sync_connection attributes."
- )
-
- @classmethod
- def _listen(cls, event_key, retval=False):
- cls._no_async_engine_events()
-
-
-class AsyncSessionEvents(orm_event.SessionEvents):
- _target_class_doc = "SomeSession"
- _dispatch_target = AsyncSession
-
- @classmethod
- def _no_async_engine_events(cls):
- raise NotImplementedError(
- "asynchronous events are not implemented at this time. Apply "
- "synchronous listeners to the AsyncSession.sync_session."
- )
-
- @classmethod
- def _listen(cls, event_key, retval=False):
- cls._no_async_engine_events()
diff --git a/lib/sqlalchemy/ext/asyncio/result.py b/lib/sqlalchemy/ext/asyncio/result.py
index 39718735c..a9db822a6 100644
--- a/lib/sqlalchemy/ext/asyncio/result.py
+++ b/lib/sqlalchemy/ext/asyncio/result.py
@@ -4,25 +4,49 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+from __future__ import annotations
import operator
+from typing import Any
+from typing import AsyncIterator
+from typing import List
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import TypeVar
from . import exc as async_exc
from ...engine.result import _NO_ROW
+from ...engine.result import _R
from ...engine.result import FilterResult
from ...engine.result import FrozenResult
from ...engine.result import MergedResult
+from ...engine.result import ResultMetaData
+from ...engine.row import Row
+from ...engine.row import RowMapping
from ...util.concurrency import greenlet_spawn
+if TYPE_CHECKING:
+ from ...engine import CursorResult
+ from ...engine import Result
+ from ...engine.result import _KeyIndexType
+ from ...engine.result import _UniqueFilterType
+ from ...engine.result import RMKeyView
-class AsyncCommon(FilterResult):
- async def close(self):
+
+class AsyncCommon(FilterResult[_R]):
+ _real_result: Result
+ _metadata: ResultMetaData
+
+ async def close(self) -> None:
"""Close this result."""
await greenlet_spawn(self._real_result.close)
-class AsyncResult(AsyncCommon):
+SelfAsyncResult = TypeVar("SelfAsyncResult", bound="AsyncResult")
+
+
+class AsyncResult(AsyncCommon[Row]):
"""An asyncio wrapper around a :class:`_result.Result` object.
The :class:`_asyncio.AsyncResult` only applies to statement executions that
@@ -43,7 +67,7 @@ class AsyncResult(AsyncCommon):
"""
- def __init__(self, real_result):
+ def __init__(self, real_result: Result):
self._real_result = real_result
self._metadata = real_result._metadata
@@ -56,14 +80,16 @@ class AsyncResult(AsyncCommon):
"_row_getter", real_result.__dict__["_row_getter"]
)
- def keys(self):
+ def keys(self) -> RMKeyView:
"""Return the :meth:`_engine.Result.keys` collection from the
underlying :class:`_engine.Result`.
"""
return self._metadata.keys
- def unique(self, strategy=None):
+ def unique(
+ self: SelfAsyncResult, strategy: Optional[_UniqueFilterType] = None
+ ) -> SelfAsyncResult:
"""Apply unique filtering to the objects returned by this
:class:`_asyncio.AsyncResult`.
@@ -75,7 +101,9 @@ class AsyncResult(AsyncCommon):
self._unique_filter_state = (set(), strategy)
return self
- def columns(self, *col_expressions):
+ def columns(
+ self: SelfAsyncResult, *col_expressions: _KeyIndexType
+ ) -> SelfAsyncResult:
r"""Establish the columns that should be returned in each row.
Refer to :meth:`_engine.Result.columns` in the synchronous
@@ -85,7 +113,9 @@ class AsyncResult(AsyncCommon):
"""
return self._column_slices(col_expressions)
- async def partitions(self, size=None):
+ async def partitions(
+ self, size: Optional[int] = None
+ ) -> AsyncIterator[List[Row]]:
"""Iterate through sub-lists of rows of the size given.
An async iterator is returned::
@@ -111,7 +141,7 @@ class AsyncResult(AsyncCommon):
else:
break
- async def fetchone(self):
+ async def fetchone(self) -> Optional[Row]:
"""Fetch one row.
When all rows are exhausted, returns None.
@@ -131,9 +161,9 @@ class AsyncResult(AsyncCommon):
if row is _NO_ROW:
return None
else:
- return row # type: ignore[return-value]
+ return row
- async def fetchmany(self, size=None):
+ async def fetchmany(self, size: Optional[int] = None) -> List[Row]:
"""Fetch many rows.
When all rows are exhausted, returns an empty list.
@@ -152,11 +182,9 @@ class AsyncResult(AsyncCommon):
"""
- return await greenlet_spawn(
- self._manyrow_getter, self, size # type: ignore
- )
+ return await greenlet_spawn(self._manyrow_getter, self, size)
- async def all(self):
+ async def all(self) -> List[Row]:
"""Return all rows in a list.
Closes the result set after invocation. Subsequent invocations
@@ -166,19 +194,19 @@ class AsyncResult(AsyncCommon):
"""
- return await greenlet_spawn(self._allrows) # type: ignore
+ return await greenlet_spawn(self._allrows)
- def __aiter__(self):
+ def __aiter__(self) -> AsyncResult:
return self
- async def __anext__(self):
+ async def __anext__(self) -> Row:
row = await greenlet_spawn(self._onerow_getter, self)
if row is _NO_ROW:
raise StopAsyncIteration()
else:
return row
- async def first(self):
+ async def first(self) -> Optional[Row]:
"""Fetch the first row or None if no row is present.
Closes the result set and discards remaining rows.
@@ -201,7 +229,7 @@ class AsyncResult(AsyncCommon):
"""
return await greenlet_spawn(self._only_one_row, False, False, False)
- async def one_or_none(self):
+ async def one_or_none(self) -> Optional[Row]:
"""Return at most one result or raise an exception.
Returns ``None`` if the result has no rows.
@@ -223,7 +251,7 @@ class AsyncResult(AsyncCommon):
"""
return await greenlet_spawn(self._only_one_row, True, False, False)
- async def scalar_one(self):
+ async def scalar_one(self) -> Any:
"""Return exactly one scalar result or raise an exception.
This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and
@@ -238,7 +266,7 @@ class AsyncResult(AsyncCommon):
"""
return await greenlet_spawn(self._only_one_row, True, True, True)
- async def scalar_one_or_none(self):
+ async def scalar_one_or_none(self) -> Optional[Any]:
"""Return exactly one or no scalar result.
This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and
@@ -253,7 +281,7 @@ class AsyncResult(AsyncCommon):
"""
return await greenlet_spawn(self._only_one_row, True, False, True)
- async def one(self):
+ async def one(self) -> Row:
"""Return exactly one row or raise an exception.
Raises :class:`.NoResultFound` if the result returns no
@@ -284,7 +312,7 @@ class AsyncResult(AsyncCommon):
"""
return await greenlet_spawn(self._only_one_row, True, True, False)
- async def scalar(self):
+ async def scalar(self) -> Any:
"""Fetch the first column of the first row, and close the result set.
Returns None if there are no rows to fetch.
@@ -300,7 +328,7 @@ class AsyncResult(AsyncCommon):
"""
return await greenlet_spawn(self._only_one_row, False, False, True)
- async def freeze(self):
+ async def freeze(self) -> FrozenResult:
"""Return a callable object that will produce copies of this
:class:`_asyncio.AsyncResult` when invoked.
@@ -323,7 +351,7 @@ class AsyncResult(AsyncCommon):
return await greenlet_spawn(FrozenResult, self)
- def merge(self, *others):
+ def merge(self, *others: AsyncResult) -> MergedResult:
"""Merge this :class:`_asyncio.AsyncResult` with other compatible result
objects.
@@ -337,9 +365,12 @@ class AsyncResult(AsyncCommon):
undefined.
"""
- return MergedResult(self._metadata, (self,) + others)
+ return MergedResult(
+ self._metadata,
+ (self._real_result,) + tuple(o._real_result for o in others),
+ )
- def scalars(self, index=0):
+ def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]:
"""Return an :class:`_asyncio.AsyncScalarResult` filtering object which
will return single elements rather than :class:`_row.Row` objects.
@@ -355,7 +386,7 @@ class AsyncResult(AsyncCommon):
"""
return AsyncScalarResult(self._real_result, index)
- def mappings(self):
+ def mappings(self) -> AsyncMappingResult:
"""Apply a mappings filter to returned rows, returning an instance of
:class:`_asyncio.AsyncMappingResult`.
@@ -373,7 +404,12 @@ class AsyncResult(AsyncCommon):
return AsyncMappingResult(self._real_result)
-class AsyncScalarResult(AsyncCommon):
+SelfAsyncScalarResult = TypeVar(
+ "SelfAsyncScalarResult", bound="AsyncScalarResult[Any]"
+)
+
+
+class AsyncScalarResult(AsyncCommon[_R]):
"""A wrapper for a :class:`_asyncio.AsyncResult` that returns scalar values
rather than :class:`_row.Row` values.
@@ -389,7 +425,7 @@ class AsyncScalarResult(AsyncCommon):
_generate_rows = False
- def __init__(self, real_result, index):
+ def __init__(self, real_result: Result, index: _KeyIndexType):
self._real_result = real_result
if real_result._source_supports_scalars:
@@ -401,7 +437,10 @@ class AsyncScalarResult(AsyncCommon):
self._unique_filter_state = real_result._unique_filter_state
- def unique(self, strategy=None):
+ def unique(
+ self: SelfAsyncScalarResult,
+ strategy: Optional[_UniqueFilterType] = None,
+ ) -> SelfAsyncScalarResult:
"""Apply unique filtering to the objects returned by this
:class:`_asyncio.AsyncScalarResult`.
@@ -411,7 +450,9 @@ class AsyncScalarResult(AsyncCommon):
self._unique_filter_state = (set(), strategy)
return self
- async def partitions(self, size=None):
+ async def partitions(
+ self, size: Optional[int] = None
+ ) -> AsyncIterator[List[_R]]:
"""Iterate through sub-lists of elements of the size given.
Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that
@@ -429,12 +470,12 @@ class AsyncScalarResult(AsyncCommon):
else:
break
- async def fetchall(self):
+ async def fetchall(self) -> List[_R]:
"""A synonym for the :meth:`_asyncio.AsyncScalarResult.all` method."""
return await greenlet_spawn(self._allrows)
- async def fetchmany(self, size=None):
+ async def fetchmany(self, size: Optional[int] = None) -> List[_R]:
"""Fetch many objects.
Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that
@@ -444,7 +485,7 @@ class AsyncScalarResult(AsyncCommon):
"""
return await greenlet_spawn(self._manyrow_getter, self, size)
- async def all(self):
+ async def all(self) -> List[_R]:
"""Return all scalar values in a list.
Equivalent to :meth:`_asyncio.AsyncResult.all` except that
@@ -454,17 +495,17 @@ class AsyncScalarResult(AsyncCommon):
"""
return await greenlet_spawn(self._allrows)
- def __aiter__(self):
+ def __aiter__(self) -> AsyncScalarResult[_R]:
return self
- async def __anext__(self):
+ async def __anext__(self) -> _R:
row = await greenlet_spawn(self._onerow_getter, self)
if row is _NO_ROW:
raise StopAsyncIteration()
else:
return row
- async def first(self):
+ async def first(self) -> Optional[_R]:
"""Fetch the first object or None if no object is present.
Equivalent to :meth:`_asyncio.AsyncResult.first` except that
@@ -474,7 +515,7 @@ class AsyncScalarResult(AsyncCommon):
"""
return await greenlet_spawn(self._only_one_row, False, False, False)
- async def one_or_none(self):
+ async def one_or_none(self) -> Optional[_R]:
"""Return at most one object or raise an exception.
Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that
@@ -484,7 +525,7 @@ class AsyncScalarResult(AsyncCommon):
"""
return await greenlet_spawn(self._only_one_row, True, False, False)
- async def one(self):
+ async def one(self) -> _R:
"""Return exactly one object or raise an exception.
Equivalent to :meth:`_asyncio.AsyncResult.one` except that
@@ -495,7 +536,12 @@ class AsyncScalarResult(AsyncCommon):
return await greenlet_spawn(self._only_one_row, True, True, False)
-class AsyncMappingResult(AsyncCommon):
+SelfAsyncMappingResult = TypeVar(
+ "SelfAsyncMappingResult", bound="AsyncMappingResult"
+)
+
+
+class AsyncMappingResult(AsyncCommon[RowMapping]):
"""A wrapper for a :class:`_asyncio.AsyncResult` that returns dictionary values
rather than :class:`_engine.Row` values.
@@ -513,14 +559,14 @@ class AsyncMappingResult(AsyncCommon):
_post_creational_filter = operator.attrgetter("_mapping")
- def __init__(self, result):
+ def __init__(self, result: Result):
self._real_result = result
self._unique_filter_state = result._unique_filter_state
self._metadata = result._metadata
if result._source_supports_scalars:
self._metadata = self._metadata._reduce([0])
- def keys(self):
+ def keys(self) -> RMKeyView:
"""Return an iterable view which yields the string keys that would
be represented by each :class:`.Row`.
@@ -535,7 +581,10 @@ class AsyncMappingResult(AsyncCommon):
"""
return self._metadata.keys
- def unique(self, strategy=None):
+ def unique(
+ self: SelfAsyncMappingResult,
+ strategy: Optional[_UniqueFilterType] = None,
+ ) -> SelfAsyncMappingResult:
"""Apply unique filtering to the objects returned by this
:class:`_asyncio.AsyncMappingResult`.
@@ -545,11 +594,16 @@ class AsyncMappingResult(AsyncCommon):
self._unique_filter_state = (set(), strategy)
return self
- def columns(self, *col_expressions):
+ def columns(
+ self: SelfAsyncMappingResult, *col_expressions: _KeyIndexType
+ ) -> SelfAsyncMappingResult:
r"""Establish the columns that should be returned in each row."""
return self._column_slices(col_expressions)
- async def partitions(self, size=None):
+ async def partitions(
+ self, size: Optional[int] = None
+ ) -> AsyncIterator[List[RowMapping]]:
+
"""Iterate through sub-lists of elements of the size given.
Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that
@@ -567,12 +621,12 @@ class AsyncMappingResult(AsyncCommon):
else:
break
- async def fetchall(self):
+ async def fetchall(self) -> List[RowMapping]:
"""A synonym for the :meth:`_asyncio.AsyncMappingResult.all` method."""
return await greenlet_spawn(self._allrows)
- async def fetchone(self):
+ async def fetchone(self) -> Optional[RowMapping]:
"""Fetch one object.
Equivalent to :meth:`_asyncio.AsyncResult.fetchone` except that
@@ -587,8 +641,8 @@ class AsyncMappingResult(AsyncCommon):
else:
return row
- async def fetchmany(self, size=None):
- """Fetch many objects.
+ async def fetchmany(self, size: Optional[int] = None) -> List[RowMapping]:
+ """Fetch many rows.
Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that
:class:`_result.RowMapping` values, rather than :class:`_result.Row`
@@ -598,8 +652,8 @@ class AsyncMappingResult(AsyncCommon):
return await greenlet_spawn(self._manyrow_getter, self, size)
- async def all(self):
- """Return all scalar values in a list.
+ async def all(self) -> List[RowMapping]:
+ """Return all rows in a list.
Equivalent to :meth:`_asyncio.AsyncResult.all` except that
:class:`_result.RowMapping` values, rather than :class:`_result.Row`
@@ -609,17 +663,17 @@ class AsyncMappingResult(AsyncCommon):
return await greenlet_spawn(self._allrows)
- def __aiter__(self):
+ def __aiter__(self) -> AsyncMappingResult:
return self
- async def __anext__(self):
+ async def __anext__(self) -> RowMapping:
row = await greenlet_spawn(self._onerow_getter, self)
if row is _NO_ROW:
raise StopAsyncIteration()
else:
return row
- async def first(self):
+ async def first(self) -> Optional[RowMapping]:
"""Fetch the first object or None if no object is present.
Equivalent to :meth:`_asyncio.AsyncResult.first` except that
@@ -630,7 +684,7 @@ class AsyncMappingResult(AsyncCommon):
"""
return await greenlet_spawn(self._only_one_row, False, False, False)
- async def one_or_none(self):
+ async def one_or_none(self) -> Optional[RowMapping]:
"""Return at most one object or raise an exception.
Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that
@@ -640,7 +694,7 @@ class AsyncMappingResult(AsyncCommon):
"""
return await greenlet_spawn(self._only_one_row, True, False, False)
- async def one(self):
+ async def one(self) -> RowMapping:
"""Return exactly one object or raise an exception.
Equivalent to :meth:`_asyncio.AsyncResult.one` except that
@@ -651,11 +705,15 @@ class AsyncMappingResult(AsyncCommon):
return await greenlet_spawn(self._only_one_row, True, True, False)
-async def _ensure_sync_result(result, calling_method):
+_RT = TypeVar("_RT", bound="Result")
+
+
+async def _ensure_sync_result(result: _RT, calling_method: Any) -> _RT:
+ cursor_result: CursorResult
if not result._is_cursor:
- cursor_result = getattr(result, "raw", None)
+ cursor_result = getattr(result, "raw", None) # type: ignore
else:
- cursor_result = result
+ cursor_result = result # type: ignore
if cursor_result and cursor_result.context._is_server_side:
await greenlet_spawn(cursor_result.close)
raise async_exc.AsyncMethodRequired(
diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py
index 0503076aa..0d6ae92b4 100644
--- a/lib/sqlalchemy/ext/asyncio/scoping.py
+++ b/lib/sqlalchemy/ext/asyncio/scoping.py
@@ -8,12 +8,50 @@
from __future__ import annotations
from typing import Any
-
+from typing import Callable
+from typing import Iterable
+from typing import Iterator
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
+
+from .session import async_sessionmaker
from .session import AsyncSession
+from ... import exc as sa_exc
from ... import util
-from ...orm.scoping import ScopedSessionMixin
+from ...orm.session import Session
from ...util import create_proxy_methods
from ...util import ScopedRegistry
+from ...util import warn
+from ...util import warn_deprecated
+
+if TYPE_CHECKING:
+ from .engine import AsyncConnection
+ from .result import AsyncResult
+ from .result import AsyncScalarResult
+ from .session import AsyncSessionTransaction
+ from ...engine import Connection
+ from ...engine import Engine
+ from ...engine import Result
+ from ...engine import Row
+ from ...engine.interfaces import _CoreAnyExecuteParams
+ from ...engine.interfaces import _CoreSingleExecuteParams
+ from ...engine.interfaces import _ExecuteOptions
+ from ...engine.interfaces import _ExecuteOptionsParameter
+ from ...engine.result import ScalarResult
+ from ...orm._typing import _IdentityKeyType
+ from ...orm._typing import _O
+ from ...orm.interfaces import ORMOption
+ from ...orm.session import _BindArguments
+ from ...orm.session import _EntityBindKey
+ from ...orm.session import _PKIdentityArgument
+ from ...orm.session import _SessionBind
+ from ...sql.base import Executable
+ from ...sql.elements import ClauseElement
+ from ...sql.selectable import ForUpdateArg
@create_proxy_methods(
@@ -62,7 +100,7 @@ from ...util import ScopedRegistry
"info",
],
)
-class async_scoped_session(ScopedSessionMixin):
+class async_scoped_session:
"""Provides scoped management of :class:`.AsyncSession` objects.
See the section :ref:`asyncio_scoped_session` for usage details.
@@ -74,17 +112,23 @@ class async_scoped_session(ScopedSessionMixin):
_support_async = True
- def __init__(self, session_factory, scopefunc):
+ session_factory: async_sessionmaker
+ """The `session_factory` provided to `__init__` is stored in this
+ attribute and may be accessed at a later time. This can be useful when
+ a new non-scoped :class:`.AsyncSession` is needed."""
+
+ registry: ScopedRegistry[AsyncSession]
+
+ def __init__(
+ self,
+ session_factory: async_sessionmaker,
+ scopefunc: Callable[[], Any],
+ ):
"""Construct a new :class:`_asyncio.async_scoped_session`.
:param session_factory: a factory to create new :class:`_asyncio.AsyncSession`
instances. This is usually, but not necessarily, an instance
- of :class:`_orm.sessionmaker` which itself was passed the
- :class:`_asyncio.AsyncSession` to its :paramref:`_orm.sessionmaker.class_`
- parameter::
-
- async_session_factory = sessionmaker(some_async_engine, class_= AsyncSession)
- AsyncSession = async_scoped_session(async_session_factory, scopefunc=current_task)
+ of :class:`_asyncio.async_sessionmaker`.
:param scopefunc: function which defines
the current scope. A function such as ``asyncio.current_task``
@@ -96,10 +140,59 @@ class async_scoped_session(ScopedSessionMixin):
self.registry = ScopedRegistry(session_factory, scopefunc)
@property
- def _proxied(self):
+ def _proxied(self) -> AsyncSession:
return self.registry()
- async def remove(self):
+ def __call__(self, **kw: Any) -> AsyncSession:
+ r"""Return the current :class:`.AsyncSession`, creating it
+ using the :attr:`.scoped_session.session_factory` if not present.
+
+ :param \**kw: Keyword arguments will be passed to the
+ :attr:`.scoped_session.session_factory` callable, if an existing
+ :class:`.AsyncSession` is not present. If the
+ :class:`.AsyncSession` is present
+ and keyword arguments have been passed,
+ :exc:`~sqlalchemy.exc.InvalidRequestError` is raised.
+
+ """
+ if kw:
+ if self.registry.has():
+ raise sa_exc.InvalidRequestError(
+ "Scoped session is already present; "
+ "no new arguments may be specified."
+ )
+ else:
+ sess = self.session_factory(**kw)
+ self.registry.set(sess)
+ else:
+ sess = self.registry()
+ if not self._support_async and sess._is_asyncio:
+ warn_deprecated(
+ "Using `scoped_session` with asyncio is deprecated and "
+ "will raise an error in a future version. "
+ "Please use `async_scoped_session` instead.",
+ "1.4.23",
+ )
+ return sess
+
+ def configure(self, **kwargs: Any) -> None:
+ """reconfigure the :class:`.sessionmaker` used by this
+ :class:`.scoped_session`.
+
+ See :meth:`.sessionmaker.configure`.
+
+ """
+
+ if self.registry.has():
+ warn(
+ "At least one scoped session is already present. "
+ " configure() can not affect sessions that have "
+ "already been created."
+ )
+
+ self.session_factory.configure(**kwargs)
+
+ async def remove(self) -> None:
"""Dispose of the current :class:`.AsyncSession`, if present.
Different from scoped_session's remove method, this method would use
@@ -152,7 +245,9 @@ class async_scoped_session(ScopedSessionMixin):
Proxied for the :class:`_orm.Session` class on
behalf of the :class:`_asyncio.AsyncSession` class.
- """
+
+
+ """ # noqa: E501
return self._proxied.__iter__()
@@ -199,7 +294,7 @@ class async_scoped_session(ScopedSessionMixin):
return self._proxied.add_all(instances)
- def begin(self):
+ def begin(self) -> AsyncSessionTransaction:
r"""Return an :class:`_asyncio.AsyncSessionTransaction` object.
.. container:: class_bases
@@ -228,7 +323,7 @@ class async_scoped_session(ScopedSessionMixin):
return self._proxied.begin()
- def begin_nested(self):
+ def begin_nested(self) -> AsyncSessionTransaction:
r"""Return an :class:`_asyncio.AsyncSessionTransaction` object
which will begin a "nested" transaction, e.g. SAVEPOINT.
@@ -247,7 +342,7 @@ class async_scoped_session(ScopedSessionMixin):
return self._proxied.begin_nested()
- async def close(self):
+ async def close(self) -> None:
r"""Close out the transactional resources and ORM objects used by this
:class:`_asyncio.AsyncSession`.
@@ -284,7 +379,7 @@ class async_scoped_session(ScopedSessionMixin):
return await self._proxied.close()
- async def commit(self):
+ async def commit(self) -> None:
r"""Commit the current transaction in progress.
.. container:: class_bases
@@ -296,7 +391,7 @@ class async_scoped_session(ScopedSessionMixin):
return await self._proxied.commit()
- async def connection(self, **kw):
+ async def connection(self, **kw: Any) -> AsyncConnection:
r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to
this :class:`.Session` object's transactional state.
@@ -321,7 +416,7 @@ class async_scoped_session(ScopedSessionMixin):
return await self._proxied.connection(**kw)
- async def delete(self, instance):
+ async def delete(self, instance: object) -> None:
r"""Mark an instance as deleted.
.. container:: class_bases
@@ -345,12 +440,12 @@ class async_scoped_session(ScopedSessionMixin):
async def execute(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Result:
r"""Execute a statement and return a buffered
:class:`_engine.Result` object.
@@ -519,7 +614,7 @@ class async_scoped_session(ScopedSessionMixin):
return self._proxied.expunge_all()
- async def flush(self, objects=None):
+ async def flush(self, objects: Optional[Sequence[Any]] = None) -> None:
r"""Flush all the object changes to the database.
.. container:: class_bases
@@ -538,13 +633,15 @@ class async_scoped_session(ScopedSessionMixin):
async def get(
self,
- entity,
- ident,
- options=None,
- populate_existing=False,
- with_for_update=None,
- identity_token=None,
- ):
+ entity: _EntityBindKey[_O],
+ ident: _PKIdentityArgument,
+ *,
+ options: Optional[Sequence[ORMOption]] = None,
+ populate_existing: bool = False,
+ with_for_update: Optional[ForUpdateArg] = None,
+ identity_token: Optional[Any] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ ) -> Optional[_O]:
r"""Return an instance based on the given primary key identifier,
or ``None`` if not found.
@@ -568,9 +665,16 @@ class async_scoped_session(ScopedSessionMixin):
populate_existing=populate_existing,
with_for_update=with_for_update,
identity_token=identity_token,
+ execution_options=execution_options,
)
- def get_bind(self, mapper=None, clause=None, bind=None, **kw):
+ def get_bind(
+ self,
+ mapper: Optional[_EntityBindKey[_O]] = None,
+ clause: Optional[ClauseElement] = None,
+ bind: Optional[_SessionBind] = None,
+ **kw: Any,
+ ) -> Union[Engine, Connection]:
r"""Return a "bind" to which the synchronous proxied :class:`_orm.Session`
is bound.
@@ -724,7 +828,7 @@ class async_scoped_session(ScopedSessionMixin):
instance, include_collections=include_collections
)
- async def invalidate(self):
+ async def invalidate(self) -> None:
r"""Close this Session, using connection invalidation.
.. container:: class_bases
@@ -738,7 +842,13 @@ class async_scoped_session(ScopedSessionMixin):
return await self._proxied.invalidate()
- async def merge(self, instance, load=True, options=None):
+ async def merge(
+ self,
+ instance: _O,
+ *,
+ load: bool = True,
+ options: Optional[Sequence[ORMOption]] = None,
+ ) -> _O:
r"""Copy the state of a given instance into a corresponding instance
within this :class:`_asyncio.AsyncSession`.
@@ -757,8 +867,11 @@ class async_scoped_session(ScopedSessionMixin):
return await self._proxied.merge(instance, load=load, options=options)
async def refresh(
- self, instance, attribute_names=None, with_for_update=None
- ):
+ self,
+ instance: object,
+ attribute_names: Optional[Iterable[str]] = None,
+ with_for_update: Optional[ForUpdateArg] = None,
+ ) -> None:
r"""Expire and refresh the attributes on the given instance.
.. container:: class_bases
@@ -785,7 +898,7 @@ class async_scoped_session(ScopedSessionMixin):
with_for_update=with_for_update,
)
- async def rollback(self):
+ async def rollback(self) -> None:
r"""Rollback the current transaction in progress.
.. container:: class_bases
@@ -799,12 +912,12 @@ class async_scoped_session(ScopedSessionMixin):
async def scalar(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Any:
r"""Execute a statement and return a scalar result.
.. container:: class_bases
@@ -829,12 +942,12 @@ class async_scoped_session(ScopedSessionMixin):
async def scalars(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> ScalarResult[Any]:
r"""Execute a statement and return scalar results.
.. container:: class_bases
@@ -865,12 +978,12 @@ class async_scoped_session(ScopedSessionMixin):
async def stream(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> AsyncResult:
r"""Execute a statement and return a streaming
:class:`_asyncio.AsyncResult` object.
@@ -892,12 +1005,12 @@ class async_scoped_session(ScopedSessionMixin):
async def stream_scalars(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> AsyncScalarResult[Any]:
r"""Execute a statement and return a stream of scalar results.
.. container:: class_bases
@@ -1159,7 +1272,7 @@ class async_scoped_session(ScopedSessionMixin):
return self._proxied.info
@classmethod
- async def close_all(self):
+ async def close_all(self) -> None:
r"""Close all :class:`_asyncio.AsyncSession` sessions.
.. container:: class_bases
diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py
index 769fe05bd..7d63b084c 100644
--- a/lib/sqlalchemy/ext/asyncio/session.py
+++ b/lib/sqlalchemy/ext/asyncio/session.py
@@ -7,17 +7,65 @@
from __future__ import annotations
from typing import Any
+from typing import Dict
+from typing import Iterable
+from typing import Iterator
+from typing import NoReturn
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
from . import engine
-from . import result as _result
from .base import ReversibleProxy
from .base import StartableContext
from .result import _ensure_sync_result
+from .result import AsyncResult
+from .result import AsyncScalarResult
from ... import util
from ...orm import object_session
from ...orm import Session
+from ...orm import SessionTransaction
from ...orm import state as _instance_state
from ...util.concurrency import greenlet_spawn
+from ...util.typing import Protocol
+
+if TYPE_CHECKING:
+ from .engine import AsyncConnection
+ from .engine import AsyncEngine
+ from ...engine import Connection
+ from ...engine import Engine
+ from ...engine import Result
+ from ...engine import Row
+ from ...engine import ScalarResult
+ from ...engine import Transaction
+ from ...engine.interfaces import _CoreAnyExecuteParams
+ from ...engine.interfaces import _CoreSingleExecuteParams
+ from ...engine.interfaces import _ExecuteOptions
+ from ...engine.interfaces import _ExecuteOptionsParameter
+ from ...event import dispatcher
+ from ...orm._typing import _IdentityKeyType
+ from ...orm._typing import _O
+ from ...orm.identity import IdentityMap
+ from ...orm.interfaces import ORMOption
+ from ...orm.session import _BindArguments
+ from ...orm.session import _EntityBindKey
+ from ...orm.session import _PKIdentityArgument
+ from ...orm.session import _SessionBind
+ from ...orm.session import _SessionBindKey
+ from ...sql.base import Executable
+ from ...sql.elements import ClauseElement
+ from ...sql.selectable import ForUpdateArg
+
+_AsyncSessionBind = Union["AsyncEngine", "AsyncConnection"]
+
+
+class _SyncSessionCallable(Protocol):
+ def __call__(self, session: Session, *arg: Any, **kw: Any) -> Any:
+ ...
+
_EXECUTE_OPTIONS = util.immutabledict({"prebuffer_rows": True})
_STREAM_OPTIONS = util.immutabledict({"stream_results": True})
@@ -52,7 +100,7 @@ _STREAM_OPTIONS = util.immutabledict({"stream_results": True})
"info",
],
)
-class AsyncSession(ReversibleProxy):
+class AsyncSession(ReversibleProxy[Session]):
"""Asyncio version of :class:`_orm.Session`.
The :class:`_asyncio.AsyncSession` is a proxy for a traditional
@@ -69,9 +117,15 @@ class AsyncSession(ReversibleProxy):
_is_asyncio = True
- dispatch = None
+ dispatch: dispatcher[Session]
- def __init__(self, bind=None, binds=None, sync_session_class=None, **kw):
+ def __init__(
+ self,
+ bind: Optional[_AsyncSessionBind] = None,
+ binds: Optional[Dict[_SessionBindKey, _AsyncSessionBind]] = None,
+ sync_session_class: Optional[Type[Session]] = None,
+ **kw: Any,
+ ):
r"""Construct a new :class:`_asyncio.AsyncSession`.
All parameters other than ``sync_session_class`` are passed to the
@@ -90,14 +144,15 @@ class AsyncSession(ReversibleProxy):
.. versionadded:: 1.4.24
"""
- kw["future"] = True
+ sync_bind = sync_binds = None
+
if bind:
self.bind = bind
- bind = engine._get_sync_engine_or_connection(bind)
+ sync_bind = engine._get_sync_engine_or_connection(bind)
if binds:
self.binds = binds
- binds = {
+ sync_binds = {
key: engine._get_sync_engine_or_connection(b)
for key, b in binds.items()
}
@@ -106,10 +161,10 @@ class AsyncSession(ReversibleProxy):
self.sync_session_class = sync_session_class
self.sync_session = self._proxied = self._assign_proxied(
- self.sync_session_class(bind=bind, binds=binds, **kw)
+ self.sync_session_class(bind=sync_bind, binds=sync_binds, **kw)
)
- sync_session_class = Session
+ sync_session_class: Type[Session] = Session
"""The class or callable that provides the
underlying :class:`_orm.Session` instance for a particular
:class:`_asyncio.AsyncSession`.
@@ -138,9 +193,19 @@ class AsyncSession(ReversibleProxy):
"""
+ @classmethod
+ def _no_async_engine_events(cls) -> NoReturn:
+ raise NotImplementedError(
+ "asynchronous events are not implemented at this time. Apply "
+ "synchronous listeners to the AsyncSession.sync_session."
+ )
+
async def refresh(
- self, instance, attribute_names=None, with_for_update=None
- ):
+ self,
+ instance: object,
+ attribute_names: Optional[Iterable[str]] = None,
+ with_for_update: Optional[ForUpdateArg] = None,
+ ) -> None:
"""Expire and refresh the attributes on the given instance.
A query will be issued to the database and all attributes will be
@@ -155,14 +220,16 @@ class AsyncSession(ReversibleProxy):
"""
- return await greenlet_spawn(
+ await greenlet_spawn(
self.sync_session.refresh,
instance,
attribute_names=attribute_names,
with_for_update=with_for_update,
)
- async def run_sync(self, fn, *arg, **kw):
+ async def run_sync(
+ self, fn: _SyncSessionCallable, *arg: Any, **kw: Any
+ ) -> Any:
"""Invoke the given sync callable passing sync self as the first
argument.
@@ -191,12 +258,12 @@ class AsyncSession(ReversibleProxy):
async def execute(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Result:
"""Execute a statement and return a buffered
:class:`_engine.Result` object.
@@ -225,12 +292,12 @@ class AsyncSession(ReversibleProxy):
async def scalar(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Any:
"""Execute a statement and return a scalar result.
.. seealso::
@@ -250,12 +317,12 @@ class AsyncSession(ReversibleProxy):
async def scalars(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> ScalarResult[Any]:
"""Execute a statement and return scalar results.
:return: a :class:`_result.ScalarResult` object
@@ -281,13 +348,16 @@ class AsyncSession(ReversibleProxy):
async def get(
self,
- entity,
- ident,
- options=None,
- populate_existing=False,
- with_for_update=None,
- identity_token=None,
- ):
+ entity: _EntityBindKey[_O],
+ ident: _PKIdentityArgument,
+ *,
+ options: Optional[Sequence[ORMOption]] = None,
+ populate_existing: bool = False,
+ with_for_update: Optional[ForUpdateArg] = None,
+ identity_token: Optional[Any] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ ) -> Optional[_O]:
+
"""Return an instance based on the given primary key identifier,
or ``None`` if not found.
@@ -297,7 +367,8 @@ class AsyncSession(ReversibleProxy):
"""
- return await greenlet_spawn(
+
+ result_obj = await greenlet_spawn(
self.sync_session.get,
entity,
ident,
@@ -306,15 +377,17 @@ class AsyncSession(ReversibleProxy):
with_for_update=with_for_update,
identity_token=identity_token,
)
+ return result_obj
async def stream(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> AsyncResult:
+
"""Execute a statement and return a streaming
:class:`_asyncio.AsyncResult` object.
@@ -335,16 +408,16 @@ class AsyncSession(ReversibleProxy):
bind_arguments=bind_arguments,
**kw,
)
- return _result.AsyncResult(result)
+ return AsyncResult(result)
async def stream_scalars(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> AsyncScalarResult[Any]:
"""Execute a statement and return a stream of scalar results.
:return: an :class:`_asyncio.AsyncScalarResult` object
@@ -368,7 +441,7 @@ class AsyncSession(ReversibleProxy):
)
return result.scalars()
- async def delete(self, instance):
+ async def delete(self, instance: object) -> None:
"""Mark an instance as deleted.
The database delete operation occurs upon ``flush()``.
@@ -381,9 +454,15 @@ class AsyncSession(ReversibleProxy):
:meth:`_orm.Session.delete` - main documentation for delete
"""
- return await greenlet_spawn(self.sync_session.delete, instance)
+ await greenlet_spawn(self.sync_session.delete, instance)
- async def merge(self, instance, load=True, options=None):
+ async def merge(
+ self,
+ instance: _O,
+ *,
+ load: bool = True,
+ options: Optional[Sequence[ORMOption]] = None,
+ ) -> _O:
"""Copy the state of a given instance into a corresponding instance
within this :class:`_asyncio.AsyncSession`.
@@ -396,7 +475,7 @@ class AsyncSession(ReversibleProxy):
self.sync_session.merge, instance, load=load, options=options
)
- async def flush(self, objects=None):
+ async def flush(self, objects: Optional[Sequence[Any]] = None) -> None:
"""Flush all the object changes to the database.
.. seealso::
@@ -406,7 +485,7 @@ class AsyncSession(ReversibleProxy):
"""
await greenlet_spawn(self.sync_session.flush, objects=objects)
- def get_transaction(self):
+ def get_transaction(self) -> Optional[AsyncSessionTransaction]:
"""Return the current root transaction in progress, if any.
:return: an :class:`_asyncio.AsyncSessionTransaction` object, or
@@ -421,7 +500,7 @@ class AsyncSession(ReversibleProxy):
else:
return None
- def get_nested_transaction(self):
+ def get_nested_transaction(self) -> Optional[AsyncSessionTransaction]:
"""Return the current nested transaction in progress, if any.
:return: an :class:`_asyncio.AsyncSessionTransaction` object, or
@@ -437,7 +516,13 @@ class AsyncSession(ReversibleProxy):
else:
return None
- def get_bind(self, mapper=None, clause=None, bind=None, **kw):
+ def get_bind(
+ self,
+ mapper: Optional[_EntityBindKey[_O]] = None,
+ clause: Optional[ClauseElement] = None,
+ bind: Optional[_SessionBind] = None,
+ **kw: Any,
+ ) -> Union[Engine, Connection]:
"""Return a "bind" to which the synchronous proxied :class:`_orm.Session`
is bound.
@@ -515,7 +600,7 @@ class AsyncSession(ReversibleProxy):
mapper=mapper, clause=clause, bind=bind, **kw
)
- async def connection(self, **kw):
+ async def connection(self, **kw: Any) -> AsyncConnection:
r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to
this :class:`.Session` object's transactional state.
@@ -539,7 +624,7 @@ class AsyncSession(ReversibleProxy):
sync_connection
)
- def begin(self):
+ def begin(self) -> AsyncSessionTransaction:
"""Return an :class:`_asyncio.AsyncSessionTransaction` object.
The underlying :class:`_orm.Session` will perform the
@@ -562,7 +647,7 @@ class AsyncSession(ReversibleProxy):
return AsyncSessionTransaction(self)
- def begin_nested(self):
+ def begin_nested(self) -> AsyncSessionTransaction:
"""Return an :class:`_asyncio.AsyncSessionTransaction` object
which will begin a "nested" transaction, e.g. SAVEPOINT.
@@ -575,15 +660,15 @@ class AsyncSession(ReversibleProxy):
return AsyncSessionTransaction(self, nested=True)
- async def rollback(self):
+ async def rollback(self) -> None:
"""Rollback the current transaction in progress."""
- return await greenlet_spawn(self.sync_session.rollback)
+ await greenlet_spawn(self.sync_session.rollback)
- async def commit(self):
+ async def commit(self) -> None:
"""Commit the current transaction in progress."""
- return await greenlet_spawn(self.sync_session.commit)
+ await greenlet_spawn(self.sync_session.commit)
- async def close(self):
+ async def close(self) -> None:
"""Close out the transactional resources and ORM objects used by this
:class:`_asyncio.AsyncSession`.
@@ -613,25 +698,25 @@ class AsyncSession(ReversibleProxy):
"""
return await greenlet_spawn(self.sync_session.close)
- async def invalidate(self):
+ async def invalidate(self) -> None:
"""Close this Session, using connection invalidation.
For a complete description, see :meth:`_orm.Session.invalidate`.
"""
- return await greenlet_spawn(self.sync_session.invalidate)
+ await greenlet_spawn(self.sync_session.invalidate)
@classmethod
- async def close_all(self):
+ async def close_all(self) -> None:
"""Close all :class:`_asyncio.AsyncSession` sessions."""
- return await greenlet_spawn(self.sync_session.close_all)
+ await greenlet_spawn(self.sync_session.close_all)
- async def __aenter__(self):
+ async def __aenter__(self) -> AsyncSession:
return self
- async def __aexit__(self, type_, value, traceback):
+ async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
await self.close()
- def _maker_context_manager(self):
+ def _maker_context_manager(self) -> _AsyncSessionContextManager:
# TODO: can this use asynccontextmanager ??
return _AsyncSessionContextManager(self)
@@ -1142,21 +1227,159 @@ class AsyncSession(ReversibleProxy):
# END PROXY METHODS AsyncSession
+class async_sessionmaker:
+ """A configurable :class:`.AsyncSession` factory.
+
+ The :class:`.async_sessionmaker` factory works in the same way as the
+ :class:`.sessionmaker` factory, to generate new :class:`.AsyncSession`
+ objects when called, creating them given
+ the configurational arguments established here.
+
+ e.g.::
+
+ from sqlalchemy.ext.asyncio import create_async_engine
+ from sqlalchemy.ext.asyncio import async_sessionmaker
+
+ async def main():
+ # an AsyncEngine, which the AsyncSession will use for connection
+ # resources
+ engine = create_async_engine('postgresql+asycncpg://scott:tiger@localhost/')
+
+ AsyncSession = async_sessionmaker(engine)
+
+ async with async_session() as session:
+ session.add(some_object)
+ session.add(some_other_object)
+ await session.commit()
+
+ .. versionadded:: 2.0 :class:`.asyncio_sessionmaker` provides a
+ :class:`.sessionmaker` class that's dedicated to the
+ :class:`.AsyncSession` object, including pep-484 typing support.
+
+ .. seealso::
+
+ :ref:`asyncio_orm` - shows example use
+
+ :class:`.sessionmaker` - general overview of the
+ :class:`.sessionmaker` architecture
+
+
+ :ref:`session_getting` - introductory text on creating
+ sessions using :class:`.sessionmaker`.
+
+ """ # noqa E501
+
+ class_: Type[AsyncSession]
+
+ def __init__(
+ self,
+ bind: Optional[_AsyncSessionBind] = None,
+ class_: Type[AsyncSession] = AsyncSession,
+ autoflush: bool = True,
+ expire_on_commit: bool = True,
+ info: Optional[Dict[Any, Any]] = None,
+ **kw: Any,
+ ):
+ r"""Construct a new :class:`.async_sessionmaker`.
+
+ All arguments here except for ``class_`` correspond to arguments
+ accepted by :class:`.Session` directly. See the
+ :meth:`.AsyncSession.__init__` docstring for more details on
+ parameters.
+
+
+ """
+ kw["bind"] = bind
+ kw["autoflush"] = autoflush
+ kw["expire_on_commit"] = expire_on_commit
+ if info is not None:
+ kw["info"] = info
+ self.kw = kw
+ self.class_ = class_
+
+ def begin(self) -> _AsyncSessionContextManager:
+ """Produce a context manager that both provides a new
+ :class:`_orm.AsyncSession` as well as a transaction that commits.
+
+
+ e.g.::
+
+ async def main():
+ Session = async_sessionmaker(some_engine)
+
+ async with Session.begin() as session:
+ session.add(some_object)
+
+ # commits transaction, closes session
+
+
+ """
+
+ session = self()
+ return session._maker_context_manager()
+
+ def __call__(self, **local_kw: Any) -> AsyncSession:
+ """Produce a new :class:`.AsyncSession` object using the configuration
+ established in this :class:`.async_sessionmaker`.
+
+ In Python, the ``__call__`` method is invoked on an object when
+ it is "called" in the same way as a function::
+
+ AsyncSession = async_sessionmaker(async_engine, expire_on_commit=False)
+ session = AsyncSession() # invokes sessionmaker.__call__()
+
+ """ # noqa E501
+ for k, v in self.kw.items():
+ if k == "info" and "info" in local_kw:
+ d = v.copy()
+ d.update(local_kw["info"])
+ local_kw["info"] = d
+ else:
+ local_kw.setdefault(k, v)
+ return self.class_(**local_kw)
+
+ def configure(self, **new_kw: Any) -> None:
+ """(Re)configure the arguments for this async_sessionmaker.
+
+ e.g.::
+
+ AsyncSession = async_sessionmaker(some_engine)
+
+ AsyncSession.configure(bind=create_async_engine('sqlite+aiosqlite://'))
+ """ # noqa E501
+
+ self.kw.update(new_kw)
+
+ def __repr__(self) -> str:
+ return "%s(class_=%r, %s)" % (
+ self.__class__.__name__,
+ self.class_.__name__,
+ ", ".join("%s=%r" % (k, v) for k, v in self.kw.items()),
+ )
+
+
class _AsyncSessionContextManager:
- def __init__(self, async_session):
+ __slots__ = ("async_session", "trans")
+
+ async_session: AsyncSession
+ trans: AsyncSessionTransaction
+
+ def __init__(self, async_session: AsyncSession):
self.async_session = async_session
- async def __aenter__(self):
+ async def __aenter__(self) -> AsyncSession:
self.trans = self.async_session.begin()
await self.trans.__aenter__()
return self.async_session
- async def __aexit__(self, type_, value, traceback):
+ async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
await self.trans.__aexit__(type_, value, traceback)
await self.async_session.__aexit__(type_, value, traceback)
-class AsyncSessionTransaction(ReversibleProxy, StartableContext):
+class AsyncSessionTransaction(
+ ReversibleProxy[SessionTransaction], StartableContext
+):
"""A wrapper for the ORM :class:`_orm.SessionTransaction` object.
This object is provided so that a transaction-holding object
@@ -1174,36 +1397,41 @@ class AsyncSessionTransaction(ReversibleProxy, StartableContext):
__slots__ = ("session", "sync_transaction", "nested")
- def __init__(self, session, nested=False):
+ session: AsyncSession
+ sync_transaction: Optional[SessionTransaction]
+
+ def __init__(self, session: AsyncSession, nested: bool = False):
self.session = session
self.nested = nested
self.sync_transaction = None
@property
- def is_active(self):
+ def is_active(self) -> bool:
return (
self._sync_transaction() is not None
and self._sync_transaction().is_active
)
- def _sync_transaction(self):
+ def _sync_transaction(self) -> SessionTransaction:
if not self.sync_transaction:
self._raise_for_not_started()
return self.sync_transaction
- async def rollback(self):
+ async def rollback(self) -> None:
"""Roll back this :class:`_asyncio.AsyncTransaction`."""
await greenlet_spawn(self._sync_transaction().rollback)
- async def commit(self):
+ async def commit(self) -> None:
"""Commit this :class:`_asyncio.AsyncTransaction`."""
await greenlet_spawn(self._sync_transaction().commit)
- async def start(self, is_ctxmanager=False):
+ async def start(
+ self, is_ctxmanager: bool = False
+ ) -> AsyncSessionTransaction:
self.sync_transaction = self._assign_proxied(
await greenlet_spawn(
- self.session.sync_session.begin_nested
+ self.session.sync_session.begin_nested # type: ignore
if self.nested
else self.session.sync_session.begin
)
@@ -1212,13 +1440,13 @@ class AsyncSessionTransaction(ReversibleProxy, StartableContext):
self.sync_transaction.__enter__()
return self
- async def __aexit__(self, type_, value, traceback):
+ async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
await greenlet_spawn(
self._sync_transaction().__exit__, type_, value, traceback
)
-def async_object_session(instance):
+def async_object_session(instance: object) -> Optional[AsyncSession]:
"""Return the :class:`_asyncio.AsyncSession` to which the given instance
belongs.
@@ -1247,7 +1475,7 @@ def async_object_session(instance):
return None
-def async_session(session: Session) -> AsyncSession:
+def async_session(session: Session) -> Optional[AsyncSession]:
"""Return the :class:`_asyncio.AsyncSession` which is proxying the given
:class:`_orm.Session` object, if any.
@@ -1260,4 +1488,4 @@ def async_session(session: Session) -> AsyncSession:
return AsyncSession._retrieve_proxy_for_target(session, regenerate=False)
-_instance_state._async_provider = async_session
+_instance_state._async_provider = async_session # type: ignore
diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py
index d8f57e149..c5348c237 100644
--- a/lib/sqlalchemy/orm/base.py
+++ b/lib/sqlalchemy/orm/base.py
@@ -448,7 +448,13 @@ def _entity_descriptor(entity, key):
) from err
-_state_mapper = util.dottedgetter("manager.mapper")
+if TYPE_CHECKING:
+
+ def _state_mapper(state: InstanceState[_O]) -> Mapper[_O]:
+ ...
+
+else:
+ _state_mapper = util.dottedgetter("manager.mapper")
@inspection._inspects(type)
diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py
index e62a83397..c531e7cf1 100644
--- a/lib/sqlalchemy/orm/events.py
+++ b/lib/sqlalchemy/orm/events.py
@@ -10,6 +10,7 @@
"""
from __future__ import annotations
+from typing import Any
import weakref
from . import instrumentation
@@ -1324,7 +1325,7 @@ class _MapperEventsHold(_EventsHold):
_sessionevents_lifecycle_event_names = set()
-class SessionEvents(event.Events):
+class SessionEvents(event.Events[Session]):
"""Define events specific to :class:`.Session` lifecycle.
e.g.::
@@ -1396,12 +1397,21 @@ class SessionEvents(event.Events):
return target
elif isinstance(target, Session):
return target
+ elif hasattr(target, "_no_async_engine_events"):
+ target._no_async_engine_events()
else:
# allows alternate SessionEvents-like-classes to be consulted
return event.Events._accept_with(target)
@classmethod
- def _listen(cls, event_key, raw=False, restore_load_context=False, **kw):
+ def _listen(
+ cls,
+ event_key: Any,
+ *,
+ raw: bool = False,
+ restore_load_context: bool = False,
+ **kw: Any,
+ ) -> None:
is_instance_event = (
event_key.identifier in _sessionevents_lifecycle_event_names
)
diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py
index 030d1595b..a5dc305d2 100644
--- a/lib/sqlalchemy/orm/instrumentation.py
+++ b/lib/sqlalchemy/orm/instrumentation.py
@@ -35,6 +35,7 @@ from __future__ import annotations
from typing import Any
from typing import Dict
from typing import Generic
+from typing import Optional
from typing import Set
from typing import TYPE_CHECKING
from typing import TypeVar
@@ -44,6 +45,7 @@ from . import collections
from . import exc
from . import interfaces
from . import state
+from ._typing import _O
from .. import util
from ..event import EventTarget
from ..util import HasMemoized
@@ -52,6 +54,7 @@ from ..util.typing import Protocol
if TYPE_CHECKING:
from .attributes import InstrumentedAttribute
from .mapper import Mapper
+ from .state import InstanceState
from ..event import dispatcher
_T = TypeVar("_T", bound=Any)
@@ -71,7 +74,7 @@ class _ExpiredAttributeLoaderProto(Protocol):
class ClassManager(
HasMemoized,
Dict[str, "InstrumentedAttribute[Any]"],
- Generic[_T],
+ Generic[_O],
EventTarget,
):
"""Tracks state information at the class level."""
@@ -230,7 +233,7 @@ class ClassManager(
return frozenset([attr.impl for attr in self.values()])
@util.memoized_property
- def mapper(self) -> Mapper[_T]:
+ def mapper(self) -> Mapper[_O]:
# raises unless self.mapper has been assigned
raise exc.UnmappedClassError(self.class_)
@@ -442,7 +445,7 @@ class ClassManager(
# InstanceState management
- def new_instance(self, state=None):
+ def new_instance(self, state: Optional[InstanceState[_O]] = None) -> _O:
instance = self.class_.__new__(self.class_)
if state is None:
state = self._state_constructor(instance, self)
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index c85861a59..abe11cc68 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -132,6 +132,8 @@ class Mapper(
_identity_class: Type[_O]
always_refresh: bool
+ allow_partial_pks: bool
+ version_id_col: Optional[ColumnElement[Any]]
@util.deprecated_params(
non_primary=(
@@ -2931,7 +2933,7 @@ class Mapper(
self,
state: InstanceState[_O],
dict_: _InstanceDict,
- column: Column[Any],
+ column: ColumnElement[Any],
passive: PassiveFlag = PassiveFlag.PASSIVE_RETURN_NO_VALUE,
) -> Any:
prop = self._columntoproperty[column]
diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py
index e498b17b4..1dd7a6952 100644
--- a/lib/sqlalchemy/orm/scoping.py
+++ b/lib/sqlalchemy/orm/scoping.py
@@ -38,6 +38,7 @@ if TYPE_CHECKING:
from .interfaces import ORMOption
from .mapper import Mapper
from .query import Query
+ from .session import _BindArguments
from .session import _EntityBindKey
from .session import _PKIdentityArgument
from .session import _SessionBind
@@ -65,65 +66,7 @@ class _QueryDescriptorType(Protocol):
_O = TypeVar("_O", bound=object)
-__all__ = ["scoped_session", "ScopedSessionMixin"]
-
-
-class ScopedSessionMixin:
- session_factory: sessionmaker
- _support_async: bool
- registry: ScopedRegistry[Session]
-
- @property
- def _proxied(self) -> Session:
- return self.registry() # type: ignore
-
- def __call__(self, **kw: Any) -> Session:
- r"""Return the current :class:`.Session`, creating it
- using the :attr:`.scoped_session.session_factory` if not present.
-
- :param \**kw: Keyword arguments will be passed to the
- :attr:`.scoped_session.session_factory` callable, if an existing
- :class:`.Session` is not present. If the :class:`.Session` is present
- and keyword arguments have been passed,
- :exc:`~sqlalchemy.exc.InvalidRequestError` is raised.
-
- """
- if kw:
- if self.registry.has():
- raise sa_exc.InvalidRequestError(
- "Scoped session is already present; "
- "no new arguments may be specified."
- )
- else:
- sess = self.session_factory(**kw)
- self.registry.set(sess)
- else:
- sess = self.registry()
- if not self._support_async and sess._is_asyncio:
- warn_deprecated(
- "Using `scoped_session` with asyncio is deprecated and "
- "will raise an error in a future version. "
- "Please use `async_scoped_session` instead.",
- "1.4.23",
- )
- return sess
-
- def configure(self, **kwargs: Any) -> None:
- """reconfigure the :class:`.sessionmaker` used by this
- :class:`.scoped_session`.
-
- See :meth:`.sessionmaker.configure`.
-
- """
-
- if self.registry.has():
- warn(
- "At least one scoped session is already present. "
- " configure() can not affect sessions that have "
- "already been created."
- )
-
- self.session_factory.configure(**kwargs)
+__all__ = ["scoped_session"]
@create_proxy_methods(
@@ -173,7 +116,7 @@ class ScopedSessionMixin:
"info",
],
)
-class scoped_session(ScopedSessionMixin):
+class scoped_session:
"""Provides scoped management of :class:`.Session` objects.
See :ref:`unitofwork_contextual` for a tutorial.
@@ -191,8 +134,9 @@ class scoped_session(ScopedSessionMixin):
session_factory: sessionmaker
"""The `session_factory` provided to `__init__` is stored in this
attribute and may be accessed at a later time. This can be useful when
- a new non-scoped :class:`.Session` or :class:`_engine.Connection` to the
- database is needed."""
+ a new non-scoped :class:`.Session` is needed."""
+
+ registry: ScopedRegistry[Session]
def __init__(
self,
@@ -222,6 +166,58 @@ class scoped_session(ScopedSessionMixin):
else:
self.registry = ThreadLocalRegistry(session_factory)
+ @property
+ def _proxied(self) -> Session:
+ return self.registry()
+
+ def __call__(self, **kw: Any) -> Session:
+ r"""Return the current :class:`.Session`, creating it
+ using the :attr:`.scoped_session.session_factory` if not present.
+
+ :param \**kw: Keyword arguments will be passed to the
+ :attr:`.scoped_session.session_factory` callable, if an existing
+ :class:`.Session` is not present. If the :class:`.Session` is present
+ and keyword arguments have been passed,
+ :exc:`~sqlalchemy.exc.InvalidRequestError` is raised.
+
+ """
+ if kw:
+ if self.registry.has():
+ raise sa_exc.InvalidRequestError(
+ "Scoped session is already present; "
+ "no new arguments may be specified."
+ )
+ else:
+ sess = self.session_factory(**kw)
+ self.registry.set(sess)
+ else:
+ sess = self.registry()
+ if not self._support_async and sess._is_asyncio:
+ warn_deprecated(
+ "Using `scoped_session` with asyncio is deprecated and "
+ "will raise an error in a future version. "
+ "Please use `async_scoped_session` instead.",
+ "1.4.23",
+ )
+ return sess
+
+ def configure(self, **kwargs: Any) -> None:
+ """reconfigure the :class:`.sessionmaker` used by this
+ :class:`.scoped_session`.
+
+ See :meth:`.sessionmaker.configure`.
+
+ """
+
+ if self.registry.has():
+ warn(
+ "At least one scoped session is already present. "
+ " configure() can not affect sessions that have "
+ "already been created."
+ )
+
+ self.session_factory.configure(**kwargs)
+
def remove(self) -> None:
"""Dispose of the current :class:`.Session`, if present.
@@ -494,9 +490,9 @@ class scoped_session(ScopedSessionMixin):
def connection(
self,
- bind_arguments: Optional[Dict[str, Any]] = None,
+ bind_arguments: Optional[_BindArguments] = None,
execution_options: Optional[_ExecuteOptions] = None,
- ) -> "Connection":
+ ) -> Connection:
r"""Return a :class:`_engine.Connection` object corresponding to this
:class:`.Session` object's transactional state.
@@ -557,7 +553,7 @@ class scoped_session(ScopedSessionMixin):
statement: Executable,
params: Optional[_CoreAnyExecuteParams] = None,
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
- bind_arguments: Optional[Dict[str, Any]] = None,
+ bind_arguments: Optional[_BindArguments] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
) -> Result:
@@ -1567,7 +1563,7 @@ class scoped_session(ScopedSessionMixin):
statement: Executable,
params: Optional[_CoreSingleExecuteParams] = None,
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
- bind_arguments: Optional[Dict[str, Any]] = None,
+ bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> Any:
r"""Execute a statement and return a scalar result.
@@ -1597,7 +1593,7 @@ class scoped_session(ScopedSessionMixin):
statement: Executable,
params: Optional[_CoreSingleExecuteParams] = None,
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
- bind_arguments: Optional[Dict[str, Any]] = None,
+ bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> ScalarResult[Any]:
r"""Execute a statement and return the results as scalars.
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 55ce73cf5..a26c55a24 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -26,7 +26,6 @@ from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
-from typing import TypeVar
from typing import Union
import weakref
@@ -38,6 +37,7 @@ from . import loading
from . import persistence
from . import query
from . import state as statelib
+from ._typing import _O
from ._typing import is_composite_class
from ._typing import is_user_defined_option
from .base import _class_to_mapper
@@ -119,11 +119,12 @@ _sessions: weakref.WeakValueDictionary[
"""Weak-referencing dictionary of :class:`.Session` objects.
"""
-_O = TypeVar("_O", bound=object)
statelib._sessions = _sessions
_PKIdentityArgument = Union[Any, Tuple[Any, ...]]
+_BindArguments = Dict[str, Any]
+
_EntityBindKey = Union[Type[_O], "Mapper[_O]"]
_SessionBindKey = Union[Type[Any], "Mapper[Any]", "Table"]
_SessionBind = Union["Engine", "Connection"]
@@ -251,7 +252,7 @@ class ORMExecuteState(util.MemoizedSlots):
parameters: Optional[_CoreAnyExecuteParams]
execution_options: _ExecuteOptions
local_execution_options: _ExecuteOptions
- bind_arguments: Dict[str, Any]
+ bind_arguments: _BindArguments
_compile_state_cls: Optional[Type[ORMCompileState]]
_starting_event_idx: int
_events_todo: List[Any]
@@ -263,7 +264,7 @@ class ORMExecuteState(util.MemoizedSlots):
statement: Executable,
parameters: Optional[_CoreAnyExecuteParams],
execution_options: _ExecuteOptions,
- bind_arguments: Dict[str, Any],
+ bind_arguments: _BindArguments,
compile_state_cls: Optional[Type[ORMCompileState]],
events_todo: List[_InstanceLevelDispatch[Session]],
):
@@ -286,7 +287,7 @@ class ORMExecuteState(util.MemoizedSlots):
statement: Optional[Executable] = None,
params: Optional[_CoreAnyExecuteParams] = None,
execution_options: Optional[_ExecuteOptionsParameter] = None,
- bind_arguments: Optional[Dict[str, Any]] = None,
+ bind_arguments: Optional[_BindArguments] = None,
) -> Result:
"""Execute the statement represented by this
:class:`.ORMExecuteState`, without re-invoking events that have
@@ -1626,9 +1627,9 @@ class Session(_SessionClassMethods, EventTarget):
def connection(
self,
- bind_arguments: Optional[Dict[str, Any]] = None,
+ bind_arguments: Optional[_BindArguments] = None,
execution_options: Optional[_ExecuteOptions] = None,
- ) -> "Connection":
+ ) -> Connection:
r"""Return a :class:`_engine.Connection` object corresponding to this
:class:`.Session` object's transactional state.
@@ -1690,7 +1691,7 @@ class Session(_SessionClassMethods, EventTarget):
statement: Executable,
params: Optional[_CoreAnyExecuteParams] = None,
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
- bind_arguments: Optional[Dict[str, Any]] = None,
+ bind_arguments: Optional[_BindArguments] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
) -> Result:
@@ -1833,7 +1834,7 @@ class Session(_SessionClassMethods, EventTarget):
statement: Executable,
params: Optional[_CoreSingleExecuteParams] = None,
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
- bind_arguments: Optional[Dict[str, Any]] = None,
+ bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> Any:
"""Execute a statement and return a scalar result.
@@ -1857,7 +1858,7 @@ class Session(_SessionClassMethods, EventTarget):
statement: Executable,
params: Optional[_CoreSingleExecuteParams] = None,
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
- bind_arguments: Optional[Dict[str, Any]] = None,
+ bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> ScalarResult[Any]:
"""Execute a statement and return the results as scalars.
@@ -3099,7 +3100,7 @@ class Session(_SessionClassMethods, EventTarget):
_recursive: Dict[InstanceState[Any], object],
_resolve_conflict_map: Dict[_IdentityKeyType[Any], object],
) -> _O:
- mapper = _state_mapper(state)
+ mapper: Mapper[_O] = _state_mapper(state)
if state in _recursive:
return cast(_O, _recursive[state])
@@ -3249,6 +3250,7 @@ class Session(_SessionClassMethods, EventTarget):
if new_instance:
merged_state.manager.dispatch.load(merged_state, None)
+
return merged
def _validate_persistent(self, state: InstanceState[Any]) -> None:
@@ -4291,7 +4293,7 @@ class sessionmaker(_SessionClassMethods):
In Python, the ``__call__`` method is invoked on an object when
it is "called" in the same way as a function::
- Session = sessionmaker()
+ Session = sessionmaker(some_engine)
session = Session() # invokes sessionmaker.__call__()
"""
diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py
index 7ccda9565..58f141997 100644
--- a/lib/sqlalchemy/orm/state.py
+++ b/lib/sqlalchemy/orm/state.py
@@ -23,12 +23,12 @@ from typing import Optional
from typing import Set
from typing import Tuple
from typing import TYPE_CHECKING
-from typing import TypeVar
import weakref
from . import base
from . import exc as orm_exc
from . import interfaces
+from ._typing import _O
from ._typing import is_collection_impl
from .base import ATTR_WAS_SET
from .base import INIT_OK
@@ -62,8 +62,6 @@ if TYPE_CHECKING:
from ..ext.asyncio.session import async_session as _async_provider
from ..ext.asyncio.session import AsyncSession
-_T = TypeVar("_T", bound=Any)
-
if TYPE_CHECKING:
_sessions: weakref.WeakValueDictionary[int, Session]
else:
@@ -83,7 +81,7 @@ class _InstanceDictProto(Protocol):
@inspection._self_inspects
-class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]):
+class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]):
"""tracks state information at the instance level.
The :class:`.InstanceState` is a key object used by the
@@ -119,15 +117,15 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]):
"expired_attributes",
)
- manager: ClassManager[_T]
+ manager: ClassManager[_O]
session_id: Optional[int] = None
- key: Optional[_IdentityKeyType[_T]] = None
+ key: Optional[_IdentityKeyType[_O]] = None
runid: Optional[int] = None
load_options: Tuple[ORMOption, ...] = ()
load_path: PathRegistry = PathRegistry.root
insert_order: Optional[int] = None
_strong_obj: Optional[object] = None
- obj: weakref.ref[_T]
+ obj: weakref.ref[_O]
committed_state: Dict[str, Any]
@@ -159,7 +157,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]):
see also the ``unmodified`` collection which is intersected
against this set when a refresh operation occurs."""
- callables: Dict[str, Callable[[InstanceState[_T], PassiveFlag], Any]]
+ callables: Dict[str, Callable[[InstanceState[_O], PassiveFlag], Any]]
"""A namespace where a per-state loader callable can be associated.
In SQLAlchemy 1.0, this is only used for lazy loaders / deferred
@@ -174,7 +172,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]):
if not TYPE_CHECKING:
callables = util.EMPTY_DICT
- def __init__(self, obj: _T, manager: ClassManager[_T]):
+ def __init__(self, obj: _O, manager: ClassManager[_O]):
self.class_ = obj.__class__
self.manager = manager
self.obj = weakref.ref(obj, self._cleanup)
@@ -381,7 +379,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]):
return None
@property
- def object(self) -> Optional[_T]:
+ def object(self) -> Optional[_O]:
"""Return the mapped object represented by this
:class:`.InstanceState`.
@@ -411,7 +409,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]):
return self.key[1]
@property
- def identity_key(self) -> Optional[_IdentityKeyType[_T]]:
+ def identity_key(self) -> Optional[_IdentityKeyType[_O]]:
"""Return the identity key for the mapped object.
This is the key used to locate the object within
@@ -435,7 +433,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]):
return {}
@util.memoized_property
- def mapper(self) -> Mapper[_T]:
+ def mapper(self) -> Mapper[_O]:
"""Return the :class:`_orm.Mapper` used for this mapped object."""
return self.manager.mapper
@@ -452,7 +450,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]):
@classmethod
def _detach_states(
self,
- states: Iterable[InstanceState[_T]],
+ states: Iterable[InstanceState[_O]],
session: Session,
to_transient: bool = False,
) -> None:
@@ -497,7 +495,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]):
# used by the test suite, apparently
self._detach()
- def _cleanup(self, ref: weakref.ref[_T]) -> None:
+ def _cleanup(self, ref: weakref.ref[_O]) -> None:
"""Weakref callback cleanup.
This callable cleans out the state when it is being garbage
@@ -657,14 +655,14 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]):
@classmethod
def _instance_level_callable_processor(
- cls, manager: ClassManager[_T], fn: _LoaderCallable, key: Any
- ) -> Callable[[InstanceState[_T], _InstanceDict, Row], None]:
+ cls, manager: ClassManager[_O], fn: _LoaderCallable, key: Any
+ ) -> Callable[[InstanceState[_O], _InstanceDict, Row], None]:
impl = manager[key].impl
if is_collection_impl(impl):
fixed_impl = impl
def _set_callable(
- state: InstanceState[_T], dict_: _InstanceDict, row: Row
+ state: InstanceState[_O], dict_: _InstanceDict, row: Row
) -> None:
if "callables" not in state.__dict__:
state.callables = {}
@@ -676,7 +674,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]):
else:
def _set_callable(
- state: InstanceState[_T], dict_: _InstanceDict, row: Row
+ state: InstanceState[_O], dict_: _InstanceDict, row: Row
) -> None:
if "callables" not in state.__dict__:
state.callables = {}
@@ -768,7 +766,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_T]):
self.manager.dispatch.expire(self, attribute_names)
def _load_expired(
- self, state: InstanceState[_T], passive: PassiveFlag
+ self, state: InstanceState[_O], passive: PassiveFlag
) -> LoaderCallableStatus:
"""__call__ allows the InstanceState to act as a deferred
callable for loading expired attributes, which is also
diff --git a/lib/sqlalchemy/pool/events.py b/lib/sqlalchemy/pool/events.py
index e961df1a3..1107c92b5 100644
--- a/lib/sqlalchemy/pool/events.py
+++ b/lib/sqlalchemy/pool/events.py
@@ -73,10 +73,8 @@ class PoolEvents(event.Events[Pool]):
return target.pool
elif isinstance(target, Pool):
return target
- elif hasattr(target, "dispatch") and hasattr(
- target.dispatch._events, "_no_async_engine_events"
- ):
- target.dispatch._events._no_async_engine_events()
+ elif hasattr(target, "_no_async_engine_events"):
+ target._no_async_engine_events()
else:
return None
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index ccd5e8c40..d58743340 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -73,6 +73,7 @@ if TYPE_CHECKING:
from .selectable import _SelectIterable
from .selectable import FromClause
from ..engine import Connection
+ from ..engine import CursorResult
from ..engine import Result
from ..engine.base import _CompiledCacheType
from ..engine.interfaces import _CoreMultiExecuteParams
@@ -983,7 +984,7 @@ class Executable(roles.StatementRole, Generative):
distilled_params: _CoreMultiExecuteParams,
execution_options: _ExecuteOptionsParameter,
_force: bool = False,
- ) -> Result:
+ ) -> CursorResult:
...
@util.ro_non_memoized_property
diff --git a/lib/sqlalchemy/util/_concurrency_py3k.py b/lib/sqlalchemy/util/_concurrency_py3k.py
index 28b062d3d..6ad099eef 100644
--- a/lib/sqlalchemy/util/_concurrency_py3k.py
+++ b/lib/sqlalchemy/util/_concurrency_py3k.py
@@ -4,6 +4,7 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+from __future__ import annotations
import asyncio
from contextvars import copy_context as _copy_context
@@ -19,6 +20,8 @@ from .langhelpers import memoized_property
from .. import exc
from ..util.typing import Protocol
+_T = TypeVar("_T", bound=Any)
+
if typing.TYPE_CHECKING:
class greenlet(Protocol):
@@ -52,8 +55,6 @@ if not typing.TYPE_CHECKING:
except (ImportError, AttributeError):
_copy_context = None # noqa
-_T = TypeVar("_T", bound=Any)
-
def is_exit_exception(e: BaseException) -> bool:
# note asyncio.CancelledError is already BaseException
@@ -128,11 +129,11 @@ def await_fallback(awaitable: Awaitable[_T]) -> _T:
async def greenlet_spawn(
- fn: Callable[..., Any],
+ fn: Callable[..., _T],
*args: Any,
_require_await: bool = False,
**kwargs: Any,
-) -> Any:
+) -> _T:
"""Runs a sync function ``fn`` in a new greenlet.
The sync function can then use :func:`await_` to wait for async
@@ -143,6 +144,7 @@ async def greenlet_spawn(
:param \\*\\*kwargs: Keyword arguments to pass to the ``fn`` callable.
"""
+ result: _T
context = _AsyncIoGreenlet(fn, getcurrent())
# runs the function synchronously in gl greenlet. If the execution
# is interrupted by await_, context is not dead and result is a