diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-04-07 12:37:23 -0400 |
|---|---|---|
| committer | mike bayer <mike_mp@zzzcomputing.com> | 2022-04-12 02:09:50 +0000 |
| commit | aa9cd878e8249a4a758c7f968e929e92fede42a5 (patch) | |
| tree | 1be1c9dc24dd247a150be55d65bfc56ebaf111bc /lib/sqlalchemy/orm/session.py | |
| parent | 98eae4e181cb2d1bbc67ec834bfad29dcba7f461 (diff) | |
| download | sqlalchemy-aa9cd878e8249a4a758c7f968e929e92fede42a5.tar.gz | |
pep-484: session, instancestate, etc
Also adds some fixes to annotation-based mapping
that have come up, as well as starts to add more
pep-484 test cases
Change-Id: Ia722bbbc7967a11b23b66c8084eb61df9d233fee
Diffstat (limited to 'lib/sqlalchemy/orm/session.py')
| -rw-r--r-- | lib/sqlalchemy/orm/session.py | 936 |
1 files changed, 573 insertions, 363 deletions
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 77a97936b..55ce73cf5 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -13,12 +13,20 @@ import itertools import sys import typing from typing import Any +from typing import Callable +from typing import cast from typing import Dict +from typing import Iterable +from typing import Iterator from typing import List +from typing import NoReturn from typing import Optional -from typing import overload +from typing import Sequence +from typing import Set from typing import Tuple from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union import weakref @@ -30,14 +38,20 @@ from . import loading from . import persistence from . import query from . import state as statelib +from ._typing import is_composite_class +from ._typing import is_user_defined_option from .base import _class_to_mapper -from .base import _IdentityKeyType from .base import _none_set from .base import _state_mapper from .base import instance_str +from .base import LoaderCallableStatus from .base import object_mapper from .base import object_state +from .base import PassiveFlag from .base import state_str +from .context import FromStatement +from .context import ORMCompileState +from .identity import IdentityMap from .query import Query from .state import InstanceState from .state_changes import _StateChange @@ -51,22 +65,41 @@ from .. import util from ..engine import Connection from ..engine import Engine from ..engine.util import TransactionalContext +from ..event import dispatcher +from ..event import EventTarget from ..inspection import inspect from ..sql import coercions from ..sql import dml from ..sql import roles +from ..sql import Select from ..sql import visitors from ..sql.base import CompileState +from ..sql.selectable import ForUpdateArg from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..util import IdentitySet from ..util.typing import Literal +from ..util.typing import Protocol if typing.TYPE_CHECKING: + from ._typing import _IdentityKeyType + from ._typing import _InstanceDict + from .interfaces import ORMOption + from .interfaces import UserDefinedOption from .mapper import Mapper + from .path_registry import PathRegistry + from ..engine import Result from ..engine import Row + from ..engine.base import Transaction + from ..engine.base import TwoPhaseTransaction + from ..engine.interfaces import _CoreAnyExecuteParams + from ..engine.interfaces import _CoreSingleExecuteParams + from ..engine.interfaces import _ExecuteOptions + from ..engine.interfaces import _ExecuteOptionsParameter + from ..engine.result import ScalarResult + from ..event import _InstanceLevelDispatch from ..sql._typing import _ColumnsClauseArgument - from ..sql._typing import _ExecuteOptions - from ..sql._typing import _ExecuteParams from ..sql.base import Executable + from ..sql.elements import ClauseElement from ..sql.schema import Table __all__ = [ @@ -80,14 +113,45 @@ __all__ = [ "object_session", ] -_sessions = weakref.WeakValueDictionary() +_sessions: weakref.WeakValueDictionary[ + int, Session +] = weakref.WeakValueDictionary() """Weak-referencing dictionary of :class:`.Session` objects. """ +_O = TypeVar("_O", bound=object) statelib._sessions = _sessions +_PKIdentityArgument = Union[Any, Tuple[Any, ...]] -def _state_session(state): +_EntityBindKey = Union[Type[_O], "Mapper[_O]"] +_SessionBindKey = Union[Type[Any], "Mapper[Any]", "Table"] +_SessionBind = Union["Engine", "Connection"] + + +class _ConnectionCallableProto(Protocol): + """a callable that returns a :class:`.Connection` given an instance. + + This callable, when present on a :class:`.Session`, is called only from the + ORM's persistence mechanism (i.e. the unit of work flush process) to allow + for connection-per-instance schemes (i.e. horizontal sharding) to be used + as persistence time. + + This callable is not present on a plain :class:`.Session`, however + is established when using the horizontal sharding extension. + + """ + + def __call__( + self, + mapper: Optional[Mapper[Any]] = None, + instance: Optional[object] = None, + **kw: Any, + ) -> Connection: + ... + + +def _state_session(state: InstanceState[Any]) -> Optional[Session]: """Given an :class:`.InstanceState`, return the :class:`.Session` associated, if any. """ @@ -110,39 +174,16 @@ class _SessionClassMethods: close_all_sessions() @classmethod - @overload - def identity_key( - cls, - class_: type, - ident: Tuple[Any, ...], - *, - identity_token: Optional[str], - ) -> _IdentityKeyType: - ... - - @classmethod - @overload - def identity_key(cls, *, instance: Any) -> _IdentityKeyType: - ... - - @classmethod - @overload - def identity_key( - cls, class_: type, *, row: "Row", identity_token: Optional[str] - ) -> _IdentityKeyType: - ... - - @classmethod @util.preload_module("sqlalchemy.orm.util") def identity_key( cls, - class_=None, - ident=None, + class_: Optional[Type[Any]] = None, + ident: Union[Any, Tuple[Any, ...]] = None, *, - instance=None, - row=None, - identity_token=None, - ) -> _IdentityKeyType: + instance: Optional[Any] = None, + row: Optional[Row] = None, + identity_token: Optional[Any] = None, + ) -> _IdentityKeyType[Any]: """Return an identity key. This is an alias of :func:`.util.identity_key`. @@ -157,7 +198,7 @@ class _SessionClassMethods: ) @classmethod - def object_session(cls, instance: Any) -> "Session": + def object_session(cls, instance: object) -> Optional[Session]: """Return the :class:`.Session` to which an object belongs. This is an alias of :func:`.object_session`. @@ -205,26 +246,26 @@ class ORMExecuteState(util.MemoizedSlots): "_update_execution_options", ) - session: "Session" - statement: "Executable" - parameters: "_ExecuteParams" - execution_options: "_ExecuteOptions" - local_execution_options: "_ExecuteOptions" + session: Session + statement: Executable + parameters: Optional[_CoreAnyExecuteParams] + execution_options: _ExecuteOptions + local_execution_options: _ExecuteOptions bind_arguments: Dict[str, Any] - _compile_state_cls: Type[context.ORMCompileState] - _starting_event_idx: Optional[int] + _compile_state_cls: Optional[Type[ORMCompileState]] + _starting_event_idx: int _events_todo: List[Any] - _update_execution_options: Optional["_ExecuteOptions"] + _update_execution_options: Optional[_ExecuteOptions] def __init__( self, - session: "Session", - statement: "Executable", - parameters: "_ExecuteParams", - execution_options: "_ExecuteOptions", + session: Session, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams], + execution_options: _ExecuteOptions, bind_arguments: Dict[str, Any], - compile_state_cls: Type[context.ORMCompileState], - events_todo: List[Any], + compile_state_cls: Optional[Type[ORMCompileState]], + events_todo: List[_InstanceLevelDispatch[Session]], ): self.session = session self.statement = statement @@ -237,16 +278,16 @@ class ORMExecuteState(util.MemoizedSlots): self._compile_state_cls = compile_state_cls self._events_todo = list(events_todo) - def _remaining_events(self): + def _remaining_events(self) -> List[_InstanceLevelDispatch[Session]]: return self._events_todo[self._starting_event_idx + 1 :] def invoke_statement( self, - statement=None, - params=None, - execution_options=None, - bind_arguments=None, - ): + statement: Optional[Executable] = None, + params: Optional[_CoreAnyExecuteParams] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, + bind_arguments: Optional[Dict[str, Any]] = None, + ) -> Result: """Execute the statement represented by this :class:`.ORMExecuteState`, without re-invoking events that have already proceeded. @@ -270,9 +311,12 @@ class ORMExecuteState(util.MemoizedSlots): :param statement: optional statement to be invoked, in place of the statement currently represented by :attr:`.ORMExecuteState.statement`. - :param params: optional dictionary of parameters which will be merged - into the existing :attr:`.ORMExecuteState.parameters` of this - :class:`.ORMExecuteState`. + :param params: optional dictionary of parameters or list of parameters + which will be merged into the existing + :attr:`.ORMExecuteState.parameters` of this :class:`.ORMExecuteState`. + + .. versionchanged:: 2.0 a list of parameter dictionaries is accepted + for executemany executions. :param execution_options: optional dictionary of execution options will be merged into the existing @@ -302,9 +346,32 @@ class ORMExecuteState(util.MemoizedSlots): _bind_arguments.update(bind_arguments) _bind_arguments["_sa_skip_events"] = True + _params: Optional[_CoreAnyExecuteParams] if params: - _params = dict(self.parameters) - _params.update(params) + if self.is_executemany: + _params = [] + exec_many_parameters = cast( + "List[Dict[str, Any]]", self.parameters + ) + for _existing_params, _new_params in itertools.zip_longest( + exec_many_parameters, + cast("List[Dict[str, Any]]", params), + ): + if _existing_params is None or _new_params is None: + raise sa_exc.InvalidRequestError( + f"Can't apply executemany parameters to " + f"statement; number of parameter sets passed to " + f"Session.execute() ({len(exec_many_parameters)}) " + f"does not match number of parameter sets given " + f"to ORMExecuteState.invoke_statement() " + f"({len(params)})" + ) + _existing_params = dict(_existing_params) + _existing_params.update(_new_params) + _params.append(_existing_params) + else: + _params = dict(cast("Dict[str, Any]", self.parameters)) + _params.update(cast("Dict[str, Any]", params)) else: _params = self.parameters @@ -321,7 +388,7 @@ class ORMExecuteState(util.MemoizedSlots): ) @property - def bind_mapper(self): + def bind_mapper(self) -> Optional[Mapper[Any]]: """Return the :class:`_orm.Mapper` that is the primary "bind" mapper. For an :class:`_orm.ORMExecuteState` object invoking an ORM @@ -349,7 +416,7 @@ class ORMExecuteState(util.MemoizedSlots): return self.bind_arguments.get("mapper", None) @property - def all_mappers(self): + def all_mappers(self) -> Sequence[Mapper[Any]]: """Return a sequence of all :class:`_orm.Mapper` objects that are involved at the top level of this statement. @@ -369,7 +436,7 @@ class ORMExecuteState(util.MemoizedSlots): """ if not self.is_orm_statement: return [] - elif self.is_select: + elif isinstance(self.statement, (Select, FromStatement)): result = [] seen = set() for d in self.statement.column_descriptions: @@ -380,13 +447,13 @@ class ORMExecuteState(util.MemoizedSlots): seen.add(insp.mapper) result.append(insp.mapper) return result - elif self.is_update or self.is_delete: + elif self.statement.is_dml and self.bind_mapper: return [self.bind_mapper] else: return [] @property - def is_orm_statement(self): + def is_orm_statement(self) -> bool: """return True if the operation is an ORM statement. This indicates that the select(), update(), or delete() being @@ -399,44 +466,64 @@ class ORMExecuteState(util.MemoizedSlots): return self._compile_state_cls is not None @property - def is_select(self): + def is_executemany(self) -> bool: + """return True if the parameters are a multi-element list of + dictionaries with more than one dictionary. + + .. versionadded:: 2.0 + + """ + return isinstance(self.parameters, list) + + @property + def is_select(self) -> bool: """return True if this is a SELECT operation.""" return self.statement.is_select @property - def is_insert(self): + def is_insert(self) -> bool: """return True if this is an INSERT operation.""" return self.statement.is_dml and self.statement.is_insert @property - def is_update(self): + def is_update(self) -> bool: """return True if this is an UPDATE operation.""" return self.statement.is_dml and self.statement.is_update @property - def is_delete(self): + def is_delete(self) -> bool: """return True if this is a DELETE operation.""" return self.statement.is_dml and self.statement.is_delete @property - def _is_crud(self): + def _is_crud(self) -> bool: return isinstance(self.statement, (dml.Update, dml.Delete)) - def update_execution_options(self, **opts): + def update_execution_options(self, **opts: _ExecuteOptions) -> None: + """Update the local execution options with new values.""" # TODO: no coverage self.local_execution_options = self.local_execution_options.union(opts) - def _orm_compile_options(self): + def _orm_compile_options( + self, + ) -> Optional[ + Union[ + context.ORMCompileState.default_compile_options, + Type[context.ORMCompileState.default_compile_options], + ] + ]: if not self.is_select: return None opts = self.statement._compile_options - if opts.isinstance(context.ORMCompileState.default_compile_options): - return opts + if opts is not None and opts.isinstance( + context.ORMCompileState.default_compile_options + ): + return opts # type: ignore else: return None @property - def lazy_loaded_from(self): + def lazy_loaded_from(self) -> Optional[InstanceState[Any]]: """An :class:`.InstanceState` that is using this statement execution for a lazy load operation. @@ -451,7 +538,7 @@ class ORMExecuteState(util.MemoizedSlots): return self.load_options._lazy_loaded_from @property - def loader_strategy_path(self): + def loader_strategy_path(self) -> Optional[PathRegistry]: """Return the :class:`.PathRegistry` for the current load path. This object represents the "path" in a query along relationships @@ -465,7 +552,7 @@ class ORMExecuteState(util.MemoizedSlots): return None @property - def is_column_load(self): + def is_column_load(self) -> bool: """Return True if the operation is refreshing column-oriented attributes on an existing ORM object. @@ -492,7 +579,7 @@ class ORMExecuteState(util.MemoizedSlots): return opts is not None and opts._for_refresh_state @property - def is_relationship_load(self): + def is_relationship_load(self) -> bool: """Return True if this load is loading objects on behalf of a relationship. @@ -518,7 +605,12 @@ class ORMExecuteState(util.MemoizedSlots): return path is not None and not path.is_root @property - def load_options(self): + def load_options( + self, + ) -> Union[ + context.QueryContext.default_load_options, + Type[context.QueryContext.default_load_options], + ]: """Return the load_options that will be used for this execution.""" if not self.is_select: @@ -531,7 +623,12 @@ class ORMExecuteState(util.MemoizedSlots): ) @property - def update_delete_options(self): + def update_delete_options( + self, + ) -> Union[ + persistence.BulkUDCompileState.default_update_options, + Type[persistence.BulkUDCompileState.default_update_options], + ]: """Return the update_delete_options that will be used for this execution.""" @@ -546,7 +643,7 @@ class ORMExecuteState(util.MemoizedSlots): ) @property - def user_defined_options(self): + def user_defined_options(self) -> Sequence[UserDefinedOption]: """The sequence of :class:`.UserDefinedOptions` that have been associated with the statement being invoked. @@ -554,7 +651,7 @@ class ORMExecuteState(util.MemoizedSlots): return [ opt for opt in self.statement._with_options - if not opt._is_compile_state and not opt._is_legacy_option + if is_user_defined_option(opt) ] @@ -597,14 +694,29 @@ class SessionTransaction(_StateChange, TransactionalContext): """ - _rollback_exception = None + _rollback_exception: Optional[BaseException] = None + + _connections: Dict[ + Union[Engine, Connection], Tuple[Connection, Transaction, bool, bool] + ] + session: Session + _parent: Optional[SessionTransaction] + + _state: SessionTransactionState + + _new: weakref.WeakKeyDictionary[InstanceState[Any], object] + _deleted: weakref.WeakKeyDictionary[InstanceState[Any], object] + _dirty: weakref.WeakKeyDictionary[InstanceState[Any], object] + _key_switches: weakref.WeakKeyDictionary[ + InstanceState[Any], Tuple[Any, Any] + ] def __init__( self, - session, - parent=None, - nested=False, - autobegin=False, + session: Session, + parent: Optional[SessionTransaction] = None, + nested: bool = False, + autobegin: bool = False, ): TransactionalContext._trans_ctx_check(session) @@ -629,7 +741,9 @@ class SessionTransaction(_StateChange, TransactionalContext): self.session.dispatch.after_transaction_create(self.session, self) - def _raise_for_prerequisite_state(self, operation_name, state): + def _raise_for_prerequisite_state( + self, operation_name: str, state: SessionTransactionState + ) -> NoReturn: if state is SessionTransactionState.DEACTIVE: if self._rollback_exception: raise sa_exc.PendingRollbackError( @@ -655,7 +769,7 @@ class SessionTransaction(_StateChange, TransactionalContext): ) @property - def parent(self): + def parent(self) -> Optional[SessionTransaction]: """The parent :class:`.SessionTransaction` of this :class:`.SessionTransaction`. @@ -673,7 +787,7 @@ class SessionTransaction(_StateChange, TransactionalContext): """ return self._parent - nested = False + nested: bool = False """Indicates if this is a nested, or SAVEPOINT, transaction. When :attr:`.SessionTransaction.nested` is True, it is expected @@ -682,33 +796,40 @@ class SessionTransaction(_StateChange, TransactionalContext): """ @property - def is_active(self): + def is_active(self) -> bool: return ( self.session is not None and self._state is SessionTransactionState.ACTIVE ) @property - def _is_transaction_boundary(self): + def _is_transaction_boundary(self) -> bool: return self.nested or not self._parent @_StateChange.declare_states( (SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE ) - def connection(self, bindkey, execution_options=None, **kwargs): + def connection( + self, + bindkey: Optional[Mapper[Any]], + execution_options: Optional[_ExecuteOptions] = None, + **kwargs: Any, + ) -> Connection: bind = self.session.get_bind(bindkey, **kwargs) return self._connection_for_bind(bind, execution_options) @_StateChange.declare_states( (SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE ) - def _begin(self, nested=False): + def _begin(self, nested: bool = False) -> SessionTransaction: return SessionTransaction(self.session, self, nested=nested) - def _iterate_self_and_parents(self, upto=None): + def _iterate_self_and_parents( + self, upto: Optional[SessionTransaction] = None + ) -> Iterable[SessionTransaction]: current = self - result = () + result: Tuple[SessionTransaction, ...] = () while current: result += (current,) if current._parent is upto: @@ -723,12 +844,14 @@ class SessionTransaction(_StateChange, TransactionalContext): return result - def _take_snapshot(self, autobegin=False): + def _take_snapshot(self, autobegin: bool = False) -> None: if not self._is_transaction_boundary: - self._new = self._parent._new - self._deleted = self._parent._deleted - self._dirty = self._parent._dirty - self._key_switches = self._parent._key_switches + parent = self._parent + assert parent is not None + self._new = parent._new + self._deleted = parent._deleted + self._dirty = parent._dirty + self._key_switches = parent._key_switches return if not autobegin and not self.session._flushing: @@ -739,7 +862,7 @@ class SessionTransaction(_StateChange, TransactionalContext): self._dirty = weakref.WeakKeyDictionary() self._key_switches = weakref.WeakKeyDictionary() - def _restore_snapshot(self, dirty_only=False): + def _restore_snapshot(self, dirty_only: bool = False) -> None: """Restore the restoration state taken before a transaction began. Corresponds to a rollback. @@ -771,7 +894,7 @@ class SessionTransaction(_StateChange, TransactionalContext): if not dirty_only or s.modified or s in self._dirty: s._expire(s.dict, self.session.identity_map._modified) - def _remove_snapshot(self): + def _remove_snapshot(self) -> None: """Remove the restoration state taken before a transaction began. Corresponds to a commit. @@ -788,15 +911,21 @@ class SessionTransaction(_StateChange, TransactionalContext): ) self._deleted.clear() elif self.nested: - self._parent._new.update(self._new) - self._parent._dirty.update(self._dirty) - self._parent._deleted.update(self._deleted) - self._parent._key_switches.update(self._key_switches) + parent = self._parent + assert parent is not None + parent._new.update(self._new) + parent._dirty.update(self._dirty) + parent._deleted.update(self._deleted) + parent._key_switches.update(self._key_switches) @_StateChange.declare_states( (SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE ) - def _connection_for_bind(self, bind, execution_options): + def _connection_for_bind( + self, + bind: _SessionBind, + execution_options: Optional[_ExecuteOptions], + ) -> Connection: if bind in self._connections: if execution_options: @@ -829,6 +958,7 @@ class SessionTransaction(_StateChange, TransactionalContext): if execution_options: conn = conn.execution_options(**execution_options) + transaction: Transaction if self.session.twophase and self._parent is None: transaction = conn.begin_twophase() elif self.nested: @@ -837,9 +967,9 @@ class SessionTransaction(_StateChange, TransactionalContext): # if given a future connection already in a transaction, don't # commit that transaction unless it is a savepoint if conn.in_nested_transaction(): - transaction = conn.get_nested_transaction() + transaction = conn._get_required_nested_transaction() else: - transaction = conn.get_transaction() + transaction = conn._get_required_transaction() should_commit = False else: transaction = conn.begin() @@ -861,7 +991,7 @@ class SessionTransaction(_StateChange, TransactionalContext): self.session.dispatch.after_begin(self.session, self, conn) return conn - def prepare(self): + def prepare(self) -> None: if self._parent is not None or not self.session.twophase: raise sa_exc.InvalidRequestError( "'twophase' mode not enabled, or not root transaction; " @@ -872,12 +1002,13 @@ class SessionTransaction(_StateChange, TransactionalContext): @_StateChange.declare_states( (SessionTransactionState.ACTIVE,), SessionTransactionState.PREPARED ) - def _prepare_impl(self): + def _prepare_impl(self) -> None: if self._parent is None or self.nested: self.session.dispatch.before_commit(self.session) stx = self.session._transaction + assert stx is not None if stx is not self: for subtransaction in stx._iterate_self_and_parents(upto=self): subtransaction.commit() @@ -897,7 +1028,7 @@ class SessionTransaction(_StateChange, TransactionalContext): if self._parent is None and self.session.twophase: try: for t in set(self._connections.values()): - t[1].prepare() + cast("TwoPhaseTransaction", t[1]).prepare() except: with util.safe_reraise(): self.rollback() @@ -929,9 +1060,7 @@ class SessionTransaction(_StateChange, TransactionalContext): self.close() if _to_root and self._parent: - return self._parent.commit(_to_root=True) - - return self._parent + self._parent.commit(_to_root=True) @_StateChange.declare_states( ( @@ -941,9 +1070,12 @@ class SessionTransaction(_StateChange, TransactionalContext): ), SessionTransactionState.CLOSED, ) - def rollback(self, _capture_exception=False, _to_root=False): + def rollback( + self, _capture_exception: bool = False, _to_root: bool = False + ) -> None: stx = self.session._transaction + assert stx is not None if stx is not self: for subtransaction in stx._iterate_self_and_parents(upto=self): subtransaction.close() @@ -993,19 +1125,18 @@ class SessionTransaction(_StateChange, TransactionalContext): if self._parent and _capture_exception: self._parent._rollback_exception = sys.exc_info()[1] - if rollback_err: + if rollback_err and rollback_err[1]: raise rollback_err[1].with_traceback(rollback_err[2]) sess.dispatch.after_soft_rollback(sess, self) if _to_root and self._parent: - return self._parent.rollback(_to_root=True) - return self._parent + self._parent.rollback(_to_root=True) @_StateChange.declare_states( _StateChangeStates.ANY, SessionTransactionState.CLOSED ) - def close(self, invalidate=False): + def close(self, invalidate: bool = False) -> None: if self.nested: self.session._nested_transaction = ( self._previous_nested_transaction @@ -1027,25 +1158,30 @@ class SessionTransaction(_StateChange, TransactionalContext): self._state = SessionTransactionState.CLOSED sess = self.session - self.session = None - self._connections = None + # TODO: these two None sets were historically after the + # event hook below, and in 2.0 I changed it this way for some reason, + # and I remember there being a reason, but not what it was. + # Why do we need to get rid of them at all? test_memusage::CycleTest + # passes with these commented out. + # self.session = None # type: ignore + # self._connections = None # type: ignore sess.dispatch.after_transaction_end(sess, self) - def _get_subject(self): + def _get_subject(self) -> Session: return self.session - def _transaction_is_active(self): + def _transaction_is_active(self) -> bool: return self._state is SessionTransactionState.ACTIVE - def _transaction_is_closed(self): + def _transaction_is_closed(self) -> bool: return self._state is SessionTransactionState.CLOSED - def _rollback_can_be_called(self): + def _rollback_can_be_called(self) -> bool: return self._state not in (COMMITTED, CLOSED) -class Session(_SessionClassMethods): +class Session(_SessionClassMethods, EventTarget): """Manages persistence operations for ORM-mapped objects. The Session's usage paradigm is described at :doc:`/orm/session`. @@ -1055,15 +1191,27 @@ class Session(_SessionClassMethods): _is_asyncio = False - identity_map: identity.IdentityMap - _new: Dict["InstanceState", Any] - _deleted: Dict["InstanceState", Any] + dispatch: dispatcher[Session] + + identity_map: IdentityMap + """A mapping of object identities to objects themselves. + + Iterating through ``Session.identity_map.values()`` provides + access to the full set of persistent objects (i.e., those + that have row identity) currently in the session. + + .. seealso:: + + :func:`.identity_key` - helper function to produce the keys used + in this dictionary. + + """ + + _new: Dict[InstanceState[Any], Any] + _deleted: Dict[InstanceState[Any], Any] bind: Optional[Union[Engine, Connection]] - __binds: Dict[ - Union[type, "Mapper", "Table"], - Union[engine.Engine, engine.Connection], - ] - _flusing: bool + __binds: Dict[_SessionBindKey, _SessionBind] + _flushing: bool _warn_on_events: bool _transaction: Optional[SessionTransaction] _nested_transaction: Optional[SessionTransaction] @@ -1072,24 +1220,19 @@ class Session(_SessionClassMethods): expire_on_commit: bool enable_baked_queries: bool twophase: bool - _query_cls: Type[Query] + _query_cls: Type[Query[Any]] def __init__( self, - bind: Optional[Union[engine.Engine, engine.Connection]] = None, + bind: Optional[_SessionBind] = None, autoflush: bool = True, future: Literal[True] = True, expire_on_commit: bool = True, twophase: bool = False, - binds: Optional[ - Dict[ - Union[type, "Mapper", "Table"], - Union[engine.Engine, engine.Connection], - ] - ] = None, + binds: Optional[Dict[_SessionBindKey, _SessionBind]] = None, enable_baked_queries: bool = True, info: Optional[Dict[Any, Any]] = None, - query_cls: Optional[Type[query.Query]] = None, + query_cls: Optional[Type[Query[Any]]] = None, autocommit: Literal[False] = False, ): r"""Construct a new Session. @@ -1249,23 +1392,23 @@ class Session(_SessionClassMethods): _sessions[self.hash_key] = self # used by sqlalchemy.engine.util.TransactionalContext - _trans_context_manager = None + _trans_context_manager: Optional[TransactionalContext] = None - connection_callable = None + connection_callable: Optional[_ConnectionCallableProto] = None - def __enter__(self): + def __enter__(self) -> Session: return self - def __exit__(self, type_, value, traceback): + def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: self.close() @contextlib.contextmanager - def _maker_context_manager(self): + def _maker_context_manager(self) -> Iterator[Session]: with self: with self.begin(): yield self - def in_transaction(self): + def in_transaction(self) -> bool: """Return True if this :class:`_orm.Session` has begun a transaction. .. versionadded:: 1.4 @@ -1278,7 +1421,7 @@ class Session(_SessionClassMethods): """ return self._transaction is not None - def in_nested_transaction(self): + def in_nested_transaction(self) -> bool: """Return True if this :class:`_orm.Session` has begun a nested transaction, e.g. SAVEPOINT. @@ -1287,7 +1430,7 @@ class Session(_SessionClassMethods): """ return self._nested_transaction is not None - def get_transaction(self): + def get_transaction(self) -> Optional[SessionTransaction]: """Return the current root transaction in progress, if any. .. versionadded:: 1.4 @@ -1298,7 +1441,7 @@ class Session(_SessionClassMethods): trans = trans._parent return trans - def get_nested_transaction(self): + def get_nested_transaction(self) -> Optional[SessionTransaction]: """Return the current nested transaction in progress, if any. .. versionadded:: 1.4 @@ -1308,7 +1451,7 @@ class Session(_SessionClassMethods): return self._nested_transaction @util.memoized_property - def info(self): + def info(self) -> Dict[Any, Any]: """A user-modifiable dictionary. The initial value of this dictionary can be populated using the @@ -1320,16 +1463,18 @@ class Session(_SessionClassMethods): """ return {} - def _autobegin(self): + def _autobegin_t(self) -> SessionTransaction: if self._transaction is None: trans = SessionTransaction(self, autobegin=True) assert self._transaction is trans - return True + return trans - return False + return self._transaction - def begin(self, nested=False, _subtrans=False): + def begin( + self, nested: bool = False, _subtrans: bool = False + ) -> SessionTransaction: """Begin a transaction, or nested transaction, on this :class:`.Session`, if one is not already begun. @@ -1364,13 +1509,16 @@ class Session(_SessionClassMethods): """ - if self._autobegin(): + trans = self._transaction + if trans is None: + trans = self._autobegin_t() + if not nested and not _subtrans: - return self._transaction + return trans - if self._transaction is not None: + if trans is not None: if _subtrans or nested: - trans = self._transaction._begin(nested=nested) + trans = trans._begin(nested=nested) assert self._transaction is trans if nested: self._nested_transaction = trans @@ -1386,9 +1534,12 @@ class Session(_SessionClassMethods): trans = SessionTransaction(self) assert self._transaction is trans - return self._transaction # needed for __enter__/__exit__ hook + if TYPE_CHECKING: + assert self._transaction is not None + + return trans # needed for __enter__/__exit__ hook - def begin_nested(self): + def begin_nested(self) -> SessionTransaction: """Begin a "nested" transaction on this Session, e.g. SAVEPOINT. The target database(s) and associated drivers must support SQL @@ -1413,7 +1564,7 @@ class Session(_SessionClassMethods): """ return self.begin(nested=True) - def rollback(self): + def rollback(self) -> None: """Rollback the current transaction in progress. If no transaction is in progress, this method is a pass-through. @@ -1450,11 +1601,11 @@ class Session(_SessionClassMethods): :ref:`unitofwork_transaction` """ - if self._transaction is None: - if not self._autobegin(): - raise sa_exc.InvalidRequestError("No transaction is begun.") + trans = self._transaction + if trans is None: + trans = self._autobegin_t() - self._transaction.commit(_to_root=True) + trans.commit(_to_root=True) def prepare(self) -> None: """Prepare the current transaction in progress for two phase commit. @@ -1467,16 +1618,16 @@ class Session(_SessionClassMethods): :exc:`~sqlalchemy.exc.InvalidRequestError` is raised. """ - if self._transaction is None: - if not self._autobegin(): - raise sa_exc.InvalidRequestError("No transaction is begun.") + trans = self._transaction + if trans is None: + trans = self._autobegin_t() - self._transaction.prepare() + trans.prepare() def connection( self, bind_arguments: Optional[Dict[str, Any]] = None, - execution_options: Optional["_ExecuteOptions"] = None, + execution_options: Optional[_ExecuteOptions] = None, ) -> "Connection": r"""Return a :class:`_engine.Connection` object corresponding to this :class:`.Session` object's transactional state. @@ -1521,24 +1672,28 @@ class Session(_SessionClassMethods): execution_options=execution_options, ) - def _connection_for_bind(self, engine, execution_options=None, **kw): + def _connection_for_bind( + self, + engine: _SessionBind, + execution_options: Optional[_ExecuteOptions] = None, + **kw: Any, + ) -> Connection: TransactionalContext._trans_ctx_check(self) - if self._transaction is None: - assert self._autobegin() - return self._transaction._connection_for_bind( - engine, execution_options - ) + trans = self._transaction + if trans is None: + trans = self._autobegin_t() + return trans._connection_for_bind(engine, execution_options) def execute( self, - statement: "Executable", - params: Optional["_ExecuteParams"] = None, - execution_options: "_ExecuteOptions" = util.EMPTY_DICT, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ): + ) -> Result: r"""Execute a SQL expression construct. Returns a :class:`_engine.Result` object representing @@ -1603,6 +1758,8 @@ class Session(_SessionClassMethods): compile_state_cls = CompileState._get_plugin_class_for_plugin( statement, "orm" ) + if TYPE_CHECKING: + assert isinstance(compile_state_cls, ORMCompileState) else: compile_state_cls = None @@ -1645,9 +1802,9 @@ class Session(_SessionClassMethods): ) for idx, fn in enumerate(events_todo): orm_exec_state._starting_event_idx = idx - result = fn(orm_exec_state) - if result: - return result + fn_result: Optional[Result] = fn(orm_exec_state) + if fn_result: + return fn_result statement = orm_exec_state.statement execution_options = orm_exec_state.local_execution_options @@ -1655,7 +1812,9 @@ class Session(_SessionClassMethods): bind = self.get_bind(**bind_arguments) conn = self._connection_for_bind(bind) - result = conn.execute(statement, params or {}, execution_options) + result: Result = conn.execute( + statement, params or {}, execution_options + ) if compile_state_cls: result = compile_state_cls.orm_setup_cursor_result( @@ -1671,12 +1830,12 @@ class Session(_SessionClassMethods): def scalar( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[Dict[str, Any]] = None, + **kw: Any, + ) -> Any: """Execute a statement and return a scalar result. Usage and parameters are the same as that of @@ -1695,12 +1854,12 @@ class Session(_SessionClassMethods): def scalars( self, - statement, - params=None, - execution_options=util.EMPTY_DICT, - bind_arguments=None, - **kw, - ): + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[Dict[str, Any]] = None, + **kw: Any, + ) -> ScalarResult[Any]: """Execute a statement and return the results as scalars. Usage and parameters are the same as that of @@ -1722,7 +1881,7 @@ class Session(_SessionClassMethods): **kw, ).scalars() - def close(self): + def close(self) -> None: """Close out the transactional resources and ORM objects used by this :class:`_orm.Session`. @@ -1754,7 +1913,7 @@ class Session(_SessionClassMethods): """ self._close_impl(invalidate=False) - def invalidate(self): + def invalidate(self) -> None: """Close this Session, using connection invalidation. This is a variant of :meth:`.Session.close` that will additionally @@ -1790,13 +1949,13 @@ class Session(_SessionClassMethods): """ self._close_impl(invalidate=True) - def _close_impl(self, invalidate): + def _close_impl(self, invalidate: bool) -> None: self.expunge_all() if self._transaction is not None: for transaction in self._transaction._iterate_self_and_parents(): transaction.close(invalidate) - def expunge_all(self): + def expunge_all(self) -> None: """Remove all object instances from this ``Session``. This is equivalent to calling ``expunge(obj)`` on all objects in this @@ -1812,7 +1971,7 @@ class Session(_SessionClassMethods): statelib.InstanceState._detach_states(all_states, self) - def _add_bind(self, key, bind): + def _add_bind(self, key: _SessionBindKey, bind: _SessionBind) -> None: try: insp = inspect(key) except sa_exc.NoInspectionAvailable as err: @@ -1834,7 +1993,9 @@ class Session(_SessionClassMethods): "Not an acceptable bind target: %s" % key ) - def bind_mapper(self, mapper, bind): + def bind_mapper( + self, mapper: _EntityBindKey[_O], bind: _SessionBind + ) -> None: """Associate a :class:`_orm.Mapper` or arbitrary Python class with a "bind", e.g. an :class:`_engine.Engine` or :class:`_engine.Connection`. @@ -1862,7 +2023,7 @@ class Session(_SessionClassMethods): """ self._add_bind(mapper, bind) - def bind_table(self, table, bind): + def bind_table(self, table: Table, bind: _SessionBind) -> None: """Associate a :class:`_schema.Table` with a "bind", e.g. an :class:`_engine.Engine` or :class:`_engine.Connection`. @@ -1892,12 +2053,12 @@ class Session(_SessionClassMethods): def get_bind( self, - mapper=None, - clause=None, - bind=None, - _sa_skip_events=None, - _sa_skip_for_implicit_returning=False, - ): + mapper: Optional[_EntityBindKey[_O]] = None, + clause: Optional[ClauseElement] = None, + bind: Optional[_SessionBind] = None, + _sa_skip_events: Optional[bool] = None, + _sa_skip_for_implicit_returning: bool = False, + ) -> Union[Engine, Connection]: """Return a "bind" to which this :class:`.Session` is bound. The "bind" is usually an instance of :class:`_engine.Engine`, @@ -1995,23 +2156,25 @@ class Session(_SessionClassMethods): # look more closely at the mapper. if mapper is not None: try: - mapper = inspect(mapper) + inspected_mapper = inspect(mapper) except sa_exc.NoInspectionAvailable as err: if isinstance(mapper, type): raise exc.UnmappedClassError(mapper) from err else: raise + else: + inspected_mapper = None # match up the mapper or clause in the __binds if self.__binds: # matching mappers and selectables to entries in the # binds dictionary; supported use case. - if mapper: - for cls in mapper.class_.__mro__: + if inspected_mapper: + for cls in inspected_mapper.class_.__mro__: if cls in self.__binds: return self.__binds[cls] if clause is None: - clause = mapper.persist_selectable + clause = inspected_mapper.persist_selectable if clause is not None: plugin_subject = clause._propagate_attrs.get( @@ -2025,6 +2188,8 @@ class Session(_SessionClassMethods): for obj in visitors.iterate(clause): if obj in self.__binds: + if TYPE_CHECKING: + assert isinstance(obj, Table) return self.__binds[obj] # none of the __binds matched, but we have a fallback bind. @@ -2033,17 +2198,19 @@ class Session(_SessionClassMethods): return self.bind context = [] - if mapper is not None: - context.append("mapper %s" % mapper) + if inspected_mapper is not None: + context.append(f"mapper {inspected_mapper}") if clause is not None: context.append("SQL expression") raise sa_exc.UnboundExecutionError( - "Could not locate a bind configured on %s or this Session." - % (", ".join(context),), + f"Could not locate a bind configured on " + f'{", ".join(context)} or this Session.' ) - def query(self, *entities: _ColumnsClauseArgument, **kwargs: Any) -> Query: + def query( + self, *entities: _ColumnsClauseArgument, **kwargs: Any + ) -> Query[Any]: """Return a new :class:`_query.Query` object corresponding to this :class:`_orm.Session`. @@ -2065,12 +2232,12 @@ class Session(_SessionClassMethods): def _identity_lookup( self, - mapper, - primary_key_identity, - identity_token=None, - passive=attributes.PASSIVE_OFF, - lazy_loaded_from=None, - ): + mapper: Mapper[_O], + primary_key_identity: Union[Any, Tuple[Any, ...]], + identity_token: Any = None, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + lazy_loaded_from: Optional[InstanceState[Any]] = None, + ) -> Union[Optional[_O], LoaderCallableStatus]: """Locate an object in the identity map. Given a primary key identity, constructs an identity key and then @@ -2117,9 +2284,9 @@ class Session(_SessionClassMethods): ) return loading.get_from_identity(self, mapper, key, passive) - @property + @util.non_memoized_property @contextlib.contextmanager - def no_autoflush(self): + def no_autoflush(self) -> Iterator[Session]: """Return a context manager that disables autoflush. e.g.:: @@ -2145,7 +2312,7 @@ class Session(_SessionClassMethods): finally: self.autoflush = autoflush - def _autoflush(self): + def _autoflush(self) -> None: if self.autoflush and not self._flushing: try: self.flush() @@ -2161,7 +2328,12 @@ class Session(_SessionClassMethods): ) raise e.with_traceback(sys.exc_info()[2]) - def refresh(self, instance, attribute_names=None, with_for_update=None): + def refresh( + self, + instance: object, + attribute_names: Optional[Iterable[str]] = None, + with_for_update: Optional[ForUpdateArg] = None, + ) -> None: """Expire and refresh attributes on the given instance. The selected attributes will first be expired as they would when using @@ -2233,7 +2405,7 @@ class Session(_SessionClassMethods): "A blank dictionary is ambiguous." ) - with_for_update = query.ForUpdateArg._from_argument(with_for_update) + with_for_update = ForUpdateArg._from_argument(with_for_update) stmt = sql.select(object_mapper(instance)) if ( @@ -2251,7 +2423,7 @@ class Session(_SessionClassMethods): "Could not refresh instance '%s'" % instance_str(instance) ) - def expire_all(self): + def expire_all(self) -> None: """Expires all persistent instances within this Session. When any attributes on a persistent instance is next accessed, @@ -2286,7 +2458,9 @@ class Session(_SessionClassMethods): for state in self.identity_map.all_states(): state._expire(state.dict, self.identity_map._modified) - def expire(self, instance, attribute_names=None): + def expire( + self, instance: object, attribute_names: Optional[Iterable[str]] = None + ) -> None: """Expire the attributes on an instance. Marks the attributes of an instance as out of date. When an expired @@ -2329,7 +2503,11 @@ class Session(_SessionClassMethods): raise exc.UnmappedInstanceError(instance) from err self._expire_state(state, attribute_names) - def _expire_state(self, state, attribute_names): + def _expire_state( + self, + state: InstanceState[Any], + attribute_names: Optional[Iterable[str]], + ) -> None: self._validate_persistent(state) if attribute_names: state._expire_attributes(state.dict, attribute_names) @@ -2343,7 +2521,9 @@ class Session(_SessionClassMethods): for o, m, st_, dct_ in cascaded: self._conditional_expire(st_) - def _conditional_expire(self, state, autoflush=None): + def _conditional_expire( + self, state: InstanceState[Any], autoflush: Optional[bool] = None + ) -> None: """Expire a state if persistent, else expunge if pending""" if state.key: @@ -2352,7 +2532,7 @@ class Session(_SessionClassMethods): self._new.pop(state) state._detach(self) - def expunge(self, instance): + def expunge(self, instance: object) -> None: """Remove the `instance` from this ``Session``. This will free all internal references to the instance. Cascading @@ -2373,7 +2553,9 @@ class Session(_SessionClassMethods): ) self._expunge_states([state] + [st_ for o, m, st_, dct_ in cascaded]) - def _expunge_states(self, states, to_transient=False): + def _expunge_states( + self, states: Iterable[InstanceState[Any]], to_transient: bool = False + ) -> None: for state in states: if state in self._new: self._new.pop(state) @@ -2388,7 +2570,7 @@ class Session(_SessionClassMethods): states, self, to_transient=to_transient ) - def _register_persistent(self, states): + def _register_persistent(self, states: Set[InstanceState[Any]]) -> None: """Register all persistent objects from a flush. This is used both for pending objects moving to the persistent @@ -2429,11 +2611,13 @@ class Session(_SessionClassMethods): # state has already replaced this one in the identity # map (see test/orm/test_naturalpks.py ReversePKsTest) self.identity_map.safe_discard(state) - if state in self._transaction._key_switches: - orig_key = self._transaction._key_switches[state][0] + trans = self._transaction + assert trans is not None + if state in trans._key_switches: + orig_key = trans._key_switches[state][0] else: orig_key = state.key - self._transaction._key_switches[state] = ( + trans._key_switches[state] = ( orig_key, instance_key, ) @@ -2470,7 +2654,7 @@ class Session(_SessionClassMethods): for state in set(states).intersection(self._new): self._new.pop(state) - def _register_altered(self, states): + def _register_altered(self, states: Iterable[InstanceState[Any]]) -> None: if self._transaction: for state in states: if state in self._new: @@ -2478,7 +2662,9 @@ class Session(_SessionClassMethods): else: self._transaction._dirty[state] = True - def _remove_newly_deleted(self, states): + def _remove_newly_deleted( + self, states: Iterable[InstanceState[Any]] + ) -> None: persistent_to_deleted = self.dispatch.persistent_to_deleted or None for state in states: if self._transaction: @@ -2498,7 +2684,7 @@ class Session(_SessionClassMethods): if persistent_to_deleted is not None: persistent_to_deleted(self, state) - def add(self, instance: Any, _warn: bool = True) -> None: + def add(self, instance: object, _warn: bool = True) -> None: """Place an object in the ``Session``. Its state will be persisted to the database on the next flush @@ -2518,7 +2704,7 @@ class Session(_SessionClassMethods): self._save_or_update_state(state) - def add_all(self, instances): + def add_all(self, instances: Iterable[object]) -> None: """Add the given collection of instances to this ``Session``.""" if self._warn_on_events: @@ -2527,7 +2713,7 @@ class Session(_SessionClassMethods): for instance in instances: self.add(instance, _warn=False) - def _save_or_update_state(self, state): + def _save_or_update_state(self, state: InstanceState[Any]) -> None: state._orphaned_outside_of_session = False self._save_or_update_impl(state) @@ -2537,7 +2723,7 @@ class Session(_SessionClassMethods): ): self._save_or_update_impl(st_) - def delete(self, instance): + def delete(self, instance: object) -> None: """Mark an instance as deleted. The database delete operation occurs upon ``flush()``. @@ -2553,7 +2739,9 @@ class Session(_SessionClassMethods): self._delete_impl(state, instance, head=True) - def _delete_impl(self, state, obj, head): + def _delete_impl( + self, state: InstanceState[Any], obj: object, head: bool + ) -> None: if state.key is None: if head: @@ -2580,23 +2768,28 @@ class Session(_SessionClassMethods): cascade_states = list( state.manager.mapper.cascade_iterator("delete", state) ) + else: + cascade_states = None self._deleted[state] = obj if head: + if TYPE_CHECKING: + assert cascade_states is not None for o, m, st_, dct_ in cascade_states: self._delete_impl(st_, o, False) def get( self, - entity, - ident, - options=None, - populate_existing=False, - with_for_update=None, - identity_token=None, - execution_options=None, - ): + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: Optional[ForUpdateArg] = None, + identity_token: Optional[Any] = None, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + ) -> Optional[_O]: """Return an instance based on the given primary key identifier, or ``None`` if not found. @@ -2696,7 +2889,7 @@ class Session(_SessionClassMethods): entity, ident, loading.load_on_pk_identity, - options, + options=options, populate_existing=populate_existing, with_for_update=with_for_update, identity_token=identity_token, @@ -2705,23 +2898,24 @@ class Session(_SessionClassMethods): def _get_impl( self, - entity, - primary_key_identity, - db_load_fn, - options=None, - populate_existing=False, - with_for_update=None, - identity_token=None, - execution_options=None, - ): + entity: _EntityBindKey[_O], + primary_key_identity: _PKIdentityArgument, + db_load_fn: Callable[..., _O], + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: Optional[ForUpdateArg] = None, + identity_token: Optional[Any] = None, + execution_options: Optional[_ExecuteOptionsParameter] = None, + ) -> Optional[_O]: # convert composite types to individual args - if hasattr(primary_key_identity, "__composite_values__"): + if is_composite_class(primary_key_identity): primary_key_identity = primary_key_identity.__composite_values__() - mapper = inspect(entity) + mapper: Optional[Mapper[_O]] = inspect(entity) - if not mapper or not mapper.is_mapper: + if mapper is None or not mapper.is_mapper: raise sa_exc.ArgumentError( "Expected mapped class or mapper, got: %r" % entity ) @@ -2729,7 +2923,7 @@ class Session(_SessionClassMethods): is_dict = isinstance(primary_key_identity, dict) if not is_dict: primary_key_identity = util.to_list( - primary_key_identity, default=(None,) + primary_key_identity, default=[None] ) if len(primary_key_identity) != len(mapper.primary_key): @@ -2770,11 +2964,12 @@ class Session(_SessionClassMethods): if instance is not None: # reject calls for id in identity map but class # mismatch. - if not issubclass(instance.__class__, mapper.class_): + if not isinstance(instance, mapper.class_): return None return instance - elif instance is attributes.PASSIVE_CLASS_MISMATCH: - return None + + # TODO: this was being tested before, but this is not possible + assert instance is not LoaderCallableStatus.PASSIVE_CLASS_MISMATCH # set_label_style() not strictly necessary, however this will ensure # that tablename_colname style is used which at the moment is @@ -2788,7 +2983,7 @@ class Session(_SessionClassMethods): LABEL_STYLE_TABLENAME_PLUS_COL ) if with_for_update is not None: - statement._for_update_arg = query.ForUpdateArg._from_argument( + statement._for_update_arg = ForUpdateArg._from_argument( with_for_update ) @@ -2803,7 +2998,13 @@ class Session(_SessionClassMethods): load_options=load_options, ) - def merge(self, instance, load=True, options=None): + def merge( + self, + instance: _O, + *, + load: bool = True, + options: Optional[Sequence[ORMOption]] = None, + ) -> _O: """Copy the state of a given instance into a corresponding instance within this :class:`.Session`. @@ -2866,8 +3067,8 @@ class Session(_SessionClassMethods): if self._warn_on_events: self._flush_warning("Session.merge()") - _recursive = {} - _resolve_conflict_map = {} + _recursive: Dict[InstanceState[Any], object] = {} + _resolve_conflict_map: Dict[_IdentityKeyType[Any], object] = {} if load: # flush current contents if we expect to load data @@ -2890,20 +3091,23 @@ class Session(_SessionClassMethods): def _merge( self, - state, - state_dict, - load=True, - options=None, - _recursive=None, - _resolve_conflict_map=None, - ): + state: InstanceState[_O], + state_dict: _InstanceDict, + *, + options: Optional[Sequence[ORMOption]] = None, + load: bool, + _recursive: Dict[InstanceState[Any], object], + _resolve_conflict_map: Dict[_IdentityKeyType[Any], object], + ) -> _O: mapper = _state_mapper(state) if state in _recursive: - return _recursive[state] + return cast(_O, _recursive[state]) new_instance = False key = state.key + merged: Optional[_O] + if key is None: if state in self._new: util.warn( @@ -2920,7 +3124,9 @@ class Session(_SessionClassMethods): "load=False." ) key = mapper._identity_key_from_state(state) - key_is_persistent = attributes.NEVER_SET not in key[1] and ( + key_is_persistent = LoaderCallableStatus.NEVER_SET not in key[ + 1 + ] and ( not _none_set.intersection(key[1]) or ( mapper.allow_partial_pks @@ -2941,7 +3147,7 @@ class Session(_SessionClassMethods): if merged is None: if key_is_persistent and key in _resolve_conflict_map: - merged = _resolve_conflict_map[key] + merged = cast(_O, _resolve_conflict_map[key]) elif not load: if state.modified: @@ -2986,19 +3192,21 @@ class Session(_SessionClassMethods): state, state_dict, mapper.version_id_col, - passive=attributes.PASSIVE_NO_INITIALIZE, + passive=PassiveFlag.PASSIVE_NO_INITIALIZE, ) merged_version = mapper._get_state_attr_by_column( merged_state, merged_dict, mapper.version_id_col, - passive=attributes.PASSIVE_NO_INITIALIZE, + passive=PassiveFlag.PASSIVE_NO_INITIALIZE, ) if ( - existing_version is not attributes.PASSIVE_NO_RESULT - and merged_version is not attributes.PASSIVE_NO_RESULT + existing_version + is not LoaderCallableStatus.PASSIVE_NO_RESULT + and merged_version + is not LoaderCallableStatus.PASSIVE_NO_RESULT and existing_version != merged_version ): raise exc.StaleDataError( @@ -3043,14 +3251,14 @@ class Session(_SessionClassMethods): merged_state.manager.dispatch.load(merged_state, None) return merged - def _validate_persistent(self, state): + def _validate_persistent(self, state: InstanceState[Any]) -> None: if not self.identity_map.contains_state(state): raise sa_exc.InvalidRequestError( "Instance '%s' is not persistent within this Session" % state_str(state) ) - def _save_impl(self, state): + def _save_impl(self, state: InstanceState[Any]) -> None: if state.key is not None: raise sa_exc.InvalidRequestError( "Object '%s' already has an identity - " @@ -3065,7 +3273,9 @@ class Session(_SessionClassMethods): if to_attach: self._after_attach(state, obj) - def _update_impl(self, state, revert_deletion=False): + def _update_impl( + self, state: InstanceState[Any], revert_deletion: bool = False + ) -> None: if state.key is None: raise sa_exc.InvalidRequestError( "Instance '%s' is not persisted" % state_str(state) @@ -3103,13 +3313,13 @@ class Session(_SessionClassMethods): elif revert_deletion: self.dispatch.deleted_to_persistent(self, state) - def _save_or_update_impl(self, state): + def _save_or_update_impl(self, state: InstanceState[Any]) -> None: if state.key is None: self._save_impl(state) else: self._update_impl(state) - def enable_relationship_loading(self, obj): + def enable_relationship_loading(self, obj: object) -> None: """Associate an object with this :class:`.Session` for related object loading. @@ -3174,8 +3384,8 @@ class Session(_SessionClassMethods): if to_attach: self._after_attach(state, obj) - def _before_attach(self, state, obj): - self._autobegin() + def _before_attach(self, state: InstanceState[Any], obj: object) -> bool: + self._autobegin_t() if state.session_id == self.hash_key: return False @@ -3191,7 +3401,7 @@ class Session(_SessionClassMethods): return True - def _after_attach(self, state, obj): + def _after_attach(self, state: InstanceState[Any], obj: object) -> None: state.session_id = self.hash_key if state.modified and state._strong_obj is None: state._strong_obj = obj @@ -3202,7 +3412,7 @@ class Session(_SessionClassMethods): else: self.dispatch.transient_to_pending(self, state) - def __contains__(self, instance): + def __contains__(self, instance: object) -> bool: """Return True if the instance is associated with this session. The instance may be pending or persistent within the Session for a @@ -3215,7 +3425,7 @@ class Session(_SessionClassMethods): raise exc.UnmappedInstanceError(instance) from err return self._contains_state(state) - def __iter__(self): + def __iter__(self) -> Iterator[object]: """Iterate over all pending or persistent instances within this Session. @@ -3224,10 +3434,10 @@ class Session(_SessionClassMethods): list(self._new.values()) + list(self.identity_map.values()) ) - def _contains_state(self, state): + def _contains_state(self, state: InstanceState[Any]) -> bool: return state in self._new or self.identity_map.contains_state(state) - def flush(self, objects=None): + def flush(self, objects: Optional[Sequence[Any]] = None) -> None: """Flush all the object changes to the database. Writes out all pending object creations, deletions and modifications @@ -3261,7 +3471,7 @@ class Session(_SessionClassMethods): finally: self._flushing = False - def _flush_warning(self, method): + def _flush_warning(self, method: Any) -> None: util.warn( "Usage of the '%s' operation is not currently supported " "within the execution stage of the flush process. " @@ -3269,14 +3479,14 @@ class Session(_SessionClassMethods): "event listeners or connection-level operations instead." % method ) - def _is_clean(self): + def _is_clean(self) -> bool: return ( not self.identity_map.check_modified() and not self._deleted and not self._new ) - def _flush(self, objects=None): + def _flush(self, objects: Optional[Sequence[object]] = None) -> None: dirty = self._dirty_states if not dirty and not self._deleted and not self._new: @@ -3398,11 +3608,11 @@ class Session(_SessionClassMethods): def bulk_save_objects( self, - objects, - return_defaults=False, - update_changed_only=True, - preserve_order=True, - ): + objects: Iterable[object], + return_defaults: bool = False, + update_changed_only: bool = True, + preserve_order: bool = True, + ) -> None: """Perform a bulk save of the given list of objects. The bulk save feature allows mapped objects to be used as the @@ -3496,6 +3706,8 @@ class Session(_SessionClassMethods): """ + obj_states: Iterable[InstanceState[Any]] + obj_states = (attributes.instance_state(obj) for obj in objects) if not preserve_order: @@ -3508,7 +3720,9 @@ class Session(_SessionClassMethods): key=lambda state: (id(state.mapper), state.key is not None), ) - def grouping_key(state): + def grouping_key( + state: InstanceState[_O], + ) -> Tuple[Mapper[_O], bool]: return (state.mapper, state.key is not None) for (mapper, isupdate), states in itertools.groupby( @@ -3525,8 +3739,12 @@ class Session(_SessionClassMethods): ) def bulk_insert_mappings( - self, mapper, mappings, return_defaults=False, render_nulls=False - ): + self, + mapper: Mapper[Any], + mappings: Iterable[Dict[str, Any]], + return_defaults: bool = False, + render_nulls: bool = False, + ) -> None: """Perform a bulk insert of the given list of mapping dictionaries. The bulk insert feature allows plain Python dictionaries to be used as @@ -3633,7 +3851,9 @@ class Session(_SessionClassMethods): render_nulls, ) - def bulk_update_mappings(self, mapper, mappings): + def bulk_update_mappings( + self, mapper: Mapper[Any], mappings: Iterable[Dict[str, Any]] + ) -> None: """Perform a bulk update of the given list of mapping dictionaries. The bulk update feature allows plain Python dictionaries to be used as @@ -3696,14 +3916,14 @@ class Session(_SessionClassMethods): def _bulk_save_mappings( self, - mapper, - mappings, - isupdate, - isstates, - return_defaults, - update_changed_only, - render_nulls, - ): + mapper: Mapper[_O], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + isupdate: bool, + isstates: bool, + return_defaults: bool, + update_changed_only: bool, + render_nulls: bool, + ) -> None: mapper = _class_to_mapper(mapper) self._flushing = True @@ -3734,7 +3954,9 @@ class Session(_SessionClassMethods): finally: self._flushing = False - def is_modified(self, instance, include_collections=True): + def is_modified( + self, instance: object, include_collections: bool = True + ) -> bool: r"""Return ``True`` if the given instance has locally modified attributes. @@ -3800,7 +4022,7 @@ class Session(_SessionClassMethods): continue (added, unchanged, deleted) = attr.impl.get_history( - state, dict_, passive=attributes.NO_CHANGE + state, dict_, passive=PassiveFlag.NO_CHANGE ) if added or deleted: @@ -3809,7 +4031,7 @@ class Session(_SessionClassMethods): return False @property - def is_active(self): + def is_active(self) -> bool: """True if this :class:`.Session` not in "partial rollback" state. .. versionchanged:: 1.4 The :class:`_orm.Session` no longer begins @@ -3838,22 +4060,8 @@ class Session(_SessionClassMethods): """ return self._transaction is None or self._transaction.is_active - identity_map = None - """A mapping of object identities to objects themselves. - - Iterating through ``Session.identity_map.values()`` provides - access to the full set of persistent objects (i.e., those - that have row identity) currently in the session. - - .. seealso:: - - :func:`.identity_key` - helper function to produce the keys used - in this dictionary. - - """ - @property - def _dirty_states(self): + def _dirty_states(self) -> Iterable[InstanceState[Any]]: """The set of all persistent states considered dirty. This method returns all states that were modified including @@ -3863,7 +4071,7 @@ class Session(_SessionClassMethods): return self.identity_map._dirty_states() @property - def dirty(self): + def dirty(self) -> IdentitySet: """The set of all persistent instances considered dirty. E.g.:: @@ -3886,7 +4094,7 @@ class Session(_SessionClassMethods): attributes, use the :meth:`.Session.is_modified` method. """ - return util.IdentitySet( + return IdentitySet( [ state.obj() for state in self._dirty_states @@ -3895,13 +4103,13 @@ class Session(_SessionClassMethods): ) @property - def deleted(self): + def deleted(self) -> IdentitySet: "The set of all instances marked as 'deleted' within this ``Session``" return util.IdentitySet(list(self._deleted.values())) @property - def new(self): + def new(self) -> IdentitySet: "The set of all instances marked as 'new' within this ``Session``." return util.IdentitySet(list(self._new.values())) @@ -4002,14 +4210,16 @@ class sessionmaker(_SessionClassMethods): """ + class_: Type[Session] + def __init__( self, - bind=None, - class_=Session, - autoflush=True, - expire_on_commit=True, - info=None, - **kw, + bind: Optional[_SessionBind] = None, + class_: Type[Session] = Session, + autoflush: bool = True, + expire_on_commit: bool = True, + info: Optional[Dict[Any, Any]] = None, + **kw: Any, ): r"""Construct a new :class:`.sessionmaker`. @@ -4052,7 +4262,7 @@ class sessionmaker(_SessionClassMethods): # events can be associated with it specifically. self.class_ = type(class_.__name__, (class_,), {}) - def begin(self): + def begin(self) -> contextlib.AbstractContextManager[Session]: """Produce a context manager that both provides a new :class:`_orm.Session` as well as a transaction that commits. @@ -4074,7 +4284,7 @@ class sessionmaker(_SessionClassMethods): session = self() return session._maker_context_manager() - def __call__(self, **local_kw): + def __call__(self, **local_kw: Any) -> Session: """Produce a new :class:`.Session` object using the configuration established in this :class:`.sessionmaker`. @@ -4094,7 +4304,7 @@ class sessionmaker(_SessionClassMethods): local_kw.setdefault(k, v) return self.class_(**local_kw) - def configure(self, **new_kw): + def configure(self, **new_kw: Any) -> None: """(Re)configure the arguments for this sessionmaker. e.g.:: @@ -4105,7 +4315,7 @@ class sessionmaker(_SessionClassMethods): """ self.kw.update(new_kw) - def __repr__(self): + def __repr__(self) -> str: return "%s(class_=%r, %s)" % ( self.__class__.__name__, self.class_.__name__, @@ -4113,7 +4323,7 @@ class sessionmaker(_SessionClassMethods): ) -def close_all_sessions(): +def close_all_sessions() -> None: """Close all sessions in memory. This function consults a global registry of all :class:`.Session` objects @@ -4131,7 +4341,7 @@ def close_all_sessions(): sess.close() -def make_transient(instance): +def make_transient(instance: object) -> None: """Alter the state of the given instance so that it is :term:`transient`. .. note:: @@ -4195,7 +4405,7 @@ def make_transient(instance): del state._deleted -def make_transient_to_detached(instance): +def make_transient_to_detached(instance: object) -> None: """Make the given transient instance :term:`detached`. .. note:: @@ -4234,7 +4444,7 @@ def make_transient_to_detached(instance): state._expire_attributes(state.dict, state.unloaded_expirable) -def object_session(instance): +def object_session(instance: object) -> Optional[Session]: """Return the :class:`.Session` to which the given instance belongs. This is essentially the same as the :attr:`.InstanceState.session` |
