summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/session.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-04-07 12:37:23 -0400
committermike bayer <mike_mp@zzzcomputing.com>2022-04-12 02:09:50 +0000
commitaa9cd878e8249a4a758c7f968e929e92fede42a5 (patch)
tree1be1c9dc24dd247a150be55d65bfc56ebaf111bc /lib/sqlalchemy/orm/session.py
parent98eae4e181cb2d1bbc67ec834bfad29dcba7f461 (diff)
downloadsqlalchemy-aa9cd878e8249a4a758c7f968e929e92fede42a5.tar.gz
pep-484: session, instancestate, etc
Also adds some fixes to annotation-based mapping that have come up, as well as starts to add more pep-484 test cases Change-Id: Ia722bbbc7967a11b23b66c8084eb61df9d233fee
Diffstat (limited to 'lib/sqlalchemy/orm/session.py')
-rw-r--r--lib/sqlalchemy/orm/session.py936
1 files changed, 573 insertions, 363 deletions
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 77a97936b..55ce73cf5 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -13,12 +13,20 @@ import itertools
import sys
import typing
from typing import Any
+from typing import Callable
+from typing import cast
from typing import Dict
+from typing import Iterable
+from typing import Iterator
from typing import List
+from typing import NoReturn
from typing import Optional
-from typing import overload
+from typing import Sequence
+from typing import Set
from typing import Tuple
from typing import Type
+from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
import weakref
@@ -30,14 +38,20 @@ from . import loading
from . import persistence
from . import query
from . import state as statelib
+from ._typing import is_composite_class
+from ._typing import is_user_defined_option
from .base import _class_to_mapper
-from .base import _IdentityKeyType
from .base import _none_set
from .base import _state_mapper
from .base import instance_str
+from .base import LoaderCallableStatus
from .base import object_mapper
from .base import object_state
+from .base import PassiveFlag
from .base import state_str
+from .context import FromStatement
+from .context import ORMCompileState
+from .identity import IdentityMap
from .query import Query
from .state import InstanceState
from .state_changes import _StateChange
@@ -51,22 +65,41 @@ from .. import util
from ..engine import Connection
from ..engine import Engine
from ..engine.util import TransactionalContext
+from ..event import dispatcher
+from ..event import EventTarget
from ..inspection import inspect
from ..sql import coercions
from ..sql import dml
from ..sql import roles
+from ..sql import Select
from ..sql import visitors
from ..sql.base import CompileState
+from ..sql.selectable import ForUpdateArg
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from ..util import IdentitySet
from ..util.typing import Literal
+from ..util.typing import Protocol
if typing.TYPE_CHECKING:
+ from ._typing import _IdentityKeyType
+ from ._typing import _InstanceDict
+ from .interfaces import ORMOption
+ from .interfaces import UserDefinedOption
from .mapper import Mapper
+ from .path_registry import PathRegistry
+ from ..engine import Result
from ..engine import Row
+ from ..engine.base import Transaction
+ from ..engine.base import TwoPhaseTransaction
+ from ..engine.interfaces import _CoreAnyExecuteParams
+ from ..engine.interfaces import _CoreSingleExecuteParams
+ from ..engine.interfaces import _ExecuteOptions
+ from ..engine.interfaces import _ExecuteOptionsParameter
+ from ..engine.result import ScalarResult
+ from ..event import _InstanceLevelDispatch
from ..sql._typing import _ColumnsClauseArgument
- from ..sql._typing import _ExecuteOptions
- from ..sql._typing import _ExecuteParams
from ..sql.base import Executable
+ from ..sql.elements import ClauseElement
from ..sql.schema import Table
__all__ = [
@@ -80,14 +113,45 @@ __all__ = [
"object_session",
]
-_sessions = weakref.WeakValueDictionary()
+_sessions: weakref.WeakValueDictionary[
+ int, Session
+] = weakref.WeakValueDictionary()
"""Weak-referencing dictionary of :class:`.Session` objects.
"""
+_O = TypeVar("_O", bound=object)
statelib._sessions = _sessions
+_PKIdentityArgument = Union[Any, Tuple[Any, ...]]
-def _state_session(state):
+_EntityBindKey = Union[Type[_O], "Mapper[_O]"]
+_SessionBindKey = Union[Type[Any], "Mapper[Any]", "Table"]
+_SessionBind = Union["Engine", "Connection"]
+
+
+class _ConnectionCallableProto(Protocol):
+ """a callable that returns a :class:`.Connection` given an instance.
+
+ This callable, when present on a :class:`.Session`, is called only from the
+ ORM's persistence mechanism (i.e. the unit of work flush process) to allow
+ for connection-per-instance schemes (i.e. horizontal sharding) to be used
+ as persistence time.
+
+ This callable is not present on a plain :class:`.Session`, however
+ is established when using the horizontal sharding extension.
+
+ """
+
+ def __call__(
+ self,
+ mapper: Optional[Mapper[Any]] = None,
+ instance: Optional[object] = None,
+ **kw: Any,
+ ) -> Connection:
+ ...
+
+
+def _state_session(state: InstanceState[Any]) -> Optional[Session]:
"""Given an :class:`.InstanceState`, return the :class:`.Session`
associated, if any.
"""
@@ -110,39 +174,16 @@ class _SessionClassMethods:
close_all_sessions()
@classmethod
- @overload
- def identity_key(
- cls,
- class_: type,
- ident: Tuple[Any, ...],
- *,
- identity_token: Optional[str],
- ) -> _IdentityKeyType:
- ...
-
- @classmethod
- @overload
- def identity_key(cls, *, instance: Any) -> _IdentityKeyType:
- ...
-
- @classmethod
- @overload
- def identity_key(
- cls, class_: type, *, row: "Row", identity_token: Optional[str]
- ) -> _IdentityKeyType:
- ...
-
- @classmethod
@util.preload_module("sqlalchemy.orm.util")
def identity_key(
cls,
- class_=None,
- ident=None,
+ class_: Optional[Type[Any]] = None,
+ ident: Union[Any, Tuple[Any, ...]] = None,
*,
- instance=None,
- row=None,
- identity_token=None,
- ) -> _IdentityKeyType:
+ instance: Optional[Any] = None,
+ row: Optional[Row] = None,
+ identity_token: Optional[Any] = None,
+ ) -> _IdentityKeyType[Any]:
"""Return an identity key.
This is an alias of :func:`.util.identity_key`.
@@ -157,7 +198,7 @@ class _SessionClassMethods:
)
@classmethod
- def object_session(cls, instance: Any) -> "Session":
+ def object_session(cls, instance: object) -> Optional[Session]:
"""Return the :class:`.Session` to which an object belongs.
This is an alias of :func:`.object_session`.
@@ -205,26 +246,26 @@ class ORMExecuteState(util.MemoizedSlots):
"_update_execution_options",
)
- session: "Session"
- statement: "Executable"
- parameters: "_ExecuteParams"
- execution_options: "_ExecuteOptions"
- local_execution_options: "_ExecuteOptions"
+ session: Session
+ statement: Executable
+ parameters: Optional[_CoreAnyExecuteParams]
+ execution_options: _ExecuteOptions
+ local_execution_options: _ExecuteOptions
bind_arguments: Dict[str, Any]
- _compile_state_cls: Type[context.ORMCompileState]
- _starting_event_idx: Optional[int]
+ _compile_state_cls: Optional[Type[ORMCompileState]]
+ _starting_event_idx: int
_events_todo: List[Any]
- _update_execution_options: Optional["_ExecuteOptions"]
+ _update_execution_options: Optional[_ExecuteOptions]
def __init__(
self,
- session: "Session",
- statement: "Executable",
- parameters: "_ExecuteParams",
- execution_options: "_ExecuteOptions",
+ session: Session,
+ statement: Executable,
+ parameters: Optional[_CoreAnyExecuteParams],
+ execution_options: _ExecuteOptions,
bind_arguments: Dict[str, Any],
- compile_state_cls: Type[context.ORMCompileState],
- events_todo: List[Any],
+ compile_state_cls: Optional[Type[ORMCompileState]],
+ events_todo: List[_InstanceLevelDispatch[Session]],
):
self.session = session
self.statement = statement
@@ -237,16 +278,16 @@ class ORMExecuteState(util.MemoizedSlots):
self._compile_state_cls = compile_state_cls
self._events_todo = list(events_todo)
- def _remaining_events(self):
+ def _remaining_events(self) -> List[_InstanceLevelDispatch[Session]]:
return self._events_todo[self._starting_event_idx + 1 :]
def invoke_statement(
self,
- statement=None,
- params=None,
- execution_options=None,
- bind_arguments=None,
- ):
+ statement: Optional[Executable] = None,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ bind_arguments: Optional[Dict[str, Any]] = None,
+ ) -> Result:
"""Execute the statement represented by this
:class:`.ORMExecuteState`, without re-invoking events that have
already proceeded.
@@ -270,9 +311,12 @@ class ORMExecuteState(util.MemoizedSlots):
:param statement: optional statement to be invoked, in place of the
statement currently represented by :attr:`.ORMExecuteState.statement`.
- :param params: optional dictionary of parameters which will be merged
- into the existing :attr:`.ORMExecuteState.parameters` of this
- :class:`.ORMExecuteState`.
+ :param params: optional dictionary of parameters or list of parameters
+ which will be merged into the existing
+ :attr:`.ORMExecuteState.parameters` of this :class:`.ORMExecuteState`.
+
+ .. versionchanged:: 2.0 a list of parameter dictionaries is accepted
+ for executemany executions.
:param execution_options: optional dictionary of execution options
will be merged into the existing
@@ -302,9 +346,32 @@ class ORMExecuteState(util.MemoizedSlots):
_bind_arguments.update(bind_arguments)
_bind_arguments["_sa_skip_events"] = True
+ _params: Optional[_CoreAnyExecuteParams]
if params:
- _params = dict(self.parameters)
- _params.update(params)
+ if self.is_executemany:
+ _params = []
+ exec_many_parameters = cast(
+ "List[Dict[str, Any]]", self.parameters
+ )
+ for _existing_params, _new_params in itertools.zip_longest(
+ exec_many_parameters,
+ cast("List[Dict[str, Any]]", params),
+ ):
+ if _existing_params is None or _new_params is None:
+ raise sa_exc.InvalidRequestError(
+ f"Can't apply executemany parameters to "
+ f"statement; number of parameter sets passed to "
+ f"Session.execute() ({len(exec_many_parameters)}) "
+ f"does not match number of parameter sets given "
+ f"to ORMExecuteState.invoke_statement() "
+ f"({len(params)})"
+ )
+ _existing_params = dict(_existing_params)
+ _existing_params.update(_new_params)
+ _params.append(_existing_params)
+ else:
+ _params = dict(cast("Dict[str, Any]", self.parameters))
+ _params.update(cast("Dict[str, Any]", params))
else:
_params = self.parameters
@@ -321,7 +388,7 @@ class ORMExecuteState(util.MemoizedSlots):
)
@property
- def bind_mapper(self):
+ def bind_mapper(self) -> Optional[Mapper[Any]]:
"""Return the :class:`_orm.Mapper` that is the primary "bind" mapper.
For an :class:`_orm.ORMExecuteState` object invoking an ORM
@@ -349,7 +416,7 @@ class ORMExecuteState(util.MemoizedSlots):
return self.bind_arguments.get("mapper", None)
@property
- def all_mappers(self):
+ def all_mappers(self) -> Sequence[Mapper[Any]]:
"""Return a sequence of all :class:`_orm.Mapper` objects that are
involved at the top level of this statement.
@@ -369,7 +436,7 @@ class ORMExecuteState(util.MemoizedSlots):
"""
if not self.is_orm_statement:
return []
- elif self.is_select:
+ elif isinstance(self.statement, (Select, FromStatement)):
result = []
seen = set()
for d in self.statement.column_descriptions:
@@ -380,13 +447,13 @@ class ORMExecuteState(util.MemoizedSlots):
seen.add(insp.mapper)
result.append(insp.mapper)
return result
- elif self.is_update or self.is_delete:
+ elif self.statement.is_dml and self.bind_mapper:
return [self.bind_mapper]
else:
return []
@property
- def is_orm_statement(self):
+ def is_orm_statement(self) -> bool:
"""return True if the operation is an ORM statement.
This indicates that the select(), update(), or delete() being
@@ -399,44 +466,64 @@ class ORMExecuteState(util.MemoizedSlots):
return self._compile_state_cls is not None
@property
- def is_select(self):
+ def is_executemany(self) -> bool:
+ """return True if the parameters are a multi-element list of
+ dictionaries with more than one dictionary.
+
+ .. versionadded:: 2.0
+
+ """
+ return isinstance(self.parameters, list)
+
+ @property
+ def is_select(self) -> bool:
"""return True if this is a SELECT operation."""
return self.statement.is_select
@property
- def is_insert(self):
+ def is_insert(self) -> bool:
"""return True if this is an INSERT operation."""
return self.statement.is_dml and self.statement.is_insert
@property
- def is_update(self):
+ def is_update(self) -> bool:
"""return True if this is an UPDATE operation."""
return self.statement.is_dml and self.statement.is_update
@property
- def is_delete(self):
+ def is_delete(self) -> bool:
"""return True if this is a DELETE operation."""
return self.statement.is_dml and self.statement.is_delete
@property
- def _is_crud(self):
+ def _is_crud(self) -> bool:
return isinstance(self.statement, (dml.Update, dml.Delete))
- def update_execution_options(self, **opts):
+ def update_execution_options(self, **opts: _ExecuteOptions) -> None:
+ """Update the local execution options with new values."""
# TODO: no coverage
self.local_execution_options = self.local_execution_options.union(opts)
- def _orm_compile_options(self):
+ def _orm_compile_options(
+ self,
+ ) -> Optional[
+ Union[
+ context.ORMCompileState.default_compile_options,
+ Type[context.ORMCompileState.default_compile_options],
+ ]
+ ]:
if not self.is_select:
return None
opts = self.statement._compile_options
- if opts.isinstance(context.ORMCompileState.default_compile_options):
- return opts
+ if opts is not None and opts.isinstance(
+ context.ORMCompileState.default_compile_options
+ ):
+ return opts # type: ignore
else:
return None
@property
- def lazy_loaded_from(self):
+ def lazy_loaded_from(self) -> Optional[InstanceState[Any]]:
"""An :class:`.InstanceState` that is using this statement execution
for a lazy load operation.
@@ -451,7 +538,7 @@ class ORMExecuteState(util.MemoizedSlots):
return self.load_options._lazy_loaded_from
@property
- def loader_strategy_path(self):
+ def loader_strategy_path(self) -> Optional[PathRegistry]:
"""Return the :class:`.PathRegistry` for the current load path.
This object represents the "path" in a query along relationships
@@ -465,7 +552,7 @@ class ORMExecuteState(util.MemoizedSlots):
return None
@property
- def is_column_load(self):
+ def is_column_load(self) -> bool:
"""Return True if the operation is refreshing column-oriented
attributes on an existing ORM object.
@@ -492,7 +579,7 @@ class ORMExecuteState(util.MemoizedSlots):
return opts is not None and opts._for_refresh_state
@property
- def is_relationship_load(self):
+ def is_relationship_load(self) -> bool:
"""Return True if this load is loading objects on behalf of a
relationship.
@@ -518,7 +605,12 @@ class ORMExecuteState(util.MemoizedSlots):
return path is not None and not path.is_root
@property
- def load_options(self):
+ def load_options(
+ self,
+ ) -> Union[
+ context.QueryContext.default_load_options,
+ Type[context.QueryContext.default_load_options],
+ ]:
"""Return the load_options that will be used for this execution."""
if not self.is_select:
@@ -531,7 +623,12 @@ class ORMExecuteState(util.MemoizedSlots):
)
@property
- def update_delete_options(self):
+ def update_delete_options(
+ self,
+ ) -> Union[
+ persistence.BulkUDCompileState.default_update_options,
+ Type[persistence.BulkUDCompileState.default_update_options],
+ ]:
"""Return the update_delete_options that will be used for this
execution."""
@@ -546,7 +643,7 @@ class ORMExecuteState(util.MemoizedSlots):
)
@property
- def user_defined_options(self):
+ def user_defined_options(self) -> Sequence[UserDefinedOption]:
"""The sequence of :class:`.UserDefinedOptions` that have been
associated with the statement being invoked.
@@ -554,7 +651,7 @@ class ORMExecuteState(util.MemoizedSlots):
return [
opt
for opt in self.statement._with_options
- if not opt._is_compile_state and not opt._is_legacy_option
+ if is_user_defined_option(opt)
]
@@ -597,14 +694,29 @@ class SessionTransaction(_StateChange, TransactionalContext):
"""
- _rollback_exception = None
+ _rollback_exception: Optional[BaseException] = None
+
+ _connections: Dict[
+ Union[Engine, Connection], Tuple[Connection, Transaction, bool, bool]
+ ]
+ session: Session
+ _parent: Optional[SessionTransaction]
+
+ _state: SessionTransactionState
+
+ _new: weakref.WeakKeyDictionary[InstanceState[Any], object]
+ _deleted: weakref.WeakKeyDictionary[InstanceState[Any], object]
+ _dirty: weakref.WeakKeyDictionary[InstanceState[Any], object]
+ _key_switches: weakref.WeakKeyDictionary[
+ InstanceState[Any], Tuple[Any, Any]
+ ]
def __init__(
self,
- session,
- parent=None,
- nested=False,
- autobegin=False,
+ session: Session,
+ parent: Optional[SessionTransaction] = None,
+ nested: bool = False,
+ autobegin: bool = False,
):
TransactionalContext._trans_ctx_check(session)
@@ -629,7 +741,9 @@ class SessionTransaction(_StateChange, TransactionalContext):
self.session.dispatch.after_transaction_create(self.session, self)
- def _raise_for_prerequisite_state(self, operation_name, state):
+ def _raise_for_prerequisite_state(
+ self, operation_name: str, state: SessionTransactionState
+ ) -> NoReturn:
if state is SessionTransactionState.DEACTIVE:
if self._rollback_exception:
raise sa_exc.PendingRollbackError(
@@ -655,7 +769,7 @@ class SessionTransaction(_StateChange, TransactionalContext):
)
@property
- def parent(self):
+ def parent(self) -> Optional[SessionTransaction]:
"""The parent :class:`.SessionTransaction` of this
:class:`.SessionTransaction`.
@@ -673,7 +787,7 @@ class SessionTransaction(_StateChange, TransactionalContext):
"""
return self._parent
- nested = False
+ nested: bool = False
"""Indicates if this is a nested, or SAVEPOINT, transaction.
When :attr:`.SessionTransaction.nested` is True, it is expected
@@ -682,33 +796,40 @@ class SessionTransaction(_StateChange, TransactionalContext):
"""
@property
- def is_active(self):
+ def is_active(self) -> bool:
return (
self.session is not None
and self._state is SessionTransactionState.ACTIVE
)
@property
- def _is_transaction_boundary(self):
+ def _is_transaction_boundary(self) -> bool:
return self.nested or not self._parent
@_StateChange.declare_states(
(SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE
)
- def connection(self, bindkey, execution_options=None, **kwargs):
+ def connection(
+ self,
+ bindkey: Optional[Mapper[Any]],
+ execution_options: Optional[_ExecuteOptions] = None,
+ **kwargs: Any,
+ ) -> Connection:
bind = self.session.get_bind(bindkey, **kwargs)
return self._connection_for_bind(bind, execution_options)
@_StateChange.declare_states(
(SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE
)
- def _begin(self, nested=False):
+ def _begin(self, nested: bool = False) -> SessionTransaction:
return SessionTransaction(self.session, self, nested=nested)
- def _iterate_self_and_parents(self, upto=None):
+ def _iterate_self_and_parents(
+ self, upto: Optional[SessionTransaction] = None
+ ) -> Iterable[SessionTransaction]:
current = self
- result = ()
+ result: Tuple[SessionTransaction, ...] = ()
while current:
result += (current,)
if current._parent is upto:
@@ -723,12 +844,14 @@ class SessionTransaction(_StateChange, TransactionalContext):
return result
- def _take_snapshot(self, autobegin=False):
+ def _take_snapshot(self, autobegin: bool = False) -> None:
if not self._is_transaction_boundary:
- self._new = self._parent._new
- self._deleted = self._parent._deleted
- self._dirty = self._parent._dirty
- self._key_switches = self._parent._key_switches
+ parent = self._parent
+ assert parent is not None
+ self._new = parent._new
+ self._deleted = parent._deleted
+ self._dirty = parent._dirty
+ self._key_switches = parent._key_switches
return
if not autobegin and not self.session._flushing:
@@ -739,7 +862,7 @@ class SessionTransaction(_StateChange, TransactionalContext):
self._dirty = weakref.WeakKeyDictionary()
self._key_switches = weakref.WeakKeyDictionary()
- def _restore_snapshot(self, dirty_only=False):
+ def _restore_snapshot(self, dirty_only: bool = False) -> None:
"""Restore the restoration state taken before a transaction began.
Corresponds to a rollback.
@@ -771,7 +894,7 @@ class SessionTransaction(_StateChange, TransactionalContext):
if not dirty_only or s.modified or s in self._dirty:
s._expire(s.dict, self.session.identity_map._modified)
- def _remove_snapshot(self):
+ def _remove_snapshot(self) -> None:
"""Remove the restoration state taken before a transaction began.
Corresponds to a commit.
@@ -788,15 +911,21 @@ class SessionTransaction(_StateChange, TransactionalContext):
)
self._deleted.clear()
elif self.nested:
- self._parent._new.update(self._new)
- self._parent._dirty.update(self._dirty)
- self._parent._deleted.update(self._deleted)
- self._parent._key_switches.update(self._key_switches)
+ parent = self._parent
+ assert parent is not None
+ parent._new.update(self._new)
+ parent._dirty.update(self._dirty)
+ parent._deleted.update(self._deleted)
+ parent._key_switches.update(self._key_switches)
@_StateChange.declare_states(
(SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE
)
- def _connection_for_bind(self, bind, execution_options):
+ def _connection_for_bind(
+ self,
+ bind: _SessionBind,
+ execution_options: Optional[_ExecuteOptions],
+ ) -> Connection:
if bind in self._connections:
if execution_options:
@@ -829,6 +958,7 @@ class SessionTransaction(_StateChange, TransactionalContext):
if execution_options:
conn = conn.execution_options(**execution_options)
+ transaction: Transaction
if self.session.twophase and self._parent is None:
transaction = conn.begin_twophase()
elif self.nested:
@@ -837,9 +967,9 @@ class SessionTransaction(_StateChange, TransactionalContext):
# if given a future connection already in a transaction, don't
# commit that transaction unless it is a savepoint
if conn.in_nested_transaction():
- transaction = conn.get_nested_transaction()
+ transaction = conn._get_required_nested_transaction()
else:
- transaction = conn.get_transaction()
+ transaction = conn._get_required_transaction()
should_commit = False
else:
transaction = conn.begin()
@@ -861,7 +991,7 @@ class SessionTransaction(_StateChange, TransactionalContext):
self.session.dispatch.after_begin(self.session, self, conn)
return conn
- def prepare(self):
+ def prepare(self) -> None:
if self._parent is not None or not self.session.twophase:
raise sa_exc.InvalidRequestError(
"'twophase' mode not enabled, or not root transaction; "
@@ -872,12 +1002,13 @@ class SessionTransaction(_StateChange, TransactionalContext):
@_StateChange.declare_states(
(SessionTransactionState.ACTIVE,), SessionTransactionState.PREPARED
)
- def _prepare_impl(self):
+ def _prepare_impl(self) -> None:
if self._parent is None or self.nested:
self.session.dispatch.before_commit(self.session)
stx = self.session._transaction
+ assert stx is not None
if stx is not self:
for subtransaction in stx._iterate_self_and_parents(upto=self):
subtransaction.commit()
@@ -897,7 +1028,7 @@ class SessionTransaction(_StateChange, TransactionalContext):
if self._parent is None and self.session.twophase:
try:
for t in set(self._connections.values()):
- t[1].prepare()
+ cast("TwoPhaseTransaction", t[1]).prepare()
except:
with util.safe_reraise():
self.rollback()
@@ -929,9 +1060,7 @@ class SessionTransaction(_StateChange, TransactionalContext):
self.close()
if _to_root and self._parent:
- return self._parent.commit(_to_root=True)
-
- return self._parent
+ self._parent.commit(_to_root=True)
@_StateChange.declare_states(
(
@@ -941,9 +1070,12 @@ class SessionTransaction(_StateChange, TransactionalContext):
),
SessionTransactionState.CLOSED,
)
- def rollback(self, _capture_exception=False, _to_root=False):
+ def rollback(
+ self, _capture_exception: bool = False, _to_root: bool = False
+ ) -> None:
stx = self.session._transaction
+ assert stx is not None
if stx is not self:
for subtransaction in stx._iterate_self_and_parents(upto=self):
subtransaction.close()
@@ -993,19 +1125,18 @@ class SessionTransaction(_StateChange, TransactionalContext):
if self._parent and _capture_exception:
self._parent._rollback_exception = sys.exc_info()[1]
- if rollback_err:
+ if rollback_err and rollback_err[1]:
raise rollback_err[1].with_traceback(rollback_err[2])
sess.dispatch.after_soft_rollback(sess, self)
if _to_root and self._parent:
- return self._parent.rollback(_to_root=True)
- return self._parent
+ self._parent.rollback(_to_root=True)
@_StateChange.declare_states(
_StateChangeStates.ANY, SessionTransactionState.CLOSED
)
- def close(self, invalidate=False):
+ def close(self, invalidate: bool = False) -> None:
if self.nested:
self.session._nested_transaction = (
self._previous_nested_transaction
@@ -1027,25 +1158,30 @@ class SessionTransaction(_StateChange, TransactionalContext):
self._state = SessionTransactionState.CLOSED
sess = self.session
- self.session = None
- self._connections = None
+ # TODO: these two None sets were historically after the
+ # event hook below, and in 2.0 I changed it this way for some reason,
+ # and I remember there being a reason, but not what it was.
+ # Why do we need to get rid of them at all? test_memusage::CycleTest
+ # passes with these commented out.
+ # self.session = None # type: ignore
+ # self._connections = None # type: ignore
sess.dispatch.after_transaction_end(sess, self)
- def _get_subject(self):
+ def _get_subject(self) -> Session:
return self.session
- def _transaction_is_active(self):
+ def _transaction_is_active(self) -> bool:
return self._state is SessionTransactionState.ACTIVE
- def _transaction_is_closed(self):
+ def _transaction_is_closed(self) -> bool:
return self._state is SessionTransactionState.CLOSED
- def _rollback_can_be_called(self):
+ def _rollback_can_be_called(self) -> bool:
return self._state not in (COMMITTED, CLOSED)
-class Session(_SessionClassMethods):
+class Session(_SessionClassMethods, EventTarget):
"""Manages persistence operations for ORM-mapped objects.
The Session's usage paradigm is described at :doc:`/orm/session`.
@@ -1055,15 +1191,27 @@ class Session(_SessionClassMethods):
_is_asyncio = False
- identity_map: identity.IdentityMap
- _new: Dict["InstanceState", Any]
- _deleted: Dict["InstanceState", Any]
+ dispatch: dispatcher[Session]
+
+ identity_map: IdentityMap
+ """A mapping of object identities to objects themselves.
+
+ Iterating through ``Session.identity_map.values()`` provides
+ access to the full set of persistent objects (i.e., those
+ that have row identity) currently in the session.
+
+ .. seealso::
+
+ :func:`.identity_key` - helper function to produce the keys used
+ in this dictionary.
+
+ """
+
+ _new: Dict[InstanceState[Any], Any]
+ _deleted: Dict[InstanceState[Any], Any]
bind: Optional[Union[Engine, Connection]]
- __binds: Dict[
- Union[type, "Mapper", "Table"],
- Union[engine.Engine, engine.Connection],
- ]
- _flusing: bool
+ __binds: Dict[_SessionBindKey, _SessionBind]
+ _flushing: bool
_warn_on_events: bool
_transaction: Optional[SessionTransaction]
_nested_transaction: Optional[SessionTransaction]
@@ -1072,24 +1220,19 @@ class Session(_SessionClassMethods):
expire_on_commit: bool
enable_baked_queries: bool
twophase: bool
- _query_cls: Type[Query]
+ _query_cls: Type[Query[Any]]
def __init__(
self,
- bind: Optional[Union[engine.Engine, engine.Connection]] = None,
+ bind: Optional[_SessionBind] = None,
autoflush: bool = True,
future: Literal[True] = True,
expire_on_commit: bool = True,
twophase: bool = False,
- binds: Optional[
- Dict[
- Union[type, "Mapper", "Table"],
- Union[engine.Engine, engine.Connection],
- ]
- ] = None,
+ binds: Optional[Dict[_SessionBindKey, _SessionBind]] = None,
enable_baked_queries: bool = True,
info: Optional[Dict[Any, Any]] = None,
- query_cls: Optional[Type[query.Query]] = None,
+ query_cls: Optional[Type[Query[Any]]] = None,
autocommit: Literal[False] = False,
):
r"""Construct a new Session.
@@ -1249,23 +1392,23 @@ class Session(_SessionClassMethods):
_sessions[self.hash_key] = self
# used by sqlalchemy.engine.util.TransactionalContext
- _trans_context_manager = None
+ _trans_context_manager: Optional[TransactionalContext] = None
- connection_callable = None
+ connection_callable: Optional[_ConnectionCallableProto] = None
- def __enter__(self):
+ def __enter__(self) -> Session:
return self
- def __exit__(self, type_, value, traceback):
+ def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
self.close()
@contextlib.contextmanager
- def _maker_context_manager(self):
+ def _maker_context_manager(self) -> Iterator[Session]:
with self:
with self.begin():
yield self
- def in_transaction(self):
+ def in_transaction(self) -> bool:
"""Return True if this :class:`_orm.Session` has begun a transaction.
.. versionadded:: 1.4
@@ -1278,7 +1421,7 @@ class Session(_SessionClassMethods):
"""
return self._transaction is not None
- def in_nested_transaction(self):
+ def in_nested_transaction(self) -> bool:
"""Return True if this :class:`_orm.Session` has begun a nested
transaction, e.g. SAVEPOINT.
@@ -1287,7 +1430,7 @@ class Session(_SessionClassMethods):
"""
return self._nested_transaction is not None
- def get_transaction(self):
+ def get_transaction(self) -> Optional[SessionTransaction]:
"""Return the current root transaction in progress, if any.
.. versionadded:: 1.4
@@ -1298,7 +1441,7 @@ class Session(_SessionClassMethods):
trans = trans._parent
return trans
- def get_nested_transaction(self):
+ def get_nested_transaction(self) -> Optional[SessionTransaction]:
"""Return the current nested transaction in progress, if any.
.. versionadded:: 1.4
@@ -1308,7 +1451,7 @@ class Session(_SessionClassMethods):
return self._nested_transaction
@util.memoized_property
- def info(self):
+ def info(self) -> Dict[Any, Any]:
"""A user-modifiable dictionary.
The initial value of this dictionary can be populated using the
@@ -1320,16 +1463,18 @@ class Session(_SessionClassMethods):
"""
return {}
- def _autobegin(self):
+ def _autobegin_t(self) -> SessionTransaction:
if self._transaction is None:
trans = SessionTransaction(self, autobegin=True)
assert self._transaction is trans
- return True
+ return trans
- return False
+ return self._transaction
- def begin(self, nested=False, _subtrans=False):
+ def begin(
+ self, nested: bool = False, _subtrans: bool = False
+ ) -> SessionTransaction:
"""Begin a transaction, or nested transaction,
on this :class:`.Session`, if one is not already begun.
@@ -1364,13 +1509,16 @@ class Session(_SessionClassMethods):
"""
- if self._autobegin():
+ trans = self._transaction
+ if trans is None:
+ trans = self._autobegin_t()
+
if not nested and not _subtrans:
- return self._transaction
+ return trans
- if self._transaction is not None:
+ if trans is not None:
if _subtrans or nested:
- trans = self._transaction._begin(nested=nested)
+ trans = trans._begin(nested=nested)
assert self._transaction is trans
if nested:
self._nested_transaction = trans
@@ -1386,9 +1534,12 @@ class Session(_SessionClassMethods):
trans = SessionTransaction(self)
assert self._transaction is trans
- return self._transaction # needed for __enter__/__exit__ hook
+ if TYPE_CHECKING:
+ assert self._transaction is not None
+
+ return trans # needed for __enter__/__exit__ hook
- def begin_nested(self):
+ def begin_nested(self) -> SessionTransaction:
"""Begin a "nested" transaction on this Session, e.g. SAVEPOINT.
The target database(s) and associated drivers must support SQL
@@ -1413,7 +1564,7 @@ class Session(_SessionClassMethods):
"""
return self.begin(nested=True)
- def rollback(self):
+ def rollback(self) -> None:
"""Rollback the current transaction in progress.
If no transaction is in progress, this method is a pass-through.
@@ -1450,11 +1601,11 @@ class Session(_SessionClassMethods):
:ref:`unitofwork_transaction`
"""
- if self._transaction is None:
- if not self._autobegin():
- raise sa_exc.InvalidRequestError("No transaction is begun.")
+ trans = self._transaction
+ if trans is None:
+ trans = self._autobegin_t()
- self._transaction.commit(_to_root=True)
+ trans.commit(_to_root=True)
def prepare(self) -> None:
"""Prepare the current transaction in progress for two phase commit.
@@ -1467,16 +1618,16 @@ class Session(_SessionClassMethods):
:exc:`~sqlalchemy.exc.InvalidRequestError` is raised.
"""
- if self._transaction is None:
- if not self._autobegin():
- raise sa_exc.InvalidRequestError("No transaction is begun.")
+ trans = self._transaction
+ if trans is None:
+ trans = self._autobegin_t()
- self._transaction.prepare()
+ trans.prepare()
def connection(
self,
bind_arguments: Optional[Dict[str, Any]] = None,
- execution_options: Optional["_ExecuteOptions"] = None,
+ execution_options: Optional[_ExecuteOptions] = None,
) -> "Connection":
r"""Return a :class:`_engine.Connection` object corresponding to this
:class:`.Session` object's transactional state.
@@ -1521,24 +1672,28 @@ class Session(_SessionClassMethods):
execution_options=execution_options,
)
- def _connection_for_bind(self, engine, execution_options=None, **kw):
+ def _connection_for_bind(
+ self,
+ engine: _SessionBind,
+ execution_options: Optional[_ExecuteOptions] = None,
+ **kw: Any,
+ ) -> Connection:
TransactionalContext._trans_ctx_check(self)
- if self._transaction is None:
- assert self._autobegin()
- return self._transaction._connection_for_bind(
- engine, execution_options
- )
+ trans = self._transaction
+ if trans is None:
+ trans = self._autobegin_t()
+ return trans._connection_for_bind(engine, execution_options)
def execute(
self,
- statement: "Executable",
- params: Optional["_ExecuteParams"] = None,
- execution_options: "_ExecuteOptions" = util.EMPTY_DICT,
+ statement: Executable,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[Dict[str, Any]] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
- ):
+ ) -> Result:
r"""Execute a SQL expression construct.
Returns a :class:`_engine.Result` object representing
@@ -1603,6 +1758,8 @@ class Session(_SessionClassMethods):
compile_state_cls = CompileState._get_plugin_class_for_plugin(
statement, "orm"
)
+ if TYPE_CHECKING:
+ assert isinstance(compile_state_cls, ORMCompileState)
else:
compile_state_cls = None
@@ -1645,9 +1802,9 @@ class Session(_SessionClassMethods):
)
for idx, fn in enumerate(events_todo):
orm_exec_state._starting_event_idx = idx
- result = fn(orm_exec_state)
- if result:
- return result
+ fn_result: Optional[Result] = fn(orm_exec_state)
+ if fn_result:
+ return fn_result
statement = orm_exec_state.statement
execution_options = orm_exec_state.local_execution_options
@@ -1655,7 +1812,9 @@ class Session(_SessionClassMethods):
bind = self.get_bind(**bind_arguments)
conn = self._connection_for_bind(bind)
- result = conn.execute(statement, params or {}, execution_options)
+ result: Result = conn.execute(
+ statement, params or {}, execution_options
+ )
if compile_state_cls:
result = compile_state_cls.orm_setup_cursor_result(
@@ -1671,12 +1830,12 @@ class Session(_SessionClassMethods):
def scalar(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[Dict[str, Any]] = None,
+ **kw: Any,
+ ) -> Any:
"""Execute a statement and return a scalar result.
Usage and parameters are the same as that of
@@ -1695,12 +1854,12 @@ class Session(_SessionClassMethods):
def scalars(
self,
- statement,
- params=None,
- execution_options=util.EMPTY_DICT,
- bind_arguments=None,
- **kw,
- ):
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[Dict[str, Any]] = None,
+ **kw: Any,
+ ) -> ScalarResult[Any]:
"""Execute a statement and return the results as scalars.
Usage and parameters are the same as that of
@@ -1722,7 +1881,7 @@ class Session(_SessionClassMethods):
**kw,
).scalars()
- def close(self):
+ def close(self) -> None:
"""Close out the transactional resources and ORM objects used by this
:class:`_orm.Session`.
@@ -1754,7 +1913,7 @@ class Session(_SessionClassMethods):
"""
self._close_impl(invalidate=False)
- def invalidate(self):
+ def invalidate(self) -> None:
"""Close this Session, using connection invalidation.
This is a variant of :meth:`.Session.close` that will additionally
@@ -1790,13 +1949,13 @@ class Session(_SessionClassMethods):
"""
self._close_impl(invalidate=True)
- def _close_impl(self, invalidate):
+ def _close_impl(self, invalidate: bool) -> None:
self.expunge_all()
if self._transaction is not None:
for transaction in self._transaction._iterate_self_and_parents():
transaction.close(invalidate)
- def expunge_all(self):
+ def expunge_all(self) -> None:
"""Remove all object instances from this ``Session``.
This is equivalent to calling ``expunge(obj)`` on all objects in this
@@ -1812,7 +1971,7 @@ class Session(_SessionClassMethods):
statelib.InstanceState._detach_states(all_states, self)
- def _add_bind(self, key, bind):
+ def _add_bind(self, key: _SessionBindKey, bind: _SessionBind) -> None:
try:
insp = inspect(key)
except sa_exc.NoInspectionAvailable as err:
@@ -1834,7 +1993,9 @@ class Session(_SessionClassMethods):
"Not an acceptable bind target: %s" % key
)
- def bind_mapper(self, mapper, bind):
+ def bind_mapper(
+ self, mapper: _EntityBindKey[_O], bind: _SessionBind
+ ) -> None:
"""Associate a :class:`_orm.Mapper` or arbitrary Python class with a
"bind", e.g. an :class:`_engine.Engine` or
:class:`_engine.Connection`.
@@ -1862,7 +2023,7 @@ class Session(_SessionClassMethods):
"""
self._add_bind(mapper, bind)
- def bind_table(self, table, bind):
+ def bind_table(self, table: Table, bind: _SessionBind) -> None:
"""Associate a :class:`_schema.Table` with a "bind", e.g. an
:class:`_engine.Engine`
or :class:`_engine.Connection`.
@@ -1892,12 +2053,12 @@ class Session(_SessionClassMethods):
def get_bind(
self,
- mapper=None,
- clause=None,
- bind=None,
- _sa_skip_events=None,
- _sa_skip_for_implicit_returning=False,
- ):
+ mapper: Optional[_EntityBindKey[_O]] = None,
+ clause: Optional[ClauseElement] = None,
+ bind: Optional[_SessionBind] = None,
+ _sa_skip_events: Optional[bool] = None,
+ _sa_skip_for_implicit_returning: bool = False,
+ ) -> Union[Engine, Connection]:
"""Return a "bind" to which this :class:`.Session` is bound.
The "bind" is usually an instance of :class:`_engine.Engine`,
@@ -1995,23 +2156,25 @@ class Session(_SessionClassMethods):
# look more closely at the mapper.
if mapper is not None:
try:
- mapper = inspect(mapper)
+ inspected_mapper = inspect(mapper)
except sa_exc.NoInspectionAvailable as err:
if isinstance(mapper, type):
raise exc.UnmappedClassError(mapper) from err
else:
raise
+ else:
+ inspected_mapper = None
# match up the mapper or clause in the __binds
if self.__binds:
# matching mappers and selectables to entries in the
# binds dictionary; supported use case.
- if mapper:
- for cls in mapper.class_.__mro__:
+ if inspected_mapper:
+ for cls in inspected_mapper.class_.__mro__:
if cls in self.__binds:
return self.__binds[cls]
if clause is None:
- clause = mapper.persist_selectable
+ clause = inspected_mapper.persist_selectable
if clause is not None:
plugin_subject = clause._propagate_attrs.get(
@@ -2025,6 +2188,8 @@ class Session(_SessionClassMethods):
for obj in visitors.iterate(clause):
if obj in self.__binds:
+ if TYPE_CHECKING:
+ assert isinstance(obj, Table)
return self.__binds[obj]
# none of the __binds matched, but we have a fallback bind.
@@ -2033,17 +2198,19 @@ class Session(_SessionClassMethods):
return self.bind
context = []
- if mapper is not None:
- context.append("mapper %s" % mapper)
+ if inspected_mapper is not None:
+ context.append(f"mapper {inspected_mapper}")
if clause is not None:
context.append("SQL expression")
raise sa_exc.UnboundExecutionError(
- "Could not locate a bind configured on %s or this Session."
- % (", ".join(context),),
+ f"Could not locate a bind configured on "
+ f'{", ".join(context)} or this Session.'
)
- def query(self, *entities: _ColumnsClauseArgument, **kwargs: Any) -> Query:
+ def query(
+ self, *entities: _ColumnsClauseArgument, **kwargs: Any
+ ) -> Query[Any]:
"""Return a new :class:`_query.Query` object corresponding to this
:class:`_orm.Session`.
@@ -2065,12 +2232,12 @@ class Session(_SessionClassMethods):
def _identity_lookup(
self,
- mapper,
- primary_key_identity,
- identity_token=None,
- passive=attributes.PASSIVE_OFF,
- lazy_loaded_from=None,
- ):
+ mapper: Mapper[_O],
+ primary_key_identity: Union[Any, Tuple[Any, ...]],
+ identity_token: Any = None,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
+ lazy_loaded_from: Optional[InstanceState[Any]] = None,
+ ) -> Union[Optional[_O], LoaderCallableStatus]:
"""Locate an object in the identity map.
Given a primary key identity, constructs an identity key and then
@@ -2117,9 +2284,9 @@ class Session(_SessionClassMethods):
)
return loading.get_from_identity(self, mapper, key, passive)
- @property
+ @util.non_memoized_property
@contextlib.contextmanager
- def no_autoflush(self):
+ def no_autoflush(self) -> Iterator[Session]:
"""Return a context manager that disables autoflush.
e.g.::
@@ -2145,7 +2312,7 @@ class Session(_SessionClassMethods):
finally:
self.autoflush = autoflush
- def _autoflush(self):
+ def _autoflush(self) -> None:
if self.autoflush and not self._flushing:
try:
self.flush()
@@ -2161,7 +2328,12 @@ class Session(_SessionClassMethods):
)
raise e.with_traceback(sys.exc_info()[2])
- def refresh(self, instance, attribute_names=None, with_for_update=None):
+ def refresh(
+ self,
+ instance: object,
+ attribute_names: Optional[Iterable[str]] = None,
+ with_for_update: Optional[ForUpdateArg] = None,
+ ) -> None:
"""Expire and refresh attributes on the given instance.
The selected attributes will first be expired as they would when using
@@ -2233,7 +2405,7 @@ class Session(_SessionClassMethods):
"A blank dictionary is ambiguous."
)
- with_for_update = query.ForUpdateArg._from_argument(with_for_update)
+ with_for_update = ForUpdateArg._from_argument(with_for_update)
stmt = sql.select(object_mapper(instance))
if (
@@ -2251,7 +2423,7 @@ class Session(_SessionClassMethods):
"Could not refresh instance '%s'" % instance_str(instance)
)
- def expire_all(self):
+ def expire_all(self) -> None:
"""Expires all persistent instances within this Session.
When any attributes on a persistent instance is next accessed,
@@ -2286,7 +2458,9 @@ class Session(_SessionClassMethods):
for state in self.identity_map.all_states():
state._expire(state.dict, self.identity_map._modified)
- def expire(self, instance, attribute_names=None):
+ def expire(
+ self, instance: object, attribute_names: Optional[Iterable[str]] = None
+ ) -> None:
"""Expire the attributes on an instance.
Marks the attributes of an instance as out of date. When an expired
@@ -2329,7 +2503,11 @@ class Session(_SessionClassMethods):
raise exc.UnmappedInstanceError(instance) from err
self._expire_state(state, attribute_names)
- def _expire_state(self, state, attribute_names):
+ def _expire_state(
+ self,
+ state: InstanceState[Any],
+ attribute_names: Optional[Iterable[str]],
+ ) -> None:
self._validate_persistent(state)
if attribute_names:
state._expire_attributes(state.dict, attribute_names)
@@ -2343,7 +2521,9 @@ class Session(_SessionClassMethods):
for o, m, st_, dct_ in cascaded:
self._conditional_expire(st_)
- def _conditional_expire(self, state, autoflush=None):
+ def _conditional_expire(
+ self, state: InstanceState[Any], autoflush: Optional[bool] = None
+ ) -> None:
"""Expire a state if persistent, else expunge if pending"""
if state.key:
@@ -2352,7 +2532,7 @@ class Session(_SessionClassMethods):
self._new.pop(state)
state._detach(self)
- def expunge(self, instance):
+ def expunge(self, instance: object) -> None:
"""Remove the `instance` from this ``Session``.
This will free all internal references to the instance. Cascading
@@ -2373,7 +2553,9 @@ class Session(_SessionClassMethods):
)
self._expunge_states([state] + [st_ for o, m, st_, dct_ in cascaded])
- def _expunge_states(self, states, to_transient=False):
+ def _expunge_states(
+ self, states: Iterable[InstanceState[Any]], to_transient: bool = False
+ ) -> None:
for state in states:
if state in self._new:
self._new.pop(state)
@@ -2388,7 +2570,7 @@ class Session(_SessionClassMethods):
states, self, to_transient=to_transient
)
- def _register_persistent(self, states):
+ def _register_persistent(self, states: Set[InstanceState[Any]]) -> None:
"""Register all persistent objects from a flush.
This is used both for pending objects moving to the persistent
@@ -2429,11 +2611,13 @@ class Session(_SessionClassMethods):
# state has already replaced this one in the identity
# map (see test/orm/test_naturalpks.py ReversePKsTest)
self.identity_map.safe_discard(state)
- if state in self._transaction._key_switches:
- orig_key = self._transaction._key_switches[state][0]
+ trans = self._transaction
+ assert trans is not None
+ if state in trans._key_switches:
+ orig_key = trans._key_switches[state][0]
else:
orig_key = state.key
- self._transaction._key_switches[state] = (
+ trans._key_switches[state] = (
orig_key,
instance_key,
)
@@ -2470,7 +2654,7 @@ class Session(_SessionClassMethods):
for state in set(states).intersection(self._new):
self._new.pop(state)
- def _register_altered(self, states):
+ def _register_altered(self, states: Iterable[InstanceState[Any]]) -> None:
if self._transaction:
for state in states:
if state in self._new:
@@ -2478,7 +2662,9 @@ class Session(_SessionClassMethods):
else:
self._transaction._dirty[state] = True
- def _remove_newly_deleted(self, states):
+ def _remove_newly_deleted(
+ self, states: Iterable[InstanceState[Any]]
+ ) -> None:
persistent_to_deleted = self.dispatch.persistent_to_deleted or None
for state in states:
if self._transaction:
@@ -2498,7 +2684,7 @@ class Session(_SessionClassMethods):
if persistent_to_deleted is not None:
persistent_to_deleted(self, state)
- def add(self, instance: Any, _warn: bool = True) -> None:
+ def add(self, instance: object, _warn: bool = True) -> None:
"""Place an object in the ``Session``.
Its state will be persisted to the database on the next flush
@@ -2518,7 +2704,7 @@ class Session(_SessionClassMethods):
self._save_or_update_state(state)
- def add_all(self, instances):
+ def add_all(self, instances: Iterable[object]) -> None:
"""Add the given collection of instances to this ``Session``."""
if self._warn_on_events:
@@ -2527,7 +2713,7 @@ class Session(_SessionClassMethods):
for instance in instances:
self.add(instance, _warn=False)
- def _save_or_update_state(self, state):
+ def _save_or_update_state(self, state: InstanceState[Any]) -> None:
state._orphaned_outside_of_session = False
self._save_or_update_impl(state)
@@ -2537,7 +2723,7 @@ class Session(_SessionClassMethods):
):
self._save_or_update_impl(st_)
- def delete(self, instance):
+ def delete(self, instance: object) -> None:
"""Mark an instance as deleted.
The database delete operation occurs upon ``flush()``.
@@ -2553,7 +2739,9 @@ class Session(_SessionClassMethods):
self._delete_impl(state, instance, head=True)
- def _delete_impl(self, state, obj, head):
+ def _delete_impl(
+ self, state: InstanceState[Any], obj: object, head: bool
+ ) -> None:
if state.key is None:
if head:
@@ -2580,23 +2768,28 @@ class Session(_SessionClassMethods):
cascade_states = list(
state.manager.mapper.cascade_iterator("delete", state)
)
+ else:
+ cascade_states = None
self._deleted[state] = obj
if head:
+ if TYPE_CHECKING:
+ assert cascade_states is not None
for o, m, st_, dct_ in cascade_states:
self._delete_impl(st_, o, False)
def get(
self,
- entity,
- ident,
- options=None,
- populate_existing=False,
- with_for_update=None,
- identity_token=None,
- execution_options=None,
- ):
+ entity: _EntityBindKey[_O],
+ ident: _PKIdentityArgument,
+ *,
+ options: Optional[Sequence[ORMOption]] = None,
+ populate_existing: bool = False,
+ with_for_update: Optional[ForUpdateArg] = None,
+ identity_token: Optional[Any] = None,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ ) -> Optional[_O]:
"""Return an instance based on the given primary key identifier,
or ``None`` if not found.
@@ -2696,7 +2889,7 @@ class Session(_SessionClassMethods):
entity,
ident,
loading.load_on_pk_identity,
- options,
+ options=options,
populate_existing=populate_existing,
with_for_update=with_for_update,
identity_token=identity_token,
@@ -2705,23 +2898,24 @@ class Session(_SessionClassMethods):
def _get_impl(
self,
- entity,
- primary_key_identity,
- db_load_fn,
- options=None,
- populate_existing=False,
- with_for_update=None,
- identity_token=None,
- execution_options=None,
- ):
+ entity: _EntityBindKey[_O],
+ primary_key_identity: _PKIdentityArgument,
+ db_load_fn: Callable[..., _O],
+ *,
+ options: Optional[Sequence[ORMOption]] = None,
+ populate_existing: bool = False,
+ with_for_update: Optional[ForUpdateArg] = None,
+ identity_token: Optional[Any] = None,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+ ) -> Optional[_O]:
# convert composite types to individual args
- if hasattr(primary_key_identity, "__composite_values__"):
+ if is_composite_class(primary_key_identity):
primary_key_identity = primary_key_identity.__composite_values__()
- mapper = inspect(entity)
+ mapper: Optional[Mapper[_O]] = inspect(entity)
- if not mapper or not mapper.is_mapper:
+ if mapper is None or not mapper.is_mapper:
raise sa_exc.ArgumentError(
"Expected mapped class or mapper, got: %r" % entity
)
@@ -2729,7 +2923,7 @@ class Session(_SessionClassMethods):
is_dict = isinstance(primary_key_identity, dict)
if not is_dict:
primary_key_identity = util.to_list(
- primary_key_identity, default=(None,)
+ primary_key_identity, default=[None]
)
if len(primary_key_identity) != len(mapper.primary_key):
@@ -2770,11 +2964,12 @@ class Session(_SessionClassMethods):
if instance is not None:
# reject calls for id in identity map but class
# mismatch.
- if not issubclass(instance.__class__, mapper.class_):
+ if not isinstance(instance, mapper.class_):
return None
return instance
- elif instance is attributes.PASSIVE_CLASS_MISMATCH:
- return None
+
+ # TODO: this was being tested before, but this is not possible
+ assert instance is not LoaderCallableStatus.PASSIVE_CLASS_MISMATCH
# set_label_style() not strictly necessary, however this will ensure
# that tablename_colname style is used which at the moment is
@@ -2788,7 +2983,7 @@ class Session(_SessionClassMethods):
LABEL_STYLE_TABLENAME_PLUS_COL
)
if with_for_update is not None:
- statement._for_update_arg = query.ForUpdateArg._from_argument(
+ statement._for_update_arg = ForUpdateArg._from_argument(
with_for_update
)
@@ -2803,7 +2998,13 @@ class Session(_SessionClassMethods):
load_options=load_options,
)
- def merge(self, instance, load=True, options=None):
+ def merge(
+ self,
+ instance: _O,
+ *,
+ load: bool = True,
+ options: Optional[Sequence[ORMOption]] = None,
+ ) -> _O:
"""Copy the state of a given instance into a corresponding instance
within this :class:`.Session`.
@@ -2866,8 +3067,8 @@ class Session(_SessionClassMethods):
if self._warn_on_events:
self._flush_warning("Session.merge()")
- _recursive = {}
- _resolve_conflict_map = {}
+ _recursive: Dict[InstanceState[Any], object] = {}
+ _resolve_conflict_map: Dict[_IdentityKeyType[Any], object] = {}
if load:
# flush current contents if we expect to load data
@@ -2890,20 +3091,23 @@ class Session(_SessionClassMethods):
def _merge(
self,
- state,
- state_dict,
- load=True,
- options=None,
- _recursive=None,
- _resolve_conflict_map=None,
- ):
+ state: InstanceState[_O],
+ state_dict: _InstanceDict,
+ *,
+ options: Optional[Sequence[ORMOption]] = None,
+ load: bool,
+ _recursive: Dict[InstanceState[Any], object],
+ _resolve_conflict_map: Dict[_IdentityKeyType[Any], object],
+ ) -> _O:
mapper = _state_mapper(state)
if state in _recursive:
- return _recursive[state]
+ return cast(_O, _recursive[state])
new_instance = False
key = state.key
+ merged: Optional[_O]
+
if key is None:
if state in self._new:
util.warn(
@@ -2920,7 +3124,9 @@ class Session(_SessionClassMethods):
"load=False."
)
key = mapper._identity_key_from_state(state)
- key_is_persistent = attributes.NEVER_SET not in key[1] and (
+ key_is_persistent = LoaderCallableStatus.NEVER_SET not in key[
+ 1
+ ] and (
not _none_set.intersection(key[1])
or (
mapper.allow_partial_pks
@@ -2941,7 +3147,7 @@ class Session(_SessionClassMethods):
if merged is None:
if key_is_persistent and key in _resolve_conflict_map:
- merged = _resolve_conflict_map[key]
+ merged = cast(_O, _resolve_conflict_map[key])
elif not load:
if state.modified:
@@ -2986,19 +3192,21 @@ class Session(_SessionClassMethods):
state,
state_dict,
mapper.version_id_col,
- passive=attributes.PASSIVE_NO_INITIALIZE,
+ passive=PassiveFlag.PASSIVE_NO_INITIALIZE,
)
merged_version = mapper._get_state_attr_by_column(
merged_state,
merged_dict,
mapper.version_id_col,
- passive=attributes.PASSIVE_NO_INITIALIZE,
+ passive=PassiveFlag.PASSIVE_NO_INITIALIZE,
)
if (
- existing_version is not attributes.PASSIVE_NO_RESULT
- and merged_version is not attributes.PASSIVE_NO_RESULT
+ existing_version
+ is not LoaderCallableStatus.PASSIVE_NO_RESULT
+ and merged_version
+ is not LoaderCallableStatus.PASSIVE_NO_RESULT
and existing_version != merged_version
):
raise exc.StaleDataError(
@@ -3043,14 +3251,14 @@ class Session(_SessionClassMethods):
merged_state.manager.dispatch.load(merged_state, None)
return merged
- def _validate_persistent(self, state):
+ def _validate_persistent(self, state: InstanceState[Any]) -> None:
if not self.identity_map.contains_state(state):
raise sa_exc.InvalidRequestError(
"Instance '%s' is not persistent within this Session"
% state_str(state)
)
- def _save_impl(self, state):
+ def _save_impl(self, state: InstanceState[Any]) -> None:
if state.key is not None:
raise sa_exc.InvalidRequestError(
"Object '%s' already has an identity - "
@@ -3065,7 +3273,9 @@ class Session(_SessionClassMethods):
if to_attach:
self._after_attach(state, obj)
- def _update_impl(self, state, revert_deletion=False):
+ def _update_impl(
+ self, state: InstanceState[Any], revert_deletion: bool = False
+ ) -> None:
if state.key is None:
raise sa_exc.InvalidRequestError(
"Instance '%s' is not persisted" % state_str(state)
@@ -3103,13 +3313,13 @@ class Session(_SessionClassMethods):
elif revert_deletion:
self.dispatch.deleted_to_persistent(self, state)
- def _save_or_update_impl(self, state):
+ def _save_or_update_impl(self, state: InstanceState[Any]) -> None:
if state.key is None:
self._save_impl(state)
else:
self._update_impl(state)
- def enable_relationship_loading(self, obj):
+ def enable_relationship_loading(self, obj: object) -> None:
"""Associate an object with this :class:`.Session` for related
object loading.
@@ -3174,8 +3384,8 @@ class Session(_SessionClassMethods):
if to_attach:
self._after_attach(state, obj)
- def _before_attach(self, state, obj):
- self._autobegin()
+ def _before_attach(self, state: InstanceState[Any], obj: object) -> bool:
+ self._autobegin_t()
if state.session_id == self.hash_key:
return False
@@ -3191,7 +3401,7 @@ class Session(_SessionClassMethods):
return True
- def _after_attach(self, state, obj):
+ def _after_attach(self, state: InstanceState[Any], obj: object) -> None:
state.session_id = self.hash_key
if state.modified and state._strong_obj is None:
state._strong_obj = obj
@@ -3202,7 +3412,7 @@ class Session(_SessionClassMethods):
else:
self.dispatch.transient_to_pending(self, state)
- def __contains__(self, instance):
+ def __contains__(self, instance: object) -> bool:
"""Return True if the instance is associated with this session.
The instance may be pending or persistent within the Session for a
@@ -3215,7 +3425,7 @@ class Session(_SessionClassMethods):
raise exc.UnmappedInstanceError(instance) from err
return self._contains_state(state)
- def __iter__(self):
+ def __iter__(self) -> Iterator[object]:
"""Iterate over all pending or persistent instances within this
Session.
@@ -3224,10 +3434,10 @@ class Session(_SessionClassMethods):
list(self._new.values()) + list(self.identity_map.values())
)
- def _contains_state(self, state):
+ def _contains_state(self, state: InstanceState[Any]) -> bool:
return state in self._new or self.identity_map.contains_state(state)
- def flush(self, objects=None):
+ def flush(self, objects: Optional[Sequence[Any]] = None) -> None:
"""Flush all the object changes to the database.
Writes out all pending object creations, deletions and modifications
@@ -3261,7 +3471,7 @@ class Session(_SessionClassMethods):
finally:
self._flushing = False
- def _flush_warning(self, method):
+ def _flush_warning(self, method: Any) -> None:
util.warn(
"Usage of the '%s' operation is not currently supported "
"within the execution stage of the flush process. "
@@ -3269,14 +3479,14 @@ class Session(_SessionClassMethods):
"event listeners or connection-level operations instead." % method
)
- def _is_clean(self):
+ def _is_clean(self) -> bool:
return (
not self.identity_map.check_modified()
and not self._deleted
and not self._new
)
- def _flush(self, objects=None):
+ def _flush(self, objects: Optional[Sequence[object]] = None) -> None:
dirty = self._dirty_states
if not dirty and not self._deleted and not self._new:
@@ -3398,11 +3608,11 @@ class Session(_SessionClassMethods):
def bulk_save_objects(
self,
- objects,
- return_defaults=False,
- update_changed_only=True,
- preserve_order=True,
- ):
+ objects: Iterable[object],
+ return_defaults: bool = False,
+ update_changed_only: bool = True,
+ preserve_order: bool = True,
+ ) -> None:
"""Perform a bulk save of the given list of objects.
The bulk save feature allows mapped objects to be used as the
@@ -3496,6 +3706,8 @@ class Session(_SessionClassMethods):
"""
+ obj_states: Iterable[InstanceState[Any]]
+
obj_states = (attributes.instance_state(obj) for obj in objects)
if not preserve_order:
@@ -3508,7 +3720,9 @@ class Session(_SessionClassMethods):
key=lambda state: (id(state.mapper), state.key is not None),
)
- def grouping_key(state):
+ def grouping_key(
+ state: InstanceState[_O],
+ ) -> Tuple[Mapper[_O], bool]:
return (state.mapper, state.key is not None)
for (mapper, isupdate), states in itertools.groupby(
@@ -3525,8 +3739,12 @@ class Session(_SessionClassMethods):
)
def bulk_insert_mappings(
- self, mapper, mappings, return_defaults=False, render_nulls=False
- ):
+ self,
+ mapper: Mapper[Any],
+ mappings: Iterable[Dict[str, Any]],
+ return_defaults: bool = False,
+ render_nulls: bool = False,
+ ) -> None:
"""Perform a bulk insert of the given list of mapping dictionaries.
The bulk insert feature allows plain Python dictionaries to be used as
@@ -3633,7 +3851,9 @@ class Session(_SessionClassMethods):
render_nulls,
)
- def bulk_update_mappings(self, mapper, mappings):
+ def bulk_update_mappings(
+ self, mapper: Mapper[Any], mappings: Iterable[Dict[str, Any]]
+ ) -> None:
"""Perform a bulk update of the given list of mapping dictionaries.
The bulk update feature allows plain Python dictionaries to be used as
@@ -3696,14 +3916,14 @@ class Session(_SessionClassMethods):
def _bulk_save_mappings(
self,
- mapper,
- mappings,
- isupdate,
- isstates,
- return_defaults,
- update_changed_only,
- render_nulls,
- ):
+ mapper: Mapper[_O],
+ mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+ isupdate: bool,
+ isstates: bool,
+ return_defaults: bool,
+ update_changed_only: bool,
+ render_nulls: bool,
+ ) -> None:
mapper = _class_to_mapper(mapper)
self._flushing = True
@@ -3734,7 +3954,9 @@ class Session(_SessionClassMethods):
finally:
self._flushing = False
- def is_modified(self, instance, include_collections=True):
+ def is_modified(
+ self, instance: object, include_collections: bool = True
+ ) -> bool:
r"""Return ``True`` if the given instance has locally
modified attributes.
@@ -3800,7 +4022,7 @@ class Session(_SessionClassMethods):
continue
(added, unchanged, deleted) = attr.impl.get_history(
- state, dict_, passive=attributes.NO_CHANGE
+ state, dict_, passive=PassiveFlag.NO_CHANGE
)
if added or deleted:
@@ -3809,7 +4031,7 @@ class Session(_SessionClassMethods):
return False
@property
- def is_active(self):
+ def is_active(self) -> bool:
"""True if this :class:`.Session` not in "partial rollback" state.
.. versionchanged:: 1.4 The :class:`_orm.Session` no longer begins
@@ -3838,22 +4060,8 @@ class Session(_SessionClassMethods):
"""
return self._transaction is None or self._transaction.is_active
- identity_map = None
- """A mapping of object identities to objects themselves.
-
- Iterating through ``Session.identity_map.values()`` provides
- access to the full set of persistent objects (i.e., those
- that have row identity) currently in the session.
-
- .. seealso::
-
- :func:`.identity_key` - helper function to produce the keys used
- in this dictionary.
-
- """
-
@property
- def _dirty_states(self):
+ def _dirty_states(self) -> Iterable[InstanceState[Any]]:
"""The set of all persistent states considered dirty.
This method returns all states that were modified including
@@ -3863,7 +4071,7 @@ class Session(_SessionClassMethods):
return self.identity_map._dirty_states()
@property
- def dirty(self):
+ def dirty(self) -> IdentitySet:
"""The set of all persistent instances considered dirty.
E.g.::
@@ -3886,7 +4094,7 @@ class Session(_SessionClassMethods):
attributes, use the :meth:`.Session.is_modified` method.
"""
- return util.IdentitySet(
+ return IdentitySet(
[
state.obj()
for state in self._dirty_states
@@ -3895,13 +4103,13 @@ class Session(_SessionClassMethods):
)
@property
- def deleted(self):
+ def deleted(self) -> IdentitySet:
"The set of all instances marked as 'deleted' within this ``Session``"
return util.IdentitySet(list(self._deleted.values()))
@property
- def new(self):
+ def new(self) -> IdentitySet:
"The set of all instances marked as 'new' within this ``Session``."
return util.IdentitySet(list(self._new.values()))
@@ -4002,14 +4210,16 @@ class sessionmaker(_SessionClassMethods):
"""
+ class_: Type[Session]
+
def __init__(
self,
- bind=None,
- class_=Session,
- autoflush=True,
- expire_on_commit=True,
- info=None,
- **kw,
+ bind: Optional[_SessionBind] = None,
+ class_: Type[Session] = Session,
+ autoflush: bool = True,
+ expire_on_commit: bool = True,
+ info: Optional[Dict[Any, Any]] = None,
+ **kw: Any,
):
r"""Construct a new :class:`.sessionmaker`.
@@ -4052,7 +4262,7 @@ class sessionmaker(_SessionClassMethods):
# events can be associated with it specifically.
self.class_ = type(class_.__name__, (class_,), {})
- def begin(self):
+ def begin(self) -> contextlib.AbstractContextManager[Session]:
"""Produce a context manager that both provides a new
:class:`_orm.Session` as well as a transaction that commits.
@@ -4074,7 +4284,7 @@ class sessionmaker(_SessionClassMethods):
session = self()
return session._maker_context_manager()
- def __call__(self, **local_kw):
+ def __call__(self, **local_kw: Any) -> Session:
"""Produce a new :class:`.Session` object using the configuration
established in this :class:`.sessionmaker`.
@@ -4094,7 +4304,7 @@ class sessionmaker(_SessionClassMethods):
local_kw.setdefault(k, v)
return self.class_(**local_kw)
- def configure(self, **new_kw):
+ def configure(self, **new_kw: Any) -> None:
"""(Re)configure the arguments for this sessionmaker.
e.g.::
@@ -4105,7 +4315,7 @@ class sessionmaker(_SessionClassMethods):
"""
self.kw.update(new_kw)
- def __repr__(self):
+ def __repr__(self) -> str:
return "%s(class_=%r, %s)" % (
self.__class__.__name__,
self.class_.__name__,
@@ -4113,7 +4323,7 @@ class sessionmaker(_SessionClassMethods):
)
-def close_all_sessions():
+def close_all_sessions() -> None:
"""Close all sessions in memory.
This function consults a global registry of all :class:`.Session` objects
@@ -4131,7 +4341,7 @@ def close_all_sessions():
sess.close()
-def make_transient(instance):
+def make_transient(instance: object) -> None:
"""Alter the state of the given instance so that it is :term:`transient`.
.. note::
@@ -4195,7 +4405,7 @@ def make_transient(instance):
del state._deleted
-def make_transient_to_detached(instance):
+def make_transient_to_detached(instance: object) -> None:
"""Make the given transient instance :term:`detached`.
.. note::
@@ -4234,7 +4444,7 @@ def make_transient_to_detached(instance):
state._expire_attributes(state.dict, state.unloaded_expirable)
-def object_session(instance):
+def object_session(instance: object) -> Optional[Session]:
"""Return the :class:`.Session` to which the given instance belongs.
This is essentially the same as the :attr:`.InstanceState.session`