summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-04-15 11:05:36 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-04-20 15:14:09 -0400
commitaeeff72e806420bf85e2e6723b1f941df38a3e1a (patch)
tree0bed521b4d7c4860f998e51ba5e318d18b2f5900 /lib/sqlalchemy/orm
parent13a8552053c21a9fa7ff6f992ed49ee92cca73e4 (diff)
downloadsqlalchemy-aeeff72e806420bf85e2e6723b1f941df38a3e1a.tar.gz
pep-484: ORM public API, constructors
for the moment, abandoning using @overload with relationship() and mapped_column(). The overloads are very difficult to get working at all, and the overloads that were there all wouldn't pass on mypy. various techniques of getting them to "work", meaning having right hand side dictate what's legal on the left, have mixed success and wont give consistent results; additionally, it's legal to have Optional / non-optional independent of nullable in any case for columns. relationship cases are less ambiguous but mypy was not going along with things. we have a comprehensive system of allowing left side annotations to drive the right side, in the absense of explicit settings on the right. so type-centric SQLAlchemy will be left-side driven just like dataclasses, and the various flags and switches on the right side will just not be needed very much. in other matters, one surprise, forgot to remove string support from orm.join(A, B, "somename") or do deprecations for it in 1.4. This is a really not-directly-used structure barely mentioned in the docs for many years, the example shows a relationship being used, not a string, so we will just change it to raise the usual error here. Change-Id: Iefbbb8d34548b538023890ab8b7c9a5d9496ec6e
Diffstat (limited to 'lib/sqlalchemy/orm')
-rw-r--r--lib/sqlalchemy/orm/_orm_constructors.py558
-rw-r--r--lib/sqlalchemy/orm/_typing.py51
-rw-r--r--lib/sqlalchemy/orm/attributes.py8
-rw-r--r--lib/sqlalchemy/orm/base.py77
-rw-r--r--lib/sqlalchemy/orm/context.py44
-rw-r--r--lib/sqlalchemy/orm/decl_api.py8
-rw-r--r--lib/sqlalchemy/orm/decl_base.py19
-rw-r--r--lib/sqlalchemy/orm/descriptor_props.py6
-rw-r--r--lib/sqlalchemy/orm/events.py13
-rw-r--r--lib/sqlalchemy/orm/exc.py5
-rw-r--r--lib/sqlalchemy/orm/instrumentation.py81
-rw-r--r--lib/sqlalchemy/orm/interfaces.py302
-rw-r--r--lib/sqlalchemy/orm/loading.py7
-rw-r--r--lib/sqlalchemy/orm/mapper.py649
-rw-r--r--lib/sqlalchemy/orm/path_registry.py445
-rw-r--r--lib/sqlalchemy/orm/properties.py67
-rw-r--r--lib/sqlalchemy/orm/query.py7
-rw-r--r--lib/sqlalchemy/orm/relationships.py114
-rw-r--r--lib/sqlalchemy/orm/session.py21
-rw-r--r--lib/sqlalchemy/orm/strategies.py5
-rw-r--r--lib/sqlalchemy/orm/util.py539
21 files changed, 1831 insertions, 1195 deletions
diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py
index 7690c05de..457ad5c5a 100644
--- a/lib/sqlalchemy/orm/_orm_constructors.py
+++ b/lib/sqlalchemy/orm/_orm_constructors.py
@@ -9,20 +9,19 @@ from __future__ import annotations
import typing
from typing import Any
+from typing import Callable
from typing import Collection
-from typing import Dict
-from typing import List
from typing import Optional
from typing import overload
-from typing import Set
from typing import Type
+from typing import TYPE_CHECKING
from typing import Union
-from . import mapper as mapperlib
+from . import mapperlib as mapperlib
+from ._typing import _O
from .base import Mapped
from .descriptor_props import Composite
from .descriptor_props import Synonym
-from .mapper import Mapper
from .properties import ColumnProperty
from .properties import MappedColumn
from .query import AliasOption
@@ -37,11 +36,29 @@ from .. import sql
from .. import util
from ..exc import InvalidRequestError
from ..sql.base import SchemaEventTarget
-from ..sql.selectable import Alias
+from ..sql.schema import SchemaConst
from ..sql.selectable import FromClause
-from ..sql.type_api import TypeEngine
from ..util.typing import Literal
+if TYPE_CHECKING:
+ from ._typing import _EntityType
+ from ._typing import _ORMColumnExprArgument
+ from .descriptor_props import _CompositeAttrType
+ from .interfaces import PropComparator
+ from .query import Query
+ from .relationships import _LazyLoadArgumentType
+ from .relationships import _ORMBackrefArgument
+ from .relationships import _ORMColCollectionArgument
+ from .relationships import _ORMOrderByArgument
+ from .relationships import _RelationshipJoinConditionArgument
+ from ..sql._typing import _ColumnExpressionArgument
+ from ..sql._typing import _InfoType
+ from ..sql._typing import _TypeEngineArgument
+ from ..sql.schema import _ServerDefaultType
+ from ..sql.schema import FetchedValue
+ from ..sql.selectable import Alias
+ from ..sql.selectable import Subquery
+
_T = typing.TypeVar("_T")
@@ -61,7 +78,7 @@ SynonymProperty = Synonym
"for entities to be matched up to a query that is established "
"via :meth:`.Query.from_statement` and now does nothing.",
)
-def contains_alias(alias) -> AliasOption:
+def contains_alias(alias: Union[Alias, Subquery]) -> AliasOption:
r"""Return a :class:`.MapperOption` that will indicate to the
:class:`_query.Query`
that the main table has been aliased.
@@ -70,134 +87,36 @@ def contains_alias(alias) -> AliasOption:
return AliasOption(alias)
-# see test/ext/mypy/plain_files/mapped_column.py for mapped column
-# typing tests
-
-
-@overload
-def mapped_column(
- __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]],
- *args: SchemaEventTarget,
- nullable: Literal[None] = ...,
- primary_key: Literal[None] = ...,
- deferred: bool = ...,
- **kw: Any,
-) -> "MappedColumn[Any]":
- ...
-
-
-@overload
-def mapped_column(
- __name: str,
- __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]],
- *args: SchemaEventTarget,
- nullable: Literal[None] = ...,
- primary_key: Literal[None] = ...,
- deferred: bool = ...,
- **kw: Any,
-) -> "MappedColumn[Any]":
- ...
-
-
-@overload
-def mapped_column(
- __name: str,
- __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]],
- *args: SchemaEventTarget,
- nullable: Literal[True] = ...,
- primary_key: Literal[None] = ...,
- deferred: bool = ...,
- **kw: Any,
-) -> "MappedColumn[Optional[_T]]":
- ...
-
-
-@overload
-def mapped_column(
- __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]],
- *args: SchemaEventTarget,
- nullable: Literal[True] = ...,
- primary_key: Literal[None] = ...,
- deferred: bool = ...,
- **kw: Any,
-) -> "MappedColumn[Optional[_T]]":
- ...
-
-
-@overload
-def mapped_column(
- __name: str,
- __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]],
- *args: SchemaEventTarget,
- nullable: Literal[False] = ...,
- primary_key: Literal[None] = ...,
- deferred: bool = ...,
- **kw: Any,
-) -> "MappedColumn[_T]":
- ...
-
-
-@overload
-def mapped_column(
- __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]],
- *args: SchemaEventTarget,
- nullable: Literal[False] = ...,
- primary_key: Literal[None] = ...,
- deferred: bool = ...,
- **kw: Any,
-) -> "MappedColumn[_T]":
- ...
-
-
-@overload
-def mapped_column(
- __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]],
- *args: SchemaEventTarget,
- nullable: bool = ...,
- primary_key: Literal[True] = ...,
- deferred: bool = ...,
- **kw: Any,
-) -> "MappedColumn[_T]":
- ...
-
-
-@overload
-def mapped_column(
- __name: str,
- __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]],
- *args: SchemaEventTarget,
- nullable: bool = ...,
- primary_key: Literal[True] = ...,
- deferred: bool = ...,
- **kw: Any,
-) -> "MappedColumn[_T]":
- ...
-
-
-@overload
-def mapped_column(
- __name: str,
- *args: SchemaEventTarget,
- nullable: bool = ...,
- primary_key: bool = ...,
- deferred: bool = ...,
- **kw: Any,
-) -> "MappedColumn[Any]":
- ...
-
-
-@overload
def mapped_column(
+ __name_pos: Optional[
+ Union[str, _TypeEngineArgument[Any], SchemaEventTarget]
+ ] = None,
+ __type_pos: Optional[
+ Union[_TypeEngineArgument[Any], SchemaEventTarget]
+ ] = None,
*args: SchemaEventTarget,
- nullable: bool = ...,
- primary_key: bool = ...,
- deferred: bool = ...,
- **kw: Any,
-) -> "MappedColumn[Any]":
- ...
-
-
-def mapped_column(*args: Any, **kw: Any) -> "MappedColumn[Any]":
+ nullable: Optional[
+ Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]]
+ ] = SchemaConst.NULL_UNSPECIFIED,
+ primary_key: Optional[bool] = False,
+ deferred: bool = False,
+ name: Optional[str] = None,
+ type_: Optional[_TypeEngineArgument[Any]] = None,
+ autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto",
+ default: Optional[Any] = None,
+ doc: Optional[str] = None,
+ key: Optional[str] = None,
+ index: Optional[bool] = None,
+ unique: Optional[bool] = None,
+ info: Optional[_InfoType] = None,
+ onupdate: Optional[Any] = None,
+ server_default: Optional[_ServerDefaultType] = None,
+ server_onupdate: Optional[FetchedValue] = None,
+ quote: Optional[bool] = None,
+ system: bool = False,
+ comment: Optional[str] = None,
+ **dialect_kwargs: Any,
+) -> MappedColumn[Any]:
r"""construct a new ORM-mapped :class:`_schema.Column` construct.
The :func:`_orm.mapped_column` function provides an ORM-aware and
@@ -363,12 +282,45 @@ def mapped_column(*args: Any, **kw: Any) -> "MappedColumn[Any]":
"""
- return MappedColumn(*args, **kw)
+ return MappedColumn(
+ __name_pos,
+ __type_pos,
+ *args,
+ name=name,
+ type_=type_,
+ autoincrement=autoincrement,
+ default=default,
+ doc=doc,
+ key=key,
+ index=index,
+ unique=unique,
+ info=info,
+ nullable=nullable,
+ onupdate=onupdate,
+ primary_key=primary_key,
+ server_default=server_default,
+ server_onupdate=server_onupdate,
+ quote=quote,
+ comment=comment,
+ system=system,
+ deferred=deferred,
+ **dialect_kwargs,
+ )
def column_property(
- column: sql.ColumnElement[_T], *additional_columns, **kwargs
-) -> "ColumnProperty[_T]":
+ column: _ORMColumnExprArgument[_T],
+ *additional_columns: _ORMColumnExprArgument[Any],
+ group: Optional[str] = None,
+ deferred: bool = False,
+ raiseload: bool = False,
+ comparator_factory: Optional[Type[PropComparator[_T]]] = None,
+ descriptor: Optional[Any] = None,
+ active_history: bool = False,
+ expire_on_flush: bool = True,
+ info: Optional[_InfoType] = None,
+ doc: Optional[str] = None,
+) -> ColumnProperty[_T]:
r"""Provide a column-level property for use with a mapping.
Column-based properties can normally be applied to the mapper's
@@ -452,13 +404,25 @@ def column_property(
expressions
"""
- return ColumnProperty(column, *additional_columns, **kwargs)
+ return ColumnProperty(
+ column,
+ *additional_columns,
+ group=group,
+ deferred=deferred,
+ raiseload=raiseload,
+ comparator_factory=comparator_factory,
+ descriptor=descriptor,
+ active_history=active_history,
+ expire_on_flush=expire_on_flush,
+ info=info,
+ doc=doc,
+ )
@overload
def composite(
class_: Type[_T],
- *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]],
+ *attrs: _CompositeAttrType[Any],
**kwargs: Any,
) -> Composite[_T]:
...
@@ -466,7 +430,7 @@ def composite(
@overload
def composite(
- *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]],
+ *attrs: _CompositeAttrType[Any],
**kwargs: Any,
) -> Composite[Any]:
...
@@ -474,7 +438,7 @@ def composite(
def composite(
class_: Any = None,
- *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]],
+ *attrs: _CompositeAttrType[Any],
**kwargs: Any,
) -> Composite[Any]:
r"""Return a composite column-based property for use with a Mapper.
@@ -529,13 +493,13 @@ def composite(
def with_loader_criteria(
- entity_or_base,
- where_criteria,
- loader_only=False,
- include_aliases=False,
- propagate_to_loaders=True,
- track_closure_variables=True,
-) -> "LoaderCriteriaOption":
+ entity_or_base: _EntityType[Any],
+ where_criteria: _ColumnExpressionArgument[bool],
+ loader_only: bool = False,
+ include_aliases: bool = False,
+ propagate_to_loaders: bool = True,
+ track_closure_variables: bool = True,
+) -> LoaderCriteriaOption:
"""Add additional WHERE criteria to the load for all occurrences of
a particular entity.
@@ -711,180 +675,40 @@ def with_loader_criteria(
)
-@overload
-def relationship(
- argument: str,
- secondary=...,
- *,
- uselist: bool = ...,
- collection_class: Literal[None] = ...,
- primaryjoin=...,
- secondaryjoin=...,
- back_populates=...,
- **kw: Any,
-) -> Relationship[Any]:
- ...
-
-
-@overload
-def relationship(
- argument: str,
- secondary=...,
- *,
- uselist: bool = ...,
- collection_class: Type[Set] = ...,
- primaryjoin=...,
- secondaryjoin=...,
- back_populates=...,
- **kw: Any,
-) -> Relationship[Set[Any]]:
- ...
-
-
-@overload
-def relationship(
- argument: str,
- secondary=...,
- *,
- uselist: bool = ...,
- collection_class: Type[List] = ...,
- primaryjoin=...,
- secondaryjoin=...,
- back_populates=...,
- **kw: Any,
-) -> Relationship[List[Any]]:
- ...
-
-
-@overload
-def relationship(
- argument: Optional[_RelationshipArgumentType[_T]],
- secondary=...,
- *,
- uselist: Literal[False] = ...,
- collection_class: Literal[None] = ...,
- primaryjoin=...,
- secondaryjoin=...,
- back_populates=...,
- **kw: Any,
-) -> Relationship[_T]:
- ...
-
-
-@overload
-def relationship(
- argument: Optional[_RelationshipArgumentType[_T]],
- secondary=...,
- *,
- uselist: Literal[True] = ...,
- collection_class: Literal[None] = ...,
- primaryjoin=...,
- secondaryjoin=...,
- back_populates=...,
- **kw: Any,
-) -> Relationship[List[_T]]:
- ...
-
-
-@overload
-def relationship(
- argument: Optional[_RelationshipArgumentType[_T]],
- secondary=...,
- *,
- uselist: Union[Literal[None], Literal[True]] = ...,
- collection_class: Type[List] = ...,
- primaryjoin=...,
- secondaryjoin=...,
- back_populates=...,
- **kw: Any,
-) -> Relationship[List[_T]]:
- ...
-
-
-@overload
-def relationship(
- argument: Optional[_RelationshipArgumentType[_T]],
- secondary=...,
- *,
- uselist: Union[Literal[None], Literal[True]] = ...,
- collection_class: Type[Set] = ...,
- primaryjoin=...,
- secondaryjoin=...,
- back_populates=...,
- **kw: Any,
-) -> Relationship[Set[_T]]:
- ...
-
-
-@overload
-def relationship(
- argument: Optional[_RelationshipArgumentType[_T]],
- secondary=...,
- *,
- uselist: Union[Literal[None], Literal[True]] = ...,
- collection_class: Type[Dict[Any, Any]] = ...,
- primaryjoin=...,
- secondaryjoin=...,
- back_populates=...,
- **kw: Any,
-) -> Relationship[Dict[Any, _T]]:
- ...
-
-
-@overload
-def relationship(
- argument: _RelationshipArgumentType[_T],
- secondary=...,
- *,
- uselist: Literal[None] = ...,
- collection_class: Literal[None] = ...,
- primaryjoin=...,
- secondaryjoin=None,
- back_populates=None,
- **kw: Any,
-) -> Relationship[Any]:
- ...
-
-
-@overload
-def relationship(
- argument: Optional[_RelationshipArgumentType[_T]] = ...,
- secondary=...,
- *,
- uselist: Literal[True] = ...,
- collection_class: Any = ...,
- primaryjoin=...,
- secondaryjoin=...,
- back_populates=...,
- **kw: Any,
-) -> Relationship[Any]:
- ...
-
-
-@overload
def relationship(
- argument: Literal[None] = ...,
- secondary=...,
- *,
- uselist: Optional[bool] = ...,
- collection_class: Any = ...,
- primaryjoin=...,
- secondaryjoin=...,
- back_populates=...,
- **kw: Any,
-) -> Relationship[Any]:
- ...
-
-
-def relationship(
- argument: Optional[_RelationshipArgumentType[_T]] = None,
- secondary=None,
+ argument: Optional[_RelationshipArgumentType[Any]] = None,
+ secondary: Optional[FromClause] = None,
*,
uselist: Optional[bool] = None,
- collection_class: Optional[Type[Collection]] = None,
- primaryjoin=None,
- secondaryjoin=None,
- back_populates=None,
+ collection_class: Optional[
+ Union[Type[Collection[Any]], Callable[[], Collection[Any]]]
+ ] = None,
+ primaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
+ secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
+ back_populates: Optional[str] = None,
+ order_by: _ORMOrderByArgument = False,
+ backref: Optional[_ORMBackrefArgument] = None,
+ overlaps: Optional[str] = None,
+ post_update: bool = False,
+ cascade: str = "save-update, merge",
+ viewonly: bool = False,
+ lazy: _LazyLoadArgumentType = "select",
+ passive_deletes: bool = False,
+ passive_updates: bool = True,
+ active_history: bool = False,
+ enable_typechecks: bool = True,
+ foreign_keys: Optional[_ORMColCollectionArgument] = None,
+ remote_side: Optional[_ORMColCollectionArgument] = None,
+ join_depth: Optional[int] = None,
+ comparator_factory: Optional[Type[PropComparator[Any]]] = None,
+ single_parent: bool = False,
+ innerjoin: bool = False,
+ distinct_target_key: Optional[bool] = None,
+ load_on_pending: bool = False,
+ query_class: Optional[Type[Query[Any]]] = None,
+ info: Optional[_InfoType] = None,
+ omit_join: Literal[None, False] = None,
+ sync_backref: Optional[bool] = None,
**kw: Any,
) -> Relationship[Any]:
"""Provide a relationship between two mapped classes.
@@ -1098,13 +922,6 @@ def relationship(
:ref:`error_qzyx` - usage example
- :param bake_queries=True:
- Legacy parameter, not used.
-
- .. versionchanged:: 1.4.23 the "lambda caching" system is no longer
- used by loader strategies and the ``bake_queries`` parameter
- has no effect.
-
:param cascade:
A comma-separated list of cascade rules which determines how
Session operations should be "cascaded" from parent to child.
@@ -1701,18 +1518,42 @@ def relationship(
primaryjoin=primaryjoin,
secondaryjoin=secondaryjoin,
back_populates=back_populates,
+ order_by=order_by,
+ backref=backref,
+ overlaps=overlaps,
+ post_update=post_update,
+ cascade=cascade,
+ viewonly=viewonly,
+ lazy=lazy,
+ passive_deletes=passive_deletes,
+ passive_updates=passive_updates,
+ active_history=active_history,
+ enable_typechecks=enable_typechecks,
+ foreign_keys=foreign_keys,
+ remote_side=remote_side,
+ join_depth=join_depth,
+ comparator_factory=comparator_factory,
+ single_parent=single_parent,
+ innerjoin=innerjoin,
+ distinct_target_key=distinct_target_key,
+ load_on_pending=load_on_pending,
+ query_class=query_class,
+ info=info,
+ omit_join=omit_join,
+ sync_backref=sync_backref,
**kw,
)
def synonym(
- name,
- map_column=None,
- descriptor=None,
- comparator_factory=None,
- doc=None,
- info=None,
-) -> "Synonym[Any]":
+ name: str,
+ *,
+ map_column: Optional[bool] = None,
+ descriptor: Optional[Any] = None,
+ comparator_factory: Optional[Type[PropComparator[_T]]] = None,
+ info: Optional[_InfoType] = None,
+ doc: Optional[str] = None,
+) -> Synonym[Any]:
"""Denote an attribute name as a synonym to a mapped property,
in that the attribute will mirror the value and expression behavior
of another attribute.
@@ -1951,8 +1792,8 @@ def deferred(*columns, **kw):
def query_expression(
- default_expr: sql.ColumnElement[_T] = sql.null(),
-) -> "Mapped[_T]":
+ default_expr: _ORMColumnExprArgument[_T] = sql.null(),
+) -> Mapped[_T]:
"""Indicate an attribute that populates from a query-time SQL expression.
:param default_expr: Optional SQL expression object that will be used in
@@ -2010,33 +1851,33 @@ def clear_mappers():
@overload
def aliased(
- element: Union[Type[_T], "Mapper[_T]", "AliasedClass[_T]"],
- alias=None,
- name=None,
- flat=False,
- adapt_on_names=False,
-) -> "AliasedClass[_T]":
+ element: _EntityType[_O],
+ alias: Optional[Union[Alias, Subquery]] = None,
+ name: Optional[str] = None,
+ flat: bool = False,
+ adapt_on_names: bool = False,
+) -> AliasedClass[_O]:
...
@overload
def aliased(
- element: "FromClause",
- alias=None,
- name=None,
- flat=False,
- adapt_on_names=False,
-) -> "Alias":
+ element: FromClause,
+ alias: Optional[Union[Alias, Subquery]] = None,
+ name: Optional[str] = None,
+ flat: bool = False,
+ adapt_on_names: bool = False,
+) -> FromClause:
...
def aliased(
- element: Union[Type[_T], "Mapper[_T]", "FromClause", "AliasedClass[_T]"],
- alias=None,
- name=None,
- flat=False,
- adapt_on_names=False,
-) -> Union["AliasedClass[_T]", "Alias"]:
+ element: Union[_EntityType[_O], FromClause],
+ alias: Optional[Union[Alias, Subquery]] = None,
+ name: Optional[str] = None,
+ flat: bool = False,
+ adapt_on_names: bool = False,
+) -> Union[AliasedClass[_O], FromClause]:
"""Produce an alias of the given element, usually an :class:`.AliasedClass`
instance.
@@ -2233,9 +2074,7 @@ def with_polymorphic(
)
-def join(
- left, right, onclause=None, isouter=False, full=False, join_to_left=None
-):
+def join(left, right, onclause=None, isouter=False, full=False):
r"""Produce an inner join between left and right clauses.
:func:`_orm.join` is an extension to the core join interface
@@ -2270,16 +2109,11 @@ def join(
See :ref:`orm_queryguide_joins` for information on modern usage
of ORM level joins.
- .. deprecated:: 0.8
-
- the ``join_to_left`` parameter is deprecated, and will be removed
- in a future release. The parameter has no effect.
-
"""
return _ORMJoin(left, right, onclause, isouter, full)
-def outerjoin(left, right, onclause=None, full=False, join_to_left=None):
+def outerjoin(left, right, onclause=None, full=False):
"""Produce a left outer join between left and right clauses.
This is the "outer join" version of the :func:`_orm.join` function,
diff --git a/lib/sqlalchemy/orm/_typing.py b/lib/sqlalchemy/orm/_typing.py
index 4250cdbe1..339844f14 100644
--- a/lib/sqlalchemy/orm/_typing.py
+++ b/lib/sqlalchemy/orm/_typing.py
@@ -2,6 +2,7 @@ from __future__ import annotations
import operator
from typing import Any
+from typing import Callable
from typing import Dict
from typing import Optional
from typing import Tuple
@@ -10,7 +11,9 @@ from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
-from sqlalchemy.orm.interfaces import UserDefinedOption
+from ..sql import roles
+from ..sql._typing import _HasClauseElement
+from ..sql.elements import ColumnElement
from ..util.typing import Protocol
from ..util.typing import TypeGuard
@@ -18,8 +21,12 @@ if TYPE_CHECKING:
from .attributes import AttributeImpl
from .attributes import CollectionAttributeImpl
from .base import PassiveFlag
+ from .decl_api import registry as _registry_type
from .descriptor_props import _CompositeClassProto
+ from .interfaces import MapperProperty
+ from .interfaces import UserDefinedOption
from .mapper import Mapper
+ from .relationships import Relationship
from .state import InstanceState
from .util import AliasedClass
from .util import AliasedInsp
@@ -27,21 +34,39 @@ if TYPE_CHECKING:
_T = TypeVar("_T", bound=Any)
+
+# I would have preferred this were bound=object however it seems
+# to not travel in all situations when defined in that way.
_O = TypeVar("_O", bound=Any)
"""The 'ORM mapped object' type.
-I would have preferred this were bound=object however it seems
-to not travel in all situations when defined in that way.
+
"""
+if TYPE_CHECKING:
+ _RegistryType = _registry_type
+
_InternalEntityType = Union["Mapper[_T]", "AliasedInsp[_T]"]
-_EntityType = Union[_T, "AliasedClass[_T]", "Mapper[_T]", "AliasedInsp[_T]"]
+_EntityType = Union[
+ Type[_T], "AliasedClass[_T]", "Mapper[_T]", "AliasedInsp[_T]"
+]
_InstanceDict = Dict[str, Any]
_IdentityKeyType = Tuple[Type[_T], Tuple[Any, ...], Optional[Any]]
+_ORMColumnExprArgument = Union[
+ ColumnElement[_T],
+ _HasClauseElement,
+ roles.ExpressionElementRole[_T],
+]
+
+# somehow Protocol didn't want to work for this one
+_ORMAdapterProto = Callable[
+ [_ORMColumnExprArgument[_T], Optional[str]], _ORMColumnExprArgument[_T]
+]
+
class _LoaderCallable(Protocol):
def __call__(self, state: InstanceState[Any], passive: PassiveFlag) -> Any:
@@ -60,10 +85,28 @@ def is_composite_class(obj: Any) -> TypeGuard[_CompositeClassProto]:
if TYPE_CHECKING:
+ def insp_is_mapper_property(obj: Any) -> TypeGuard[MapperProperty[Any]]:
+ ...
+
+ def insp_is_mapper(obj: Any) -> TypeGuard[Mapper[Any]]:
+ ...
+
+ def insp_is_aliased_class(obj: Any) -> TypeGuard[AliasedInsp[Any]]:
+ ...
+
+ def prop_is_relationship(
+ prop: MapperProperty[Any],
+ ) -> TypeGuard[Relationship[Any]]:
+ ...
+
def is_collection_impl(
impl: AttributeImpl,
) -> TypeGuard[CollectionAttributeImpl]:
...
else:
+ insp_is_mapper_property = operator.attrgetter("is_property")
+ insp_is_mapper = operator.attrgetter("is_mapper")
+ insp_is_aliased_class = operator.attrgetter("is_aliased_class")
is_collection_impl = operator.attrgetter("collection")
+ prop_is_relationship = operator.attrgetter("_is_relationship")
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index 33ce96a19..41d944c57 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -44,7 +44,7 @@ from .base import instance_dict as instance_dict
from .base import instance_state as instance_state
from .base import instance_str
from .base import LOAD_AGAINST_COMMITTED
-from .base import manager_of_class
+from .base import manager_of_class as manager_of_class
from .base import Mapped as Mapped # noqa
from .base import NEVER_SET # noqa
from .base import NO_AUTOFLUSH
@@ -52,6 +52,7 @@ from .base import NO_CHANGE # noqa
from .base import NO_RAISE
from .base import NO_VALUE
from .base import NON_PERSISTENT_OK # noqa
+from .base import opt_manager_of_class as opt_manager_of_class
from .base import PASSIVE_CLASS_MISMATCH # noqa
from .base import PASSIVE_NO_FETCH
from .base import PASSIVE_NO_FETCH_RELATED # noqa
@@ -74,6 +75,7 @@ from ..sql import traversals
from ..sql import visitors
if TYPE_CHECKING:
+ from .interfaces import MapperProperty
from .state import InstanceState
from ..sql.dml import _DMLColumnElement
from ..sql.elements import ColumnElement
@@ -146,7 +148,7 @@ class QueryableAttribute(
self._of_type = of_type
self._extra_criteria = extra_criteria
- manager = manager_of_class(class_)
+ manager = opt_manager_of_class(class_)
# manager is None in the case of AliasedClass
if manager:
# propagate existing event listeners from
@@ -370,7 +372,7 @@ class QueryableAttribute(
return "%s.%s" % (self.class_.__name__, self.key)
@util.memoized_property
- def property(self):
+ def property(self) -> MapperProperty[_T]:
"""Return the :class:`.MapperProperty` associated with this
:class:`.QueryableAttribute`.
diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py
index 3fa855a4b..054d52d83 100644
--- a/lib/sqlalchemy/orm/base.py
+++ b/lib/sqlalchemy/orm/base.py
@@ -26,24 +26,25 @@ from typing import TypeVar
from typing import Union
from . import exc
+from ._typing import insp_is_mapper
from .. import exc as sa_exc
from .. import inspection
from .. import util
from ..sql.elements import SQLCoreOperations
from ..util import FastIntFlag
from ..util.langhelpers import TypingOnly
-from ..util.typing import Concatenate
from ..util.typing import Literal
-from ..util.typing import ParamSpec
from ..util.typing import Self
if typing.TYPE_CHECKING:
from ._typing import _InternalEntityType
from .attributes import InstrumentedAttribute
+ from .instrumentation import ClassManager
from .mapper import Mapper
from .state import InstanceState
from ..sql._typing import _InfoType
+
_T = TypeVar("_T", bound=Any)
_O = TypeVar("_O", bound=object)
@@ -246,21 +247,15 @@ _DEFER_FOR_STATE = util.symbol("DEFER_FOR_STATE")
_RAISE_FOR_STATE = util.symbol("RAISE_FOR_STATE")
-_Fn = TypeVar("_Fn", bound=Callable)
-_Args = ParamSpec("_Args")
+_F = TypeVar("_F", bound=Callable)
_Self = TypeVar("_Self")
def _assertions(
*assertions: Any,
-) -> Callable[
- [Callable[Concatenate[_Self, _Fn, _Args], _Self]],
- Callable[Concatenate[_Self, _Fn, _Args], _Self],
-]:
+) -> Callable[[_F], _F]:
@util.decorator
- def generate(
- fn: _Fn, self: _Self, *args: _Args.args, **kw: _Args.kwargs
- ) -> _Self:
+ def generate(fn: _F, self: _Self, *args: Any, **kw: Any) -> _Self:
for assertion in assertions:
assertion(self, fn.__name__)
fn(self, *args, **kw)
@@ -269,13 +264,13 @@ def _assertions(
return generate
-# these can be replaced by sqlalchemy.ext.instrumentation
-# if augmented class instrumentation is enabled.
-def manager_of_class(cls):
- return cls.__dict__.get(DEFAULT_MANAGER_ATTR, None)
+if TYPE_CHECKING:
+ def manager_of_class(cls: Type[Any]) -> ClassManager:
+ ...
-if TYPE_CHECKING:
+ def opt_manager_of_class(cls: Type[Any]) -> Optional[ClassManager]:
+ ...
def instance_state(instance: _O) -> InstanceState[_O]:
...
@@ -284,6 +279,20 @@ if TYPE_CHECKING:
...
else:
+ # these can be replaced by sqlalchemy.ext.instrumentation
+ # if augmented class instrumentation is enabled.
+
+ def manager_of_class(cls):
+ try:
+ return cls.__dict__[DEFAULT_MANAGER_ATTR]
+ except KeyError as ke:
+ raise exc.UnmappedClassError(
+ cls, f"Can't locate an instrumentation manager for class {cls}"
+ ) from ke
+
+ def opt_manager_of_class(cls):
+ return cls.__dict__.get(DEFAULT_MANAGER_ATTR)
+
instance_state = operator.attrgetter(DEFAULT_STATE_ATTR)
instance_dict = operator.attrgetter("__dict__")
@@ -458,11 +467,12 @@ else:
_state_mapper = util.dottedgetter("manager.mapper")
-@inspection._inspects(type)
-def _inspect_mapped_class(class_, configure=False):
+def _inspect_mapped_class(
+ class_: Type[_O], configure: bool = False
+) -> Optional[Mapper[_O]]:
try:
- class_manager = manager_of_class(class_)
- if not class_manager.is_mapped:
+ class_manager = opt_manager_of_class(class_)
+ if class_manager is None or not class_manager.is_mapped:
return None
mapper = class_manager.mapper
except exc.NO_STATE:
@@ -473,7 +483,28 @@ def _inspect_mapped_class(class_, configure=False):
return mapper
-def class_mapper(class_: Type[_T], configure: bool = True) -> Mapper[_T]:
+@inspection._inspects(type)
+def _inspect_mc(class_: Type[_O]) -> Optional[Mapper[_O]]:
+ try:
+ class_manager = opt_manager_of_class(class_)
+ if class_manager is None or not class_manager.is_mapped:
+ return None
+ mapper = class_manager.mapper
+ except exc.NO_STATE:
+ return None
+ else:
+ return mapper
+
+
+def _parse_mapper_argument(arg: Union[Mapper[_O], Type[_O]]) -> Mapper[_O]:
+ insp = inspection.inspect(arg, raiseerr=False)
+ if insp_is_mapper(insp):
+ return insp
+
+ raise sa_exc.ArgumentError(f"Mapper or mapped class expected, got {arg!r}")
+
+
+def class_mapper(class_: Type[_O], configure: bool = True) -> Mapper[_O]:
"""Given a class, return the primary :class:`_orm.Mapper` associated
with the key.
@@ -502,8 +533,8 @@ def class_mapper(class_: Type[_T], configure: bool = True) -> Mapper[_T]:
class InspectionAttr:
- """A base class applied to all ORM objects that can be returned
- by the :func:`_sa.inspect` function.
+ """A base class applied to all ORM objects and attributes that are
+ related to things that can be returned by the :func:`_sa.inspect` function.
The attributes defined here allow the usage of simple boolean
checks to test basic facts about the object returned.
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py
index 419da65f7..4fee2d383 100644
--- a/lib/sqlalchemy/orm/context.py
+++ b/lib/sqlalchemy/orm/context.py
@@ -63,11 +63,14 @@ from ..sql.visitors import InternalTraversal
if TYPE_CHECKING:
from ._typing import _InternalEntityType
+ from .mapper import Mapper
+ from .query import Query
from ..sql.compiler import _CompilerStackEntry
from ..sql.dml import _DMLTableElement
from ..sql.elements import ColumnElement
from ..sql.selectable import _LabelConventionCallable
from ..sql.selectable import SelectBase
+ from ..sql.type_api import TypeEngine
_path_registry = PathRegistry.root
@@ -211,6 +214,9 @@ class ORMCompileState(CompileState):
_for_refresh_state = False
_render_for_subquery = False
+ attributes: Dict[Any, Any]
+ global_attributes: Dict[Any, Any]
+
statement: Union[Select, FromStatement]
select_statement: Union[Select, FromStatement]
_entities: List[_QueryEntity]
@@ -1930,7 +1936,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
assert right_mapper
adapter = ORMAdapter(
- right, equivalents=right_mapper._equivalent_columns
+ inspect(right), equivalents=right_mapper._equivalent_columns
)
# if an alias() on the right side was generated,
@@ -2075,14 +2081,16 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
def _column_descriptions(
- query_or_select_stmt, compile_state=None, legacy=False
+ query_or_select_stmt: Union[Query, Select, FromStatement],
+ compile_state: Optional[ORMSelectCompileState] = None,
+ legacy: bool = False,
) -> List[ORMColumnDescription]:
if compile_state is None:
compile_state = ORMSelectCompileState._create_entities_collection(
query_or_select_stmt, legacy=legacy
)
ctx = compile_state
- return [
+ d = [
{
"name": ent._label_name,
"type": ent.type,
@@ -2093,17 +2101,10 @@ def _column_descriptions(
else None,
}
for ent, insp_ent in [
- (
- _ent,
- (
- inspect(_ent.entity_zero)
- if _ent.entity_zero is not None
- else None
- ),
- )
- for _ent in ctx._entities
+ (_ent, _ent.entity_zero) for _ent in ctx._entities
]
]
+ return d
def _legacy_filter_by_entity_zero(query_or_augmented_select):
@@ -2157,6 +2158,11 @@ class _QueryEntity:
_null_column_type = False
use_id_for_hash = False
+ _label_name: Optional[str]
+ type: Union[Type[Any], TypeEngine[Any]]
+ expr: Union[_InternalEntityType, ColumnElement[Any]]
+ entity_zero: Optional[_InternalEntityType]
+
def setup_compile_state(self, compile_state: ORMCompileState) -> None:
raise NotImplementedError()
@@ -2234,6 +2240,13 @@ class _MapperEntity(_QueryEntity):
"_polymorphic_discriminator",
)
+ expr: _InternalEntityType
+ mapper: Mapper[Any]
+ entity_zero: _InternalEntityType
+ is_aliased_class: bool
+ path: PathRegistry
+ _label_name: str
+
def __init__(
self, compile_state, entity, entities_collection, is_current_entities
):
@@ -2389,6 +2402,13 @@ class _BundleEntity(_QueryEntity):
"supports_single_entity",
)
+ _entities: List[_QueryEntity]
+ bundle: Bundle
+ type: Type[Any]
+ _label_name: str
+ supports_single_entity: bool
+ expr: Bundle
+
def __init__(
self,
compile_state,
diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py
index 70507015b..0c990f809 100644
--- a/lib/sqlalchemy/orm/decl_api.py
+++ b/lib/sqlalchemy/orm/decl_api.py
@@ -50,7 +50,7 @@ from ..util import hybridproperty
from ..util import typing as compat_typing
if typing.TYPE_CHECKING:
- from .state import InstanceState # noqa
+ from .state import InstanceState
_T = TypeVar("_T", bound=Any)
@@ -280,7 +280,7 @@ class declared_attr(interfaces._MappedAttribute[_T]):
# for the span of the declarative scan_attributes() phase.
# to achieve this we look at the class manager that's configured.
cls = owner
- manager = attributes.manager_of_class(cls)
+ manager = attributes.opt_manager_of_class(cls)
if manager is None:
if not re.match(r"^__.+__$", self.fget.__name__):
# if there is no manager at all, then this class hasn't been
@@ -1294,8 +1294,8 @@ def as_declarative(**kw):
@inspection._inspects(
DeclarativeMeta, DeclarativeBase, DeclarativeAttributeIntercept
)
-def _inspect_decl_meta(cls):
- mp = _inspect_mapped_class(cls)
+def _inspect_decl_meta(cls: Type[Any]) -> Mapper[Any]:
+ mp: Mapper[Any] = _inspect_mapped_class(cls)
if mp is None:
if _DeferredMapperConfig.has_cls(cls):
_DeferredMapperConfig.raise_unmapped_for_cls(cls)
diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py
index 804d05ce1..9c79a4172 100644
--- a/lib/sqlalchemy/orm/decl_base.py
+++ b/lib/sqlalchemy/orm/decl_base.py
@@ -12,6 +12,8 @@ import collections
from typing import Any
from typing import Dict
from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
import weakref
from . import attributes
@@ -42,6 +44,10 @@ from ..sql.schema import Column
from ..sql.schema import Table
from ..util import topological
+if TYPE_CHECKING:
+ from ._typing import _O
+ from ._typing import _RegistryType
+
def _declared_mapping_info(cls):
# deferred mapping
@@ -121,7 +127,7 @@ def _dive_for_cls_manager(cls):
return None
for base in cls.__mro__:
- manager = attributes.manager_of_class(base)
+ manager = attributes.opt_manager_of_class(base)
if manager:
return manager
return None
@@ -171,7 +177,7 @@ class _MapperConfig:
@classmethod
def setup_mapping(cls, registry, cls_, dict_, table, mapper_kw):
- manager = attributes.manager_of_class(cls)
+ manager = attributes.opt_manager_of_class(cls)
if manager and manager.class_ is cls_:
raise exc.InvalidRequestError(
"Class %r already has been " "instrumented declaratively" % cls
@@ -191,7 +197,12 @@ class _MapperConfig:
return cfg_cls(registry, cls_, dict_, table, mapper_kw)
- def __init__(self, registry, cls_, mapper_kw):
+ def __init__(
+ self,
+ registry: _RegistryType,
+ cls_: Type[Any],
+ mapper_kw: Dict[str, Any],
+ ):
self.cls = util.assert_arg_type(cls_, type, "cls_")
self.classname = cls_.__name__
self.properties = util.OrderedDict()
@@ -206,7 +217,7 @@ class _MapperConfig:
init_method=registry.constructor,
)
else:
- manager = attributes.manager_of_class(self.cls)
+ manager = attributes.opt_manager_of_class(self.cls)
if not manager or not manager.is_mapped:
raise exc.InvalidRequestError(
"Class %s has no primary mapper configured. Configure "
diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py
index 8beac472e..4738d8c2c 100644
--- a/lib/sqlalchemy/orm/descriptor_props.py
+++ b/lib/sqlalchemy/orm/descriptor_props.py
@@ -122,7 +122,11 @@ class DescriptorProperty(MapperProperty[_T]):
_CompositeAttrType = Union[
- str, "Column[Any]", "MappedColumn[Any]", "InstrumentedAttribute[Any]"
+ str,
+ "Column[_T]",
+ "MappedColumn[_T]",
+ "InstrumentedAttribute[_T]",
+ "Mapped[_T]",
]
diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py
index c531e7cf1..331c224ee 100644
--- a/lib/sqlalchemy/orm/events.py
+++ b/lib/sqlalchemy/orm/events.py
@@ -11,6 +11,9 @@
from __future__ import annotations
from typing import Any
+from typing import Optional
+from typing import Type
+from typing import TYPE_CHECKING
import weakref
from . import instrumentation
@@ -27,6 +30,10 @@ from .. import exc
from .. import util
from ..util.compat import inspect_getfullargspec
+if TYPE_CHECKING:
+ from ._typing import _O
+ from .instrumentation import ClassManager
+
class InstrumentationEvents(event.Events):
"""Events related to class instrumentation events.
@@ -214,7 +221,7 @@ class InstanceEvents(event.Events):
if issubclass(target, mapperlib.Mapper):
return instrumentation.ClassManager
else:
- manager = instrumentation.manager_of_class(target)
+ manager = instrumentation.opt_manager_of_class(target)
if manager:
return manager
else:
@@ -613,8 +620,8 @@ class _EventsHold(event.RefCollection):
class _InstanceEventsHold(_EventsHold):
all_holds = weakref.WeakKeyDictionary()
- def resolve(self, class_):
- return instrumentation.manager_of_class(class_)
+ def resolve(self, class_: Type[_O]) -> Optional[ClassManager[_O]]:
+ return instrumentation.opt_manager_of_class(class_)
class HoldInstanceEvents(_EventsHold.HoldEvents, InstanceEvents):
pass
diff --git a/lib/sqlalchemy/orm/exc.py b/lib/sqlalchemy/orm/exc.py
index 00829ecbb..529a7cd01 100644
--- a/lib/sqlalchemy/orm/exc.py
+++ b/lib/sqlalchemy/orm/exc.py
@@ -203,7 +203,10 @@ def _default_unmapped(cls) -> Optional[str]:
try:
mappers = base.manager_of_class(cls).mappers
- except (TypeError,) + NO_STATE:
+ except (
+ UnmappedClassError,
+ TypeError,
+ ) + NO_STATE:
mappers = {}
name = _safe_cls_name(cls)
diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py
index 0d4b630da..88ceacd07 100644
--- a/lib/sqlalchemy/orm/instrumentation.py
+++ b/lib/sqlalchemy/orm/instrumentation.py
@@ -33,10 +33,13 @@ alternate instrumentation forms.
from __future__ import annotations
from typing import Any
+from typing import Callable
from typing import Dict
from typing import Generic
from typing import Optional
from typing import Set
+from typing import Tuple
+from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
import weakref
@@ -53,7 +56,9 @@ from ..util import HasMemoized
from ..util.typing import Protocol
if TYPE_CHECKING:
+ from ._typing import _RegistryType
from .attributes import InstrumentedAttribute
+ from .decl_base import _MapperConfig
from .mapper import Mapper
from .state import InstanceState
from ..event import dispatcher
@@ -72,6 +77,11 @@ class _ExpiredAttributeLoaderProto(Protocol):
...
+class _ManagerFactory(Protocol):
+ def __call__(self, class_: Type[_O]) -> ClassManager[_O]:
+ ...
+
+
class ClassManager(
HasMemoized,
Dict[str, "InstrumentedAttribute[Any]"],
@@ -90,12 +100,12 @@ class ClassManager(
expired_attribute_loader: _ExpiredAttributeLoaderProto
"previously known as deferred_scalar_loader"
- init_method = None
+ init_method: Optional[Callable[..., None]]
- factory = None
+ factory: Optional[_ManagerFactory]
- declarative_scan = None
- registry = None
+ declarative_scan: Optional[weakref.ref[_MapperConfig]] = None
+ registry: Optional[_RegistryType] = None
@property
@util.deprecated(
@@ -122,11 +132,13 @@ class ClassManager(
self.local_attrs = {}
self.originals = {}
self._finalized = False
+ self.factory = None
+ self.init_method = None
self._bases = [
mgr
for mgr in [
- manager_of_class(base)
+ opt_manager_of_class(base)
for base in self.class_.__bases__
if isinstance(base, type)
]
@@ -139,7 +151,7 @@ class ClassManager(
self.dispatch._events._new_classmanager_instance(class_, self)
for basecls in class_.__mro__:
- mgr = manager_of_class(basecls)
+ mgr = opt_manager_of_class(basecls)
if mgr is not None:
self.dispatch._update(mgr.dispatch)
@@ -155,16 +167,18 @@ class ClassManager(
def _update_state(
self,
- finalize=False,
- mapper=None,
- registry=None,
- declarative_scan=None,
- expired_attribute_loader=None,
- init_method=None,
- ):
+ finalize: bool = False,
+ mapper: Optional[Mapper[_O]] = None,
+ registry: Optional[_RegistryType] = None,
+ declarative_scan: Optional[_MapperConfig] = None,
+ expired_attribute_loader: Optional[
+ _ExpiredAttributeLoaderProto
+ ] = None,
+ init_method: Optional[Callable[..., None]] = None,
+ ) -> None:
if mapper:
- self.mapper = mapper
+ self.mapper = mapper # type: ignore[assignment]
if registry:
registry._add_manager(self)
if declarative_scan:
@@ -350,7 +364,7 @@ class ClassManager(
def subclass_managers(self, recursive):
for cls in self.class_.__subclasses__():
- mgr = manager_of_class(cls)
+ mgr = opt_manager_of_class(cls)
if mgr is not None and mgr is not self:
yield mgr
if recursive:
@@ -374,7 +388,7 @@ class ClassManager(
self._reset_memoizations()
del self[key]
for cls in self.class_.__subclasses__():
- manager = manager_of_class(cls)
+ manager = opt_manager_of_class(cls)
if manager:
manager.uninstrument_attribute(key, True)
@@ -523,7 +537,7 @@ class _SerializeManager:
manager.dispatch.pickle(state, d)
def __call__(self, state, inst, state_dict):
- state.manager = manager = manager_of_class(self.class_)
+ state.manager = manager = opt_manager_of_class(self.class_)
if manager is None:
raise exc.UnmappedInstanceError(
inst,
@@ -546,9 +560,9 @@ class _SerializeManager:
class InstrumentationFactory:
"""Factory for new ClassManager instances."""
- def create_manager_for_cls(self, class_):
+ def create_manager_for_cls(self, class_: Type[_O]) -> ClassManager[_O]:
assert class_ is not None
- assert manager_of_class(class_) is None
+ assert opt_manager_of_class(class_) is None
# give a more complicated subclass
# a chance to do what it wants here
@@ -557,6 +571,8 @@ class InstrumentationFactory:
if factory is None:
factory = ClassManager
manager = factory(class_)
+ else:
+ assert manager is not None
self._check_conflicts(class_, factory)
@@ -564,11 +580,15 @@ class InstrumentationFactory:
return manager
- def _locate_extended_factory(self, class_):
+ def _locate_extended_factory(
+ self, class_: Type[_O]
+ ) -> Tuple[Optional[ClassManager[_O]], Optional[_ManagerFactory]]:
"""Overridden by a subclass to do an extended lookup."""
return None, None
- def _check_conflicts(self, class_, factory):
+ def _check_conflicts(
+ self, class_: Type[_O], factory: Callable[[Type[_O]], ClassManager[_O]]
+ ):
"""Overridden by a subclass to test for conflicting factories."""
return
@@ -590,24 +610,25 @@ instance_state = _default_state_getter = base.instance_state
instance_dict = _default_dict_getter = base.instance_dict
manager_of_class = _default_manager_getter = base.manager_of_class
+opt_manager_of_class = _default_opt_manager_getter = base.opt_manager_of_class
def register_class(
- class_,
- finalize=True,
- mapper=None,
- registry=None,
- declarative_scan=None,
- expired_attribute_loader=None,
- init_method=None,
-):
+ class_: Type[_O],
+ finalize: bool = True,
+ mapper: Optional[Mapper[_O]] = None,
+ registry: Optional[_RegistryType] = None,
+ declarative_scan: Optional[_MapperConfig] = None,
+ expired_attribute_loader: Optional[_ExpiredAttributeLoaderProto] = None,
+ init_method: Optional[Callable[..., None]] = None,
+) -> ClassManager[_O]:
"""Register class instrumentation.
Returns the existing or newly created class manager.
"""
- manager = manager_of_class(class_)
+ manager = opt_manager_of_class(class_)
if manager is None:
manager = _instrumentation_factory.create_manager_for_cls(class_)
manager._update_state(
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index abc1300d8..0ca62b7e3 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -21,10 +21,15 @@ from __future__ import annotations
import collections
import typing
from typing import Any
+from typing import Callable
from typing import cast
+from typing import ClassVar
+from typing import Dict
+from typing import Iterator
from typing import List
from typing import Optional
from typing import Sequence
+from typing import Set
from typing import Tuple
from typing import Type
from typing import TypeVar
@@ -45,7 +50,6 @@ from .base import NotExtension as NotExtension
from .base import ONETOMANY as ONETOMANY
from .base import SQLORMOperations
from .. import ColumnElement
-from .. import inspect
from .. import inspection
from .. import util
from ..sql import operators
@@ -53,19 +57,47 @@ from ..sql import roles
from ..sql import visitors
from ..sql.base import ExecutableOption
from ..sql.cache_key import HasCacheKey
-from ..sql.elements import SQLCoreOperations
from ..sql.schema import Column
from ..sql.type_api import TypeEngine
from ..util.typing import TypedDict
+
if typing.TYPE_CHECKING:
+ from ._typing import _EntityType
+ from ._typing import _IdentityKeyType
+ from ._typing import _InstanceDict
+ from ._typing import _InternalEntityType
+ from ._typing import _ORMAdapterProto
+ from ._typing import _ORMColumnExprArgument
+ from .attributes import InstrumentedAttribute
+ from .context import _MapperEntity
+ from .context import ORMCompileState
from .decl_api import RegistryType
+ from .loading import _PopulatorDict
+ from .mapper import Mapper
+ from .path_registry import AbstractEntityRegistry
+ from .path_registry import PathRegistry
+ from .query import Query
+ from .session import Session
+ from .state import InstanceState
+ from .strategy_options import _LoadElement
+ from .util import AliasedInsp
+ from .util import CascadeOptions
+ from .util import ORMAdapter
+ from ..engine.result import Result
+ from ..sql._typing import _ColumnExpressionArgument
from ..sql._typing import _ColumnsClauseArgument
from ..sql._typing import _DMLColumnArgument
from ..sql._typing import _InfoType
+ from ..sql._typing import _PropagateAttrsType
+ from ..sql.operators import OperatorType
+ from ..sql.util import ColumnAdapter
+ from ..sql.visitors import _TraverseInternalsType
_T = TypeVar("_T", bound=Any)
+_TLS = TypeVar("_TLS", bound="Type[LoaderStrategy]")
+
class ORMStatementRole(roles.StatementRole):
__slots__ = ()
@@ -91,7 +123,9 @@ class ORMFromClauseRole(roles.StrictFromClauseRole):
class ORMColumnDescription(TypedDict):
name: str
- type: Union[Type, TypeEngine]
+ # TODO: add python_type and sql_type here; combining them
+ # into "type" is a bad idea
+ type: Union[Type[Any], TypeEngine[Any]]
aliased: bool
expr: _ColumnsClauseArgument
entity: Optional[_ColumnsClauseArgument]
@@ -102,10 +136,10 @@ class _IntrospectsAnnotations:
def declarative_scan(
self,
- registry: "RegistryType",
- cls: type,
+ registry: RegistryType,
+ cls: Type[Any],
key: str,
- annotation: Optional[type],
+ annotation: Optional[Type[Any]],
is_dataclass_field: Optional[bool],
) -> None:
"""Perform class-specific initializaton at early declarative scanning
@@ -124,12 +158,12 @@ class _MapsColumns(_MappedAttribute[_T]):
__slots__ = ()
@property
- def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]:
+ def mapper_property_to_assign(self) -> Optional[MapperProperty[_T]]:
"""return a MapperProperty to be assigned to the declarative mapping"""
raise NotImplementedError()
@property
- def columns_to_assign(self) -> List[Column]:
+ def columns_to_assign(self) -> List[Column[_T]]:
"""A list of Column objects that should be declaratively added to the
new Table object.
@@ -139,7 +173,10 @@ class _MapsColumns(_MappedAttribute[_T]):
@inspection._self_inspects
class MapperProperty(
- HasCacheKey, _MappedAttribute[_T], InspectionAttr, util.MemoizedSlots
+ HasCacheKey,
+ _MappedAttribute[_T],
+ InspectionAttrInfo,
+ util.MemoizedSlots,
):
"""Represent a particular class attribute mapped by :class:`_orm.Mapper`.
@@ -160,12 +197,12 @@ class MapperProperty(
"info",
)
- _cache_key_traversal = [
+ _cache_key_traversal: _TraverseInternalsType = [
("parent", visitors.ExtendedInternalTraversal.dp_has_cache_key),
("key", visitors.ExtendedInternalTraversal.dp_string),
]
- cascade = frozenset()
+ cascade: Optional[CascadeOptions] = None
"""The set of 'cascade' attribute names.
This collection is checked before the 'cascade_iterator' method is called.
@@ -184,14 +221,20 @@ class MapperProperty(
"""The :class:`_orm.PropComparator` instance that implements SQL
expression construction on behalf of this mapped attribute."""
- @property
- def _links_to_entity(self):
- """True if this MapperProperty refers to a mapped entity.
+ key: str
+ """name of class attribute"""
- Should only be True for Relationship, False for all others.
+ parent: Mapper[Any]
+ """the :class:`.Mapper` managing this property."""
- """
- raise NotImplementedError()
+ _is_relationship = False
+
+ _links_to_entity: bool
+ """True if this MapperProperty refers to a mapped entity.
+
+ Should only be True for Relationship, False for all others.
+
+ """
def _memoized_attr_info(self) -> _InfoType:
"""Info dictionary associated with the object, allowing user-defined
@@ -217,7 +260,14 @@ class MapperProperty(
"""
return {}
- def setup(self, context, query_entity, path, adapter, **kwargs):
+ def setup(
+ self,
+ context: ORMCompileState,
+ query_entity: _MapperEntity,
+ path: PathRegistry,
+ adapter: Optional[ColumnAdapter],
+ **kwargs: Any,
+ ) -> None:
"""Called by Query for the purposes of constructing a SQL statement.
Each MapperProperty associated with the target mapper processes the
@@ -227,16 +277,30 @@ class MapperProperty(
"""
def create_row_processor(
- self, context, query_entity, path, mapper, result, adapter, populators
- ):
+ self,
+ context: ORMCompileState,
+ query_entity: _MapperEntity,
+ path: PathRegistry,
+ mapper: Mapper[Any],
+ result: Result,
+ adapter: Optional[ColumnAdapter],
+ populators: _PopulatorDict,
+ ) -> None:
"""Produce row processing functions and append to the given
set of populators lists.
"""
def cascade_iterator(
- self, type_, state, dict_, visited_states, halt_on=None
- ):
+ self,
+ type_: str,
+ state: InstanceState[Any],
+ dict_: _InstanceDict,
+ visited_states: Set[InstanceState[Any]],
+ halt_on: Optional[Callable[[InstanceState[Any]], bool]] = None,
+ ) -> Iterator[
+ Tuple[object, Mapper[Any], InstanceState[Any], _InstanceDict]
+ ]:
"""Iterate through instances related to the given instance for
a particular 'cascade', starting with this MapperProperty.
@@ -251,7 +315,7 @@ class MapperProperty(
return iter(())
- def set_parent(self, parent, init):
+ def set_parent(self, parent: Mapper[Any], init: bool) -> None:
"""Set the parent mapper that references this MapperProperty.
This method is overridden by some subclasses to perform extra
@@ -260,7 +324,7 @@ class MapperProperty(
"""
self.parent = parent
- def instrument_class(self, mapper):
+ def instrument_class(self, mapper: Mapper[Any]) -> None:
"""Hook called by the Mapper to the property to initiate
instrumentation of the class attribute managed by this
MapperProperty.
@@ -280,11 +344,11 @@ class MapperProperty(
"""
- def __init__(self):
+ def __init__(self) -> None:
self._configure_started = False
self._configure_finished = False
- def init(self):
+ def init(self) -> None:
"""Called after all mappers are created to assemble
relationships between mappers and perform other post-mapper-creation
initialization steps.
@@ -296,7 +360,7 @@ class MapperProperty(
self._configure_finished = True
@property
- def class_attribute(self):
+ def class_attribute(self) -> InstrumentedAttribute[_T]:
"""Return the class-bound descriptor corresponding to this
:class:`.MapperProperty`.
@@ -319,9 +383,9 @@ class MapperProperty(
"""
- return getattr(self.parent.class_, self.key)
+ return getattr(self.parent.class_, self.key) # type: ignore
- def do_init(self):
+ def do_init(self) -> None:
"""Perform subclass-specific initialization post-mapper-creation
steps.
@@ -330,7 +394,7 @@ class MapperProperty(
"""
- def post_instrument_class(self, mapper):
+ def post_instrument_class(self, mapper: Mapper[Any]) -> None:
"""Perform instrumentation adjustments that need to occur
after init() has completed.
@@ -347,21 +411,21 @@ class MapperProperty(
def merge(
self,
- session,
- source_state,
- source_dict,
- dest_state,
- dest_dict,
- load,
- _recursive,
- _resolve_conflict_map,
- ):
+ session: Session,
+ source_state: InstanceState[Any],
+ source_dict: _InstanceDict,
+ dest_state: InstanceState[Any],
+ dest_dict: _InstanceDict,
+ load: bool,
+ _recursive: Set[InstanceState[Any]],
+ _resolve_conflict_map: Dict[_IdentityKeyType[Any], object],
+ ) -> None:
"""Merge the attribute represented by this ``MapperProperty``
from source to destination object.
"""
- def __repr__(self):
+ def __repr__(self) -> str:
return "<%s at 0x%x; %s>" % (
self.__class__.__name__,
id(self),
@@ -452,21 +516,28 @@ class PropComparator(SQLORMOperations[_T]):
"""
- __slots__ = "prop", "property", "_parententity", "_adapt_to_entity"
+ __slots__ = "prop", "_parententity", "_adapt_to_entity"
__visit_name__ = "orm_prop_comparator"
+ _parententity: _InternalEntityType[Any]
+ _adapt_to_entity: Optional[AliasedInsp[Any]]
+
def __init__(
self,
- prop,
- parentmapper,
- adapt_to_entity=None,
+ prop: MapperProperty[_T],
+ parentmapper: _InternalEntityType[Any],
+ adapt_to_entity: Optional[AliasedInsp[Any]] = None,
):
- self.prop = self.property = prop
+ self.prop = prop
self._parententity = adapt_to_entity or parentmapper
self._adapt_to_entity = adapt_to_entity
- def __clause_element__(self):
+ @util.ro_non_memoized_property
+ def property(self) -> Optional[MapperProperty[_T]]:
+ return self.prop
+
+ def __clause_element__(self) -> _ORMColumnExprArgument[_T]:
raise NotImplementedError("%r" % self)
def _bulk_update_tuples(
@@ -480,22 +551,24 @@ class PropComparator(SQLORMOperations[_T]):
"""
- return [(self.__clause_element__(), value)]
+ return [(cast("_DMLColumnArgument", self.__clause_element__()), value)]
- def adapt_to_entity(self, adapt_to_entity):
+ def adapt_to_entity(
+ self, adapt_to_entity: AliasedInsp[Any]
+ ) -> PropComparator[_T]:
"""Return a copy of this PropComparator which will use the given
:class:`.AliasedInsp` to produce corresponding expressions.
"""
return self.__class__(self.prop, self._parententity, adapt_to_entity)
- @property
- def _parentmapper(self):
+ @util.ro_non_memoized_property
+ def _parentmapper(self) -> Mapper[Any]:
"""legacy; this is renamed to _parententity to be
compatible with QueryableAttribute."""
- return inspect(self._parententity).mapper
+ return self._parententity.mapper
- @property
- def _propagate_attrs(self):
+ @util.memoized_property
+ def _propagate_attrs(self) -> _PropagateAttrsType:
# this suits the case in coercions where we don't actually
# call ``__clause_element__()`` but still need to get
# resolved._propagate_attrs. See #6558.
@@ -507,12 +580,14 @@ class PropComparator(SQLORMOperations[_T]):
)
def _criterion_exists(
- self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any
+ self,
+ criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+ **kwargs: Any,
) -> ColumnElement[Any]:
return self.prop.comparator._criterion_exists(criterion, **kwargs)
- @property
- def adapter(self):
+ @util.ro_non_memoized_property
+ def adapter(self) -> Optional[_ORMAdapterProto[_T]]:
"""Produce a callable that adapts column expressions
to suit an aliased version of this comparator.
@@ -522,20 +597,20 @@ class PropComparator(SQLORMOperations[_T]):
else:
return self._adapt_to_entity._adapt_element
- @util.non_memoized_property
+ @util.ro_non_memoized_property
def info(self) -> _InfoType:
- return self.property.info
+ return self.prop.info
@staticmethod
- def _any_op(a, b, **kwargs):
+ def _any_op(a: Any, b: Any, **kwargs: Any) -> Any:
return a.any(b, **kwargs)
@staticmethod
- def _has_op(left, other, **kwargs):
+ def _has_op(left: Any, other: Any, **kwargs: Any) -> Any:
return left.has(other, **kwargs)
@staticmethod
- def _of_type_op(a, class_):
+ def _of_type_op(a: Any, class_: Any) -> Any:
return a.of_type(class_)
any_op = cast(operators.OperatorType, _any_op)
@@ -545,16 +620,16 @@ class PropComparator(SQLORMOperations[_T]):
if typing.TYPE_CHECKING:
def operate(
- self, op: operators.OperatorType, *other: Any, **kwargs: Any
- ) -> "SQLCoreOperations[Any]":
+ self, op: OperatorType, *other: Any, **kwargs: Any
+ ) -> ColumnElement[Any]:
...
def reverse_operate(
- self, op: operators.OperatorType, other: Any, **kwargs: Any
- ) -> "SQLCoreOperations[Any]":
+ self, op: OperatorType, other: Any, **kwargs: Any
+ ) -> ColumnElement[Any]:
...
- def of_type(self, class_) -> "SQLORMOperations[_T]":
+ def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T]:
r"""Redefine this object in terms of a polymorphic subclass,
:func:`_orm.with_polymorphic` construct, or :func:`_orm.aliased`
construct.
@@ -578,9 +653,11 @@ class PropComparator(SQLORMOperations[_T]):
"""
- return self.operate(PropComparator.of_type_op, class_)
+ return self.operate(PropComparator.of_type_op, class_) # type: ignore
- def and_(self, *criteria) -> "SQLORMOperations[_T]":
+ def and_(
+ self, *criteria: _ColumnExpressionArgument[bool]
+ ) -> ColumnElement[bool]:
"""Add additional criteria to the ON clause that's represented by this
relationship attribute.
@@ -606,10 +683,12 @@ class PropComparator(SQLORMOperations[_T]):
:func:`.with_loader_criteria`
"""
- return self.operate(operators.and_, *criteria)
+ return self.operate(operators.and_, *criteria) # type: ignore
def any(
- self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs
+ self,
+ criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+ **kwargs: Any,
) -> ColumnElement[bool]:
r"""Return a SQL expression representing true if this element
references a member which meets the given criterion.
@@ -626,10 +705,14 @@ class PropComparator(SQLORMOperations[_T]):
"""
- return self.operate(PropComparator.any_op, criterion, **kwargs)
+ return self.operate( # type: ignore
+ PropComparator.any_op, criterion, **kwargs
+ )
def has(
- self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs
+ self,
+ criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+ **kwargs: Any,
) -> ColumnElement[bool]:
r"""Return a SQL expression representing true if this element
references a member which meets the given criterion.
@@ -646,7 +729,9 @@ class PropComparator(SQLORMOperations[_T]):
"""
- return self.operate(PropComparator.has_op, criterion, **kwargs)
+ return self.operate( # type: ignore
+ PropComparator.has_op, criterion, **kwargs
+ )
class StrategizedProperty(MapperProperty[_T]):
@@ -674,23 +759,30 @@ class StrategizedProperty(MapperProperty[_T]):
"strategy_key",
)
inherit_cache = True
- strategy_wildcard_key = None
+ strategy_wildcard_key: ClassVar[str]
strategy_key: Tuple[Any, ...]
- def _memoized_attr__wildcard_token(self):
+ _strategies: Dict[Tuple[Any, ...], LoaderStrategy]
+
+ def _memoized_attr__wildcard_token(self) -> Tuple[str]:
return (
f"{self.strategy_wildcard_key}:{path_registry._WILDCARD_TOKEN}",
)
- def _memoized_attr__default_path_loader_key(self):
+ def _memoized_attr__default_path_loader_key(
+ self,
+ ) -> Tuple[str, Tuple[str]]:
return (
"loader",
(f"{self.strategy_wildcard_key}:{path_registry._DEFAULT_TOKEN}",),
)
- def _get_context_loader(self, context, path):
- load = None
+ def _get_context_loader(
+ self, context: ORMCompileState, path: AbstractEntityRegistry
+ ) -> Optional[_LoadElement]:
+
+ load: Optional[_LoadElement] = None
search_path = path[self]
@@ -714,7 +806,7 @@ class StrategizedProperty(MapperProperty[_T]):
return load
- def _get_strategy(self, key):
+ def _get_strategy(self, key: Tuple[Any, ...]) -> LoaderStrategy:
try:
return self._strategies[key]
except KeyError:
@@ -768,11 +860,13 @@ class StrategizedProperty(MapperProperty[_T]):
):
self.strategy.init_class_attribute(mapper)
- _all_strategies = collections.defaultdict(dict)
+ _all_strategies: collections.defaultdict[
+ Type[Any], Dict[Tuple[Any, ...], Type[LoaderStrategy]]
+ ] = collections.defaultdict(dict)
@classmethod
- def strategy_for(cls, **kw):
- def decorate(dec_cls):
+ def strategy_for(cls, **kw: Any) -> Callable[[_TLS], _TLS]:
+ def decorate(dec_cls: _TLS) -> _TLS:
# ensure each subclass of the strategy has its
# own _strategy_keys collection
if "_strategy_keys" not in dec_cls.__dict__:
@@ -785,7 +879,9 @@ class StrategizedProperty(MapperProperty[_T]):
return decorate
@classmethod
- def _strategy_lookup(cls, requesting_property, *key):
+ def _strategy_lookup(
+ cls, requesting_property: MapperProperty[Any], *key: Any
+ ) -> Type[LoaderStrategy]:
requesting_property.parent._with_polymorphic_mappers
for prop_cls in cls.__mro__:
@@ -984,10 +1080,10 @@ class MapperOption(ORMOption):
"""
- def process_query(self, query):
+ def process_query(self, query: Query[Any]) -> None:
"""Apply a modification to the given :class:`_query.Query`."""
- def process_query_conditionally(self, query):
+ def process_query_conditionally(self, query: Query[Any]) -> None:
"""same as process_query(), except that this option may not
apply to the given query.
@@ -1034,7 +1130,11 @@ class LoaderStrategy:
"strategy_opts",
)
- def __init__(self, parent, strategy_key):
+ _strategy_keys: ClassVar[List[Tuple[Any, ...]]]
+
+ def __init__(
+ self, parent: MapperProperty[Any], strategy_key: Tuple[Any, ...]
+ ):
self.parent_property = parent
self.is_class_level = False
self.parent = self.parent_property.parent
@@ -1042,12 +1142,18 @@ class LoaderStrategy:
self.strategy_key = strategy_key
self.strategy_opts = dict(strategy_key)
- def init_class_attribute(self, mapper):
+ def init_class_attribute(self, mapper: Mapper[Any]) -> None:
pass
def setup_query(
- self, compile_state, query_entity, path, loadopt, adapter, **kwargs
- ):
+ self,
+ compile_state: ORMCompileState,
+ query_entity: _MapperEntity,
+ path: AbstractEntityRegistry,
+ loadopt: Optional[_LoadElement],
+ adapter: Optional[ORMAdapter],
+ **kwargs: Any,
+ ) -> None:
"""Establish column and other state for a given QueryContext.
This method fulfills the contract specified by MapperProperty.setup().
@@ -1059,15 +1165,15 @@ class LoaderStrategy:
def create_row_processor(
self,
- context,
- query_entity,
- path,
- loadopt,
- mapper,
- result,
- adapter,
- populators,
- ):
+ context: ORMCompileState,
+ query_entity: _MapperEntity,
+ path: AbstractEntityRegistry,
+ loadopt: Optional[_LoadElement],
+ mapper: Mapper[Any],
+ result: Result,
+ adapter: Optional[ORMAdapter],
+ populators: _PopulatorDict,
+ ) -> None:
"""Establish row processing functions for a given QueryContext.
This method fulfills the contract specified by
diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py
index ae083054c..d9949eb7a 100644
--- a/lib/sqlalchemy/orm/loading.py
+++ b/lib/sqlalchemy/orm/loading.py
@@ -16,7 +16,9 @@ as well as some of the attribute loading strategies.
from __future__ import annotations
from typing import Any
+from typing import Dict
from typing import Iterable
+from typing import List
from typing import Mapping
from typing import Optional
from typing import Sequence
@@ -65,6 +67,9 @@ _O = TypeVar("_O", bound=object)
_new_runid = util.counter()
+_PopulatorDict = Dict[str, List[Tuple[str, Any]]]
+
+
def instances(cursor, context):
"""Return a :class:`.Result` given an ORM query context.
@@ -383,7 +388,7 @@ def get_from_identity(
mapper: Mapper[_O],
key: _IdentityKeyType[_O],
passive: PassiveFlag,
-) -> Union[Optional[_O], LoaderCallableStatus]:
+) -> Union[LoaderCallableStatus, Optional[_O]]:
"""Look up the given key in the given session's identity map,
check the object for expired state if found.
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index abe11cc68..b37c080ea 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -23,12 +23,23 @@ import sys
import threading
from typing import Any
from typing import Callable
+from typing import cast
+from typing import Collection
+from typing import Deque
+from typing import Dict
from typing import Generic
+from typing import Iterable
from typing import Iterator
+from typing import List
+from typing import Mapping
from typing import Optional
+from typing import Sequence
+from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
import weakref
from . import attributes
@@ -39,8 +50,8 @@ from . import properties
from . import util as orm_util
from ._typing import _O
from .base import _class_to_mapper
+from .base import _parse_mapper_argument
from .base import _state_mapper
-from .base import class_mapper
from .base import PassiveFlag
from .base import state_str
from .interfaces import _MappedAttribute
@@ -58,6 +69,8 @@ from .. import log
from .. import schema
from .. import sql
from .. import util
+from ..event import dispatcher
+from ..event import EventTarget
from ..sql import base as sql_base
from ..sql import coercions
from ..sql import expression
@@ -65,26 +78,68 @@ from ..sql import operators
from ..sql import roles
from ..sql import util as sql_util
from ..sql import visitors
+from ..sql.cache_key import MemoizedHasCacheKey
+from ..sql.schema import Table
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
from ..util import HasMemoized
+from ..util import HasMemoized_ro_memoized_attribute
+from ..util.typing import Literal
if TYPE_CHECKING:
from ._typing import _IdentityKeyType
from ._typing import _InstanceDict
+ from ._typing import _ORMColumnExprArgument
+ from ._typing import _RegistryType
+ from .decl_api import registry
+ from .dependency import DependencyProcessor
+ from .descriptor_props import Composite
+ from .descriptor_props import Synonym
+ from .events import MapperEvents
from .instrumentation import ClassManager
+ from .path_registry import AbstractEntityRegistry
+ from .path_registry import CachingEntityRegistry
+ from .properties import ColumnProperty
+ from .relationships import Relationship
from .state import InstanceState
+ from ..engine import Row
+ from ..engine import RowMapping
+ from ..sql._typing import _ColumnExpressionArgument
+ from ..sql._typing import _EquivalentColumnMap
+ from ..sql.base import ReadOnlyColumnCollection
+ from ..sql.elements import ColumnClause
from ..sql.elements import ColumnElement
from ..sql.schema import Column
+ from ..sql.schema import Table
+ from ..sql.selectable import FromClause
+ from ..sql.selectable import TableClause
+ from ..sql.util import ColumnAdapter
+ from ..util import OrderedSet
-_mapper_registries = weakref.WeakKeyDictionary()
+_T = TypeVar("_T", bound=Any)
+_MP = TypeVar("_MP", bound="MapperProperty[Any]")
-def _all_registries():
+_WithPolymorphicArg = Union[
+ Literal["*"],
+ Tuple[
+ Union[Literal["*"], Sequence[Union["Mapper[Any]", Type[Any]]]],
+ Optional["FromClause"],
+ ],
+ Sequence[Union["Mapper[Any]", Type[Any]]],
+]
+
+
+_mapper_registries: weakref.WeakKeyDictionary[
+ _RegistryType, bool
+] = weakref.WeakKeyDictionary()
+
+
+def _all_registries() -> Set[registry]:
with _CONFIGURE_MUTEX:
return set(_mapper_registries)
-def _unconfigured_mappers():
+def _unconfigured_mappers() -> Iterator[Mapper[Any]]:
for reg in _all_registries():
for mapper in reg._mappers_to_configure():
yield mapper
@@ -107,9 +162,11 @@ _CONFIGURE_MUTEX = threading.RLock()
class Mapper(
ORMFromClauseRole,
ORMEntityColumnsClauseRole,
- sql_base.MemoizedHasCacheKey,
+ MemoizedHasCacheKey,
InspectionAttr,
log.Identified,
+ inspection.Inspectable["Mapper[_O]"],
+ EventTarget,
Generic[_O],
):
"""Defines an association between a Python class and a database table or
@@ -123,18 +180,11 @@ class Mapper(
"""
+ dispatch: dispatcher[Mapper[_O]]
+
_dispose_called = False
_ready_for_configure = False
- class_: Type[_O]
- """The class to which this :class:`_orm.Mapper` is mapped."""
-
- _identity_class: Type[_O]
-
- always_refresh: bool
- allow_partial_pks: bool
- version_id_col: Optional[ColumnElement[Any]]
-
@util.deprecated_params(
non_primary=(
"1.3",
@@ -148,33 +198,39 @@ class Mapper(
def __init__(
self,
class_: Type[_O],
- local_table=None,
- properties=None,
- primary_key=None,
- non_primary=False,
- inherits=None,
- inherit_condition=None,
- inherit_foreign_keys=None,
- always_refresh=False,
- version_id_col=None,
- version_id_generator=None,
- polymorphic_on=None,
- _polymorphic_map=None,
- polymorphic_identity=None,
- concrete=False,
- with_polymorphic=None,
- polymorphic_load=None,
- allow_partial_pks=True,
- batch=True,
- column_prefix=None,
- include_properties=None,
- exclude_properties=None,
- passive_updates=True,
- passive_deletes=False,
- confirm_deleted_rows=True,
- eager_defaults=False,
- legacy_is_orphan=False,
- _compiled_cache_size=100,
+ local_table: Optional[FromClause] = None,
+ properties: Optional[Mapping[str, MapperProperty[Any]]] = None,
+ primary_key: Optional[Iterable[_ORMColumnExprArgument[Any]]] = None,
+ non_primary: bool = False,
+ inherits: Optional[Union[Mapper[Any], Type[Any]]] = None,
+ inherit_condition: Optional[_ColumnExpressionArgument[bool]] = None,
+ inherit_foreign_keys: Optional[
+ Sequence[_ORMColumnExprArgument[Any]]
+ ] = None,
+ always_refresh: bool = False,
+ version_id_col: Optional[_ORMColumnExprArgument[Any]] = None,
+ version_id_generator: Optional[
+ Union[Literal[False], Callable[[Any], Any]]
+ ] = None,
+ polymorphic_on: Optional[
+ Union[_ORMColumnExprArgument[Any], str, MapperProperty[Any]]
+ ] = None,
+ _polymorphic_map: Optional[Dict[Any, Mapper[Any]]] = None,
+ polymorphic_identity: Optional[Any] = None,
+ concrete: bool = False,
+ with_polymorphic: Optional[_WithPolymorphicArg] = None,
+ polymorphic_load: Optional[Literal["selectin", "inline"]] = None,
+ allow_partial_pks: bool = True,
+ batch: bool = True,
+ column_prefix: Optional[str] = None,
+ include_properties: Optional[Sequence[str]] = None,
+ exclude_properties: Optional[Sequence[str]] = None,
+ passive_updates: bool = True,
+ passive_deletes: bool = False,
+ confirm_deleted_rows: bool = True,
+ eager_defaults: bool = False,
+ legacy_is_orphan: bool = False,
+ _compiled_cache_size: int = 100,
):
r"""Direct constructor for a new :class:`_orm.Mapper` object.
@@ -593,8 +649,6 @@ class Mapper(
self.class_.__name__,
)
- self.class_manager = None
-
self._primary_key_argument = util.to_list(primary_key)
self.non_primary = non_primary
@@ -623,17 +677,36 @@ class Mapper(
self.concrete = concrete
self.single = False
- self.inherits = inherits
+
+ if inherits is not None:
+ self.inherits = _parse_mapper_argument(inherits)
+ else:
+ self.inherits = None
+
if local_table is not None:
self.local_table = coercions.expect(
roles.StrictFromClauseRole, local_table
)
+ elif self.inherits:
+ # note this is a new flow as of 2.0 so that
+ # .local_table need not be Optional
+ self.local_table = self.inherits.local_table
+ self.single = True
else:
- self.local_table = None
+ raise sa_exc.ArgumentError(
+ f"Mapper[{self.class_.__name__}(None)] has None for a "
+ "primary table argument and does not specify 'inherits'"
+ )
+
+ if inherit_condition is not None:
+ self.inherit_condition = coercions.expect(
+ roles.OnClauseRole, inherit_condition
+ )
+ else:
+ self.inherit_condition = None
- self.inherit_condition = inherit_condition
self.inherit_foreign_keys = inherit_foreign_keys
- self._init_properties = properties or {}
+ self._init_properties = dict(properties) if properties else {}
self._delete_orphans = []
self.batch = batch
self.eager_defaults = eager_defaults
@@ -694,7 +767,10 @@ class Mapper(
# while a configure_mappers() is occurring (and defer a
# configure_mappers() until construction succeeds)
with _CONFIGURE_MUTEX:
- self.dispatch._events._new_mapper_instance(class_, self)
+
+ cast("MapperEvents", self.dispatch._events)._new_mapper_instance(
+ class_, self
+ )
self._configure_inheritance()
self._configure_class_instrumentation()
self._configure_properties()
@@ -704,16 +780,21 @@ class Mapper(
self._log("constructed")
self._expire_memoizations()
- # major attributes initialized at the classlevel so that
- # they can be Sphinx-documented.
+ def _gen_cache_key(self, anon_map, bindparams):
+ return (self,)
+
+ # ### BEGIN
+ # ATTRIBUTE DECLARATIONS START HERE
is_mapper = True
"""Part of the inspection API."""
represents_outer_join = False
+ registry: _RegistryType
+
@property
- def mapper(self):
+ def mapper(self) -> Mapper[_O]:
"""Part of the inspection API.
Returns self.
@@ -721,9 +802,6 @@ class Mapper(
"""
return self
- def _gen_cache_key(self, anon_map, bindparams):
- return (self,)
-
@property
def entity(self):
r"""Part of the inspection API.
@@ -733,49 +811,109 @@ class Mapper(
"""
return self.class_
- local_table = None
- """The :class:`_expression.Selectable` which this :class:`_orm.Mapper`
- manages.
+ class_: Type[_O]
+ """The class to which this :class:`_orm.Mapper` is mapped."""
+
+ _identity_class: Type[_O]
+
+ _delete_orphans: List[Tuple[str, Type[Any]]]
+ _dependency_processors: List[DependencyProcessor]
+ _memoized_values: Dict[Any, Callable[[], Any]]
+ _inheriting_mappers: util.WeakSequence[Mapper[Any]]
+ _all_tables: Set[Table]
+
+ _pks_by_table: Dict[FromClause, OrderedSet[ColumnClause[Any]]]
+ _cols_by_table: Dict[FromClause, OrderedSet[ColumnElement[Any]]]
+
+ _props: util.OrderedDict[str, MapperProperty[Any]]
+ _init_properties: Dict[str, MapperProperty[Any]]
+
+ _columntoproperty: _ColumnMapping
+
+ _set_polymorphic_identity: Optional[Callable[[InstanceState[_O]], None]]
+ _validate_polymorphic_identity: Optional[
+ Callable[[Mapper[_O], InstanceState[_O], _InstanceDict], None]
+ ]
+
+ tables: Sequence[Table]
+ """A sequence containing the collection of :class:`_schema.Table` objects
+ which this :class:`_orm.Mapper` is aware of.
+
+ If the mapper is mapped to a :class:`_expression.Join`, or an
+ :class:`_expression.Alias`
+ representing a :class:`_expression.Select`, the individual
+ :class:`_schema.Table`
+ objects that comprise the full construct will be represented here.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ """
+
+ validators: util.immutabledict[str, Tuple[str, Dict[str, Any]]]
+ """An immutable dictionary of attributes which have been decorated
+ using the :func:`_orm.validates` decorator.
+
+ The dictionary contains string attribute names as keys
+ mapped to the actual validation method.
+
+ """
+
+ always_refresh: bool
+ allow_partial_pks: bool
+ version_id_col: Optional[ColumnElement[Any]]
+
+ with_polymorphic: Optional[
+ Tuple[
+ Union[Literal["*"], Sequence[Union["Mapper[Any]", Type[Any]]]],
+ Optional["FromClause"],
+ ]
+ ]
+
+ version_id_generator: Optional[Union[Literal[False], Callable[[Any], Any]]]
+
+ local_table: FromClause
+ """The immediate :class:`_expression.FromClause` which this
+ :class:`_orm.Mapper` refers towards.
- Typically is an instance of :class:`_schema.Table` or
- :class:`_expression.Alias`.
- May also be ``None``.
+ Typically is an instance of :class:`_schema.Table`, may be any
+ :class:`.FromClause`.
The "local" table is the
selectable that the :class:`_orm.Mapper` is directly responsible for
managing from an attribute access and flush perspective. For
- non-inheriting mappers, the local table is the same as the
- "mapped" table. For joined-table inheritance mappers, local_table
- will be the particular sub-table of the overall "join" which
- this :class:`_orm.Mapper` represents. If this mapper is a
- single-table inheriting mapper, local_table will be ``None``.
+ non-inheriting mappers, :attr:`.Mapper.local_table` will be the same
+ as :attr:`.Mapper.persist_selectable`. For inheriting mappers,
+ :attr:`.Mapper.local_table` refers to the specific portion of
+ :attr:`.Mapper.persist_selectable` that includes the columns to which
+ this :class:`.Mapper` is loading/persisting, such as a particular
+ :class:`.Table` within a join.
.. seealso::
:attr:`_orm.Mapper.persist_selectable`.
+ :attr:`_orm.Mapper.selectable`.
+
"""
- persist_selectable = None
- """The :class:`_expression.Selectable` to which this :class:`_orm.Mapper`
+ persist_selectable: FromClause
+ """The :class:`_expression.FromClause` to which this :class:`_orm.Mapper`
is mapped.
- Typically an instance of :class:`_schema.Table`,
- :class:`_expression.Join`, or :class:`_expression.Alias`.
-
- The :attr:`_orm.Mapper.persist_selectable` is separate from
- :attr:`_orm.Mapper.selectable` in that the former represents columns
- that are mapped on this class or its superclasses, whereas the
- latter may be a "polymorphic" selectable that contains additional columns
- which are in fact mapped on subclasses only.
+ Typically is an instance of :class:`_schema.Table`, may be any
+ :class:`.FromClause`.
- "persist selectable" is the "thing the mapper writes to" and
- "selectable" is the "thing the mapper selects from".
-
- :attr:`_orm.Mapper.persist_selectable` is also separate from
- :attr:`_orm.Mapper.local_table`, which represents the set of columns that
- are locally mapped on this class directly.
+ The :attr:`_orm.Mapper.persist_selectable` is similar to
+ :attr:`.Mapper.local_table`, but represents the :class:`.FromClause` that
+ represents the inheriting class hierarchy overall in an inheritance
+ scenario.
+ :attr.`.Mapper.persist_selectable` is also separate from the
+ :attr:`.Mapper.selectable` attribute, the latter of which may be an
+ alternate subquery used for selecting columns.
+ :attr.`.Mapper.persist_selectable` is oriented towards columns that
+ will be written on a persist operation.
.. seealso::
@@ -785,16 +923,15 @@ class Mapper(
"""
- inherits = None
+ inherits: Optional[Mapper[Any]]
"""References the :class:`_orm.Mapper` which this :class:`_orm.Mapper`
inherits from, if any.
- This is a *read only* attribute determined during mapper construction.
- Behavior is undefined if directly modified.
-
"""
- configured = False
+ inherit_condition: Optional[ColumnElement[bool]]
+
+ configured: bool = False
"""Represent ``True`` if this :class:`_orm.Mapper` has been configured.
This is a *read only* attribute determined during mapper construction.
@@ -806,7 +943,7 @@ class Mapper(
"""
- concrete = None
+ concrete: bool
"""Represent ``True`` if this :class:`_orm.Mapper` is a concrete
inheritance mapper.
@@ -815,21 +952,6 @@ class Mapper(
"""
- tables = None
- """An iterable containing the collection of :class:`_schema.Table` objects
- which this :class:`_orm.Mapper` is aware of.
-
- If the mapper is mapped to a :class:`_expression.Join`, or an
- :class:`_expression.Alias`
- representing a :class:`_expression.Select`, the individual
- :class:`_schema.Table`
- objects that comprise the full construct will be represented here.
-
- This is a *read only* attribute determined during mapper construction.
- Behavior is undefined if directly modified.
-
- """
-
primary_key: Tuple[Column[Any], ...]
"""An iterable containing the collection of :class:`_schema.Column`
objects
@@ -854,14 +976,6 @@ class Mapper(
"""
- class_: Type[_O]
- """The Python class which this :class:`_orm.Mapper` maps.
-
- This is a *read only* attribute determined during mapper construction.
- Behavior is undefined if directly modified.
-
- """
-
class_manager: ClassManager[_O]
"""The :class:`.ClassManager` which maintains event listeners
and class-bound descriptors for this :class:`_orm.Mapper`.
@@ -871,7 +985,7 @@ class Mapper(
"""
- single = None
+ single: bool
"""Represent ``True`` if this :class:`_orm.Mapper` is a single table
inheritance mapper.
@@ -882,7 +996,7 @@ class Mapper(
"""
- non_primary = None
+ non_primary: bool
"""Represent ``True`` if this :class:`_orm.Mapper` is a "non-primary"
mapper, e.g. a mapper that is used only to select rows but not for
persistence management.
@@ -892,7 +1006,7 @@ class Mapper(
"""
- polymorphic_on = None
+ polymorphic_on: Optional[ColumnElement[Any]]
"""The :class:`_schema.Column` or SQL expression specified as the
``polymorphic_on`` argument
for this :class:`_orm.Mapper`, within an inheritance scenario.
@@ -906,7 +1020,7 @@ class Mapper(
"""
- polymorphic_map = None
+ polymorphic_map: Dict[Any, Mapper[Any]]
"""A mapping of "polymorphic identity" identifiers mapped to
:class:`_orm.Mapper` instances, within an inheritance scenario.
@@ -922,7 +1036,7 @@ class Mapper(
"""
- polymorphic_identity = None
+ polymorphic_identity: Optional[Any]
"""Represent an identifier which is matched against the
:attr:`_orm.Mapper.polymorphic_on` column during result row loading.
@@ -935,7 +1049,7 @@ class Mapper(
"""
- base_mapper = None
+ base_mapper: Mapper[Any]
"""The base-most :class:`_orm.Mapper` in an inheritance chain.
In a non-inheriting scenario, this attribute will always be this
@@ -948,7 +1062,7 @@ class Mapper(
"""
- columns = None
+ columns: ReadOnlyColumnCollection[str, Column[Any]]
"""A collection of :class:`_schema.Column` or other scalar expression
objects maintained by this :class:`_orm.Mapper`.
@@ -965,25 +1079,16 @@ class Mapper(
"""
- validators = None
- """An immutable dictionary of attributes which have been decorated
- using the :func:`_orm.validates` decorator.
-
- The dictionary contains string attribute names as keys
- mapped to the actual validation method.
-
- """
-
- c = None
+ c: ReadOnlyColumnCollection[str, Column[Any]]
"""A synonym for :attr:`_orm.Mapper.columns`."""
- @property
+ @util.non_memoized_property
@util.deprecated("1.3", "Use .persist_selectable")
def mapped_table(self):
return self.persist_selectable
@util.memoized_property
- def _path_registry(self) -> PathRegistry:
+ def _path_registry(self) -> CachingEntityRegistry:
return PathRegistry.per_mapper(self)
def _configure_inheritance(self):
@@ -994,8 +1099,6 @@ class Mapper(
self._inheriting_mappers = util.WeakSequence()
if self.inherits:
- if isinstance(self.inherits, type):
- self.inherits = class_mapper(self.inherits, configure=False)
if not issubclass(self.class_, self.inherits.class_):
raise sa_exc.ArgumentError(
"Class '%s' does not inherit from '%s'"
@@ -1011,11 +1114,9 @@ class Mapper(
"only allowed from a %s mapper"
% (np, self.class_.__name__, np)
)
- # inherit_condition is optional.
- if self.local_table is None:
- self.local_table = self.inherits.local_table
+
+ if self.single:
self.persist_selectable = self.inherits.persist_selectable
- self.single = True
elif self.local_table is not self.inherits.local_table:
if self.concrete:
self.persist_selectable = self.local_table
@@ -1068,6 +1169,7 @@ class Mapper(
self.local_table.description,
)
) from afe
+ assert self.inherits.persist_selectable is not None
self.persist_selectable = sql.join(
self.inherits.persist_selectable,
self.local_table,
@@ -1149,6 +1251,7 @@ class Mapper(
else:
self._all_tables = set()
self.base_mapper = self
+ assert self.local_table is not None
self.persist_selectable = self.local_table
if self.polymorphic_identity is not None:
self.polymorphic_map[self.polymorphic_identity] = self
@@ -1160,21 +1263,34 @@ class Mapper(
% self
)
- def _set_with_polymorphic(self, with_polymorphic):
+ def _set_with_polymorphic(
+ self, with_polymorphic: Optional[_WithPolymorphicArg]
+ ) -> None:
if with_polymorphic == "*":
self.with_polymorphic = ("*", None)
elif isinstance(with_polymorphic, (tuple, list)):
if isinstance(with_polymorphic[0], (str, tuple, list)):
- self.with_polymorphic = with_polymorphic
+ self.with_polymorphic = cast(
+ """Tuple[
+ Union[
+ Literal["*"],
+ Sequence[Union["Mapper[Any]", Type[Any]]],
+ ],
+ Optional["FromClause"],
+ ]""",
+ with_polymorphic,
+ )
else:
self.with_polymorphic = (with_polymorphic, None)
elif with_polymorphic is not None:
- raise sa_exc.ArgumentError("Invalid setting for with_polymorphic")
+ raise sa_exc.ArgumentError(
+ f"Invalid setting for with_polymorphic: {with_polymorphic!r}"
+ )
else:
self.with_polymorphic = None
if self.with_polymorphic and self.with_polymorphic[1] is not None:
- self.with_polymorphic = (
+ self.with_polymorphic = ( # type: ignore
self.with_polymorphic[0],
coercions.expect(
roles.StrictFromClauseRole,
@@ -1191,6 +1307,7 @@ class Mapper(
if self.with_polymorphic is None:
self._set_with_polymorphic((subcl,))
elif self.with_polymorphic[0] != "*":
+ assert isinstance(self.with_polymorphic[0], tuple)
self._set_with_polymorphic(
(self.with_polymorphic[0] + (subcl,), self.with_polymorphic[1])
)
@@ -1241,7 +1358,7 @@ class Mapper(
# we expect that declarative has applied the class manager
# already and set up a registry. if this is None,
# this raises as of 2.0.
- manager = attributes.manager_of_class(self.class_)
+ manager = attributes.opt_manager_of_class(self.class_)
if self.non_primary:
if not manager or not manager.is_mapped:
@@ -1251,6 +1368,8 @@ class Mapper(
"Mapper." % self.class_
)
self.class_manager = manager
+
+ assert manager.registry is not None
self.registry = manager.registry
self._identity_class = manager.mapper._identity_class
manager.registry._add_non_primary_mapper(self)
@@ -1275,7 +1394,7 @@ class Mapper(
manager = instrumentation.register_class(
self.class_,
mapper=self,
- expired_attribute_loader=util.partial(
+ expired_attribute_loader=util.partial( # type: ignore
loading.load_scalar_attributes, self
),
# finalize flag means instrument the __init__ method
@@ -1284,6 +1403,8 @@ class Mapper(
)
self.class_manager = manager
+
+ assert manager.registry is not None
self.registry = manager.registry
# The remaining members can be added by any mapper,
@@ -1315,15 +1436,25 @@ class Mapper(
{name: (method, validation_opts)}
)
- def _set_dispose_flags(self):
+ def _set_dispose_flags(self) -> None:
self.configured = True
self._ready_for_configure = True
self._dispose_called = True
self.__dict__.pop("_configure_failed", None)
- def _configure_pks(self):
- self.tables = sql_util.find_tables(self.persist_selectable)
+ def _configure_pks(self) -> None:
+ self.tables = cast(
+ "List[Table]", sql_util.find_tables(self.persist_selectable)
+ )
+ for t in self.tables:
+ if not isinstance(t, Table):
+ raise sa_exc.ArgumentError(
+ f"ORM mappings can only be made against schema-level "
+ f"Table objects, not TableClause; got "
+ f"tableclause {t.name !r}"
+ )
+ self._all_tables.update(t for t in self.tables if isinstance(t, Table))
self._pks_by_table = {}
self._cols_by_table = {}
@@ -1335,16 +1466,16 @@ class Mapper(
pk_cols = util.column_set(c for c in all_cols if c.primary_key)
# identify primary key columns which are also mapped by this mapper.
- tables = set(self.tables + [self.persist_selectable])
- self._all_tables.update(tables)
- for t in tables:
- if t.primary_key and pk_cols.issuperset(t.primary_key):
+ for fc in set(self.tables).union([self.persist_selectable]):
+ if fc.primary_key and pk_cols.issuperset(fc.primary_key):
# ordering is important since it determines the ordering of
# mapper.primary_key (and therefore query.get())
- self._pks_by_table[t] = util.ordered_column_set(
- t.primary_key
- ).intersection(pk_cols)
- self._cols_by_table[t] = util.ordered_column_set(t.c).intersection(
+ self._pks_by_table[fc] = util.ordered_column_set( # type: ignore # noqa: E501
+ fc.primary_key
+ ).intersection(
+ pk_cols
+ )
+ self._cols_by_table[fc] = util.ordered_column_set(fc.c).intersection( # type: ignore # noqa: E501
all_cols
)
@@ -1386,10 +1517,15 @@ class Mapper(
self.primary_key = self.inherits.primary_key
else:
# determine primary key from argument or persist_selectable pks
+ primary_key: Collection[ColumnElement[Any]]
+
if self._primary_key_argument:
primary_key = [
- self.persist_selectable.corresponding_column(c)
- for c in self._primary_key_argument
+ cc if cc is not None else c
+ for cc, c in (
+ (self.persist_selectable.corresponding_column(c), c)
+ for c in self._primary_key_argument
+ )
]
else:
# if heuristically determined PKs, reduce to the minimal set
@@ -1413,7 +1549,7 @@ class Mapper(
# determine cols that aren't expressed within our tables; mark these
# as "read only" properties which are refreshed upon INSERT/UPDATE
- self._readonly_props = set(
+ self._readonly_props = {
self._columntoproperty[col]
for col in self._columntoproperty
if self._columntoproperty[col] not in self._identity_key_props
@@ -1421,12 +1557,12 @@ class Mapper(
not hasattr(col, "table")
or col.table not in self._cols_by_table
)
- )
+ }
- def _configure_properties(self):
+ def _configure_properties(self) -> None:
# TODO: consider using DedupeColumnCollection
- self.columns = self.c = sql_base.ColumnCollection()
+ self.columns = self.c = sql_base.ColumnCollection() # type: ignore
# object attribute names mapped to MapperProperty objects
self._props = util.OrderedDict()
@@ -1454,7 +1590,6 @@ class Mapper(
continue
column_key = (self.column_prefix or "") + column.key
-
if self._should_exclude(
column.key,
column_key,
@@ -1542,6 +1677,7 @@ class Mapper(
col = self.polymorphic_on
if isinstance(col, schema.Column) and (
self.with_polymorphic is None
+ or self.with_polymorphic[1] is None
or self.with_polymorphic[1].corresponding_column(col)
is None
):
@@ -1763,8 +1899,8 @@ class Mapper(
self.columns.add(col, key)
for col in prop.columns + prop._orig_columns:
- for col in col.proxy_set:
- self._columntoproperty[col] = prop
+ for proxy_col in col.proxy_set:
+ self._columntoproperty[proxy_col] = prop
prop.key = key
@@ -2033,7 +2169,9 @@ class Mapper(
self._check_configure()
return iter(self._props.values())
- def _mappers_from_spec(self, spec, selectable):
+ def _mappers_from_spec(
+ self, spec: Any, selectable: Optional[FromClause]
+ ) -> Sequence[Mapper[Any]]:
"""given a with_polymorphic() argument, return the set of mappers it
represents.
@@ -2044,7 +2182,7 @@ class Mapper(
if spec == "*":
mappers = list(self.self_and_descendants)
elif spec:
- mappers = set()
+ mapper_set = set()
for m in util.to_list(spec):
m = _class_to_mapper(m)
if not m.isa(self):
@@ -2053,10 +2191,10 @@ class Mapper(
)
if selectable is None:
- mappers.update(m.iterate_to_root())
+ mapper_set.update(m.iterate_to_root())
else:
- mappers.add(m)
- mappers = [m for m in self.self_and_descendants if m in mappers]
+ mapper_set.add(m)
+ mappers = [m for m in self.self_and_descendants if m in mapper_set]
else:
mappers = []
@@ -2067,7 +2205,9 @@ class Mapper(
mappers = [m for m in mappers if m.local_table in tables]
return mappers
- def _selectable_from_mappers(self, mappers, innerjoin):
+ def _selectable_from_mappers(
+ self, mappers: Iterable[Mapper[Any]], innerjoin: bool
+ ) -> FromClause:
"""given a list of mappers (assumed to be within this mapper's
inheritance hierarchy), construct an outerjoin amongst those mapper's
mapped tables.
@@ -2098,13 +2238,13 @@ class Mapper(
def _single_table_criterion(self):
if self.single and self.inherits and self.polymorphic_on is not None:
return self.polymorphic_on._annotate({"parentmapper": self}).in_(
- m.polymorphic_identity for m in self.self_and_descendants
+ [m.polymorphic_identity for m in self.self_and_descendants]
)
else:
return None
@HasMemoized.memoized_attribute
- def _with_polymorphic_mappers(self):
+ def _with_polymorphic_mappers(self) -> Sequence[Mapper[Any]]:
self._check_configure()
if not self.with_polymorphic:
@@ -2124,8 +2264,8 @@ class Mapper(
"""
self._check_configure()
- @HasMemoized.memoized_attribute
- def _with_polymorphic_selectable(self):
+ @HasMemoized_ro_memoized_attribute
+ def _with_polymorphic_selectable(self) -> FromClause:
if not self.with_polymorphic:
return self.persist_selectable
@@ -2143,7 +2283,7 @@ class Mapper(
"""
- @HasMemoized.memoized_attribute
+ @HasMemoized_ro_memoized_attribute
def _insert_cols_evaluating_none(self):
return dict(
(
@@ -2250,7 +2390,7 @@ class Mapper(
@HasMemoized.memoized_instancemethod
def __clause_element__(self):
- annotations = {
+ annotations: Dict[str, Any] = {
"entity_namespace": self,
"parententity": self,
"parentmapper": self,
@@ -2290,7 +2430,7 @@ class Mapper(
)
@property
- def selectable(self):
+ def selectable(self) -> FromClause:
"""The :class:`_schema.FromClause` construct this
:class:`_orm.Mapper` selects from by default.
@@ -2302,8 +2442,11 @@ class Mapper(
return self._with_polymorphic_selectable
def _with_polymorphic_args(
- self, spec=None, selectable=False, innerjoin=False
- ):
+ self,
+ spec: Any = None,
+ selectable: Union[Literal[False, None], FromClause] = False,
+ innerjoin: bool = False,
+ ) -> Tuple[Sequence[Mapper[Any]], FromClause]:
if selectable not in (None, False):
selectable = coercions.expect(
roles.StrictFromClauseRole, selectable, allow_select=True
@@ -2357,7 +2500,7 @@ class Mapper(
]
@HasMemoized.memoized_attribute
- def _polymorphic_adapter(self):
+ def _polymorphic_adapter(self) -> Optional[sql_util.ColumnAdapter]:
if self.with_polymorphic:
return sql_util.ColumnAdapter(
self.selectable, equivalents=self._equivalent_columns
@@ -2394,7 +2537,7 @@ class Mapper(
yield c
@HasMemoized.memoized_attribute
- def attrs(self) -> util.ReadOnlyProperties["MapperProperty"]:
+ def attrs(self) -> util.ReadOnlyProperties[MapperProperty[Any]]:
"""A namespace of all :class:`.MapperProperty` objects
associated this mapper.
@@ -2432,7 +2575,7 @@ class Mapper(
return util.ReadOnlyProperties(self._props)
@HasMemoized.memoized_attribute
- def all_orm_descriptors(self):
+ def all_orm_descriptors(self) -> util.ReadOnlyProperties[InspectionAttr]:
"""A namespace of all :class:`.InspectionAttr` attributes associated
with the mapped class.
@@ -2503,7 +2646,7 @@ class Mapper(
@HasMemoized.memoized_attribute
@util.preload_module("sqlalchemy.orm.descriptor_props")
- def synonyms(self):
+ def synonyms(self) -> util.ReadOnlyProperties[Synonym[Any]]:
"""Return a namespace of all :class:`.Synonym`
properties maintained by this :class:`_orm.Mapper`.
@@ -2523,7 +2666,7 @@ class Mapper(
return self.class_
@HasMemoized.memoized_attribute
- def column_attrs(self):
+ def column_attrs(self) -> util.ReadOnlyProperties[ColumnProperty[Any]]:
"""Return a namespace of all :class:`.ColumnProperty`
properties maintained by this :class:`_orm.Mapper`.
@@ -2536,9 +2679,9 @@ class Mapper(
"""
return self._filter_properties(properties.ColumnProperty)
- @util.preload_module("sqlalchemy.orm.relationships")
@HasMemoized.memoized_attribute
- def relationships(self):
+ @util.preload_module("sqlalchemy.orm.relationships")
+ def relationships(self) -> util.ReadOnlyProperties[Relationship[Any]]:
"""A namespace of all :class:`.Relationship` properties
maintained by this :class:`_orm.Mapper`.
@@ -2567,7 +2710,7 @@ class Mapper(
@HasMemoized.memoized_attribute
@util.preload_module("sqlalchemy.orm.descriptor_props")
- def composites(self):
+ def composites(self) -> util.ReadOnlyProperties[Composite[Any]]:
"""Return a namespace of all :class:`.Composite`
properties maintained by this :class:`_orm.Mapper`.
@@ -2582,7 +2725,9 @@ class Mapper(
util.preloaded.orm_descriptor_props.Composite
)
- def _filter_properties(self, type_):
+ def _filter_properties(
+ self, type_: Type[_MP]
+ ) -> util.ReadOnlyProperties[_MP]:
self._check_configure()
return util.ReadOnlyProperties(
util.OrderedDict(
@@ -2610,7 +2755,7 @@ class Mapper(
)
@HasMemoized.memoized_attribute
- def _equivalent_columns(self):
+ def _equivalent_columns(self) -> _EquivalentColumnMap:
"""Create a map of all equivalent columns, based on
the determination of column pairs that are equated to
one another based on inherit condition. This is designed
@@ -2630,18 +2775,18 @@ class Mapper(
}
"""
- result = util.column_dict()
+ result: _EquivalentColumnMap = {}
def visit_binary(binary):
if binary.operator == operators.eq:
if binary.left in result:
result[binary.left].add(binary.right)
else:
- result[binary.left] = util.column_set((binary.right,))
+ result[binary.left] = {binary.right}
if binary.right in result:
result[binary.right].add(binary.left)
else:
- result[binary.right] = util.column_set((binary.left,))
+ result[binary.right] = {binary.left}
for mapper in self.base_mapper.self_and_descendants:
if mapper.inherit_condition is not None:
@@ -2711,13 +2856,13 @@ class Mapper(
return False
- def common_parent(self, other):
+ def common_parent(self, other: Mapper[Any]) -> bool:
"""Return true if the given mapper shares a
common inherited parent as this mapper."""
return self.base_mapper is other.base_mapper
- def is_sibling(self, other):
+ def is_sibling(self, other: Mapper[Any]) -> bool:
"""return true if the other mapper is an inheriting sibling to this
one. common parent but different branch
@@ -2728,7 +2873,9 @@ class Mapper(
and not other.isa(self)
)
- def _canload(self, state, allow_subtypes):
+ def _canload(
+ self, state: InstanceState[Any], allow_subtypes: bool
+ ) -> bool:
s = self.primary_mapper()
if self.polymorphic_on is not None or allow_subtypes:
return _state_mapper(state).isa(s)
@@ -2738,19 +2885,19 @@ class Mapper(
def isa(self, other: Mapper[Any]) -> bool:
"""Return True if the this mapper inherits from the given mapper."""
- m = self
+ m: Optional[Mapper[Any]] = self
while m and m is not other:
m = m.inherits
return bool(m)
- def iterate_to_root(self):
- m = self
+ def iterate_to_root(self) -> Iterator[Mapper[Any]]:
+ m: Optional[Mapper[Any]] = self
while m:
yield m
m = m.inherits
@HasMemoized.memoized_attribute
- def self_and_descendants(self):
+ def self_and_descendants(self) -> Sequence[Mapper[Any]]:
"""The collection including this mapper and all descendant mappers.
This includes not just the immediately inheriting mappers but
@@ -2765,7 +2912,7 @@ class Mapper(
stack.extend(item._inheriting_mappers)
return util.WeakSequence(descendants)
- def polymorphic_iterator(self):
+ def polymorphic_iterator(self) -> Iterator[Mapper[Any]]:
"""Iterate through the collection including this mapper and
all descendant mappers.
@@ -2778,18 +2925,18 @@ class Mapper(
"""
return iter(self.self_and_descendants)
- def primary_mapper(self):
+ def primary_mapper(self) -> Mapper[Any]:
"""Return the primary mapper corresponding to this mapper's class key
(class)."""
return self.class_manager.mapper
@property
- def primary_base_mapper(self):
+ def primary_base_mapper(self) -> Mapper[Any]:
return self.class_manager.mapper.base_mapper
def _result_has_identity_key(self, result, adapter=None):
- pk_cols = self.primary_key
+ pk_cols: Sequence[ColumnClause[Any]] = self.primary_key
if adapter:
pk_cols = [adapter.columns[c] for c in pk_cols]
rk = result.keys()
@@ -2799,25 +2946,35 @@ class Mapper(
else:
return True
- def identity_key_from_row(self, row, identity_token=None, adapter=None):
+ def identity_key_from_row(
+ self,
+ row: Optional[Union[Row, RowMapping]],
+ identity_token: Optional[Any] = None,
+ adapter: Optional[ColumnAdapter] = None,
+ ) -> _IdentityKeyType[_O]:
"""Return an identity-map key for use in storing/retrieving an
item from the identity map.
- :param row: A :class:`.Row` instance. The columns which are
- mapped by this :class:`_orm.Mapper` should be locatable in the row,
- preferably via the :class:`_schema.Column`
- object directly (as is the case
- when a :func:`_expression.select` construct is executed), or
- via string names of the form ``<tablename>_<colname>``.
+ :param row: A :class:`.Row` or :class:`.RowMapping` produced from a
+ result set that selected from the ORM mapped primary key columns.
+
+ .. versionchanged:: 2.0
+ :class:`.Row` or :class:`.RowMapping` are accepted
+ for the "row" argument
"""
- pk_cols = self.primary_key
+ pk_cols: Sequence[ColumnClause[Any]] = self.primary_key
if adapter:
pk_cols = [adapter.columns[c] for c in pk_cols]
+ if hasattr(row, "_mapping"):
+ mapping = row._mapping # type: ignore
+ else:
+ mapping = cast("Mapping[Any, Any]", row)
+
return (
self._identity_class,
- tuple(row[column] for column in pk_cols),
+ tuple(mapping[column] for column in pk_cols), # type: ignore
identity_token,
)
@@ -2852,12 +3009,12 @@ class Mapper(
"""
state = attributes.instance_state(instance)
- return self._identity_key_from_state(state, attributes.PASSIVE_OFF)
+ return self._identity_key_from_state(state, PassiveFlag.PASSIVE_OFF)
def _identity_key_from_state(
self,
state: InstanceState[_O],
- passive: PassiveFlag = attributes.PASSIVE_RETURN_NO_VALUE,
+ passive: PassiveFlag = PassiveFlag.PASSIVE_RETURN_NO_VALUE,
) -> _IdentityKeyType[_O]:
dict_ = state.dict
manager = state.manager
@@ -2884,7 +3041,7 @@ class Mapper(
"""
state = attributes.instance_state(instance)
identity_key = self._identity_key_from_state(
- state, attributes.PASSIVE_OFF
+ state, PassiveFlag.PASSIVE_OFF
)
return identity_key[1]
@@ -2913,14 +3070,14 @@ class Mapper(
@HasMemoized.memoized_attribute
def _all_pk_cols(self):
- collection = set()
+ collection: Set[ColumnClause[Any]] = set()
for table in self.tables:
collection.update(self._pks_by_table[table])
return collection
@HasMemoized.memoized_attribute
def _should_undefer_in_wildcard(self):
- cols = set(self.primary_key)
+ cols: Set[ColumnElement[Any]] = set(self.primary_key)
if self.polymorphic_on is not None:
cols.add(self.polymorphic_on)
return cols
@@ -2951,11 +3108,11 @@ class Mapper(
state = attributes.instance_state(obj)
dict_ = attributes.instance_dict(obj)
return self._get_committed_state_attr_by_column(
- state, dict_, column, passive=attributes.PASSIVE_OFF
+ state, dict_, column, passive=PassiveFlag.PASSIVE_OFF
)
def _get_committed_state_attr_by_column(
- self, state, dict_, column, passive=attributes.PASSIVE_RETURN_NO_VALUE
+ self, state, dict_, column, passive=PassiveFlag.PASSIVE_RETURN_NO_VALUE
):
prop = self._columntoproperty[column]
@@ -2978,7 +3135,7 @@ class Mapper(
col_attribute_names = set(attribute_names).intersection(
state.mapper.column_attrs.keys()
)
- tables = set(
+ tables: Set[FromClause] = set(
chain(
*[
sql_util.find_tables(c, check_columns=True)
@@ -3002,7 +3159,7 @@ class Mapper(
state,
state.dict,
leftcol,
- passive=attributes.PASSIVE_NO_INITIALIZE,
+ passive=PassiveFlag.PASSIVE_NO_INITIALIZE,
)
if leftval in orm_util._none_set:
raise _OptGetColumnsNotAvailable()
@@ -3014,7 +3171,7 @@ class Mapper(
state,
state.dict,
rightcol,
- passive=attributes.PASSIVE_NO_INITIALIZE,
+ passive=PassiveFlag.PASSIVE_NO_INITIALIZE,
)
if rightval in orm_util._none_set:
raise _OptGetColumnsNotAvailable()
@@ -3022,7 +3179,7 @@ class Mapper(
None, rightval, type_=binary.right.type
)
- allconds = []
+ allconds: List[ColumnElement[bool]] = []
start = False
@@ -3035,6 +3192,9 @@ class Mapper(
elif not isinstance(mapper.local_table, expression.TableClause):
return None
if start and not mapper.single:
+ assert mapper.inherits
+ assert not mapper.concrete
+ assert mapper.inherit_condition is not None
allconds.append(mapper.inherit_condition)
tables.add(mapper.local_table)
@@ -3043,11 +3203,13 @@ class Mapper(
# descendant-most class should all be present and joined to each
# other.
try:
- allconds[0] = visitors.cloned_traverse(
+ _traversed = visitors.cloned_traverse(
allconds[0], {}, {"binary": visit_binary}
)
except _OptGetColumnsNotAvailable:
return None
+ else:
+ allconds[0] = _traversed
cond = sql.and_(*allconds)
@@ -3145,6 +3307,8 @@ class Mapper(
for pk in self.primary_key
]
+ in_expr: ColumnElement[Any]
+
if len(primary_key) > 1:
in_expr = sql.tuple_(*primary_key)
else:
@@ -3209,11 +3373,22 @@ class Mapper(
traverse all objects without relying on cascades.
"""
- visited_states = set()
+ visited_states: Set[InstanceState[Any]] = set()
prp, mpp = object(), object()
assert state.mapper.isa(self)
+ # this is actually a recursive structure, fully typing it seems
+ # a little too difficult for what it's worth here
+ visitables: Deque[
+ Tuple[
+ Deque[Any],
+ object,
+ Optional[InstanceState[Any]],
+ Optional[_InstanceDict],
+ ]
+ ]
+
visitables = deque(
[(deque(state.mapper._props.values()), prp, state, state.dict)]
)
@@ -3226,8 +3401,10 @@ class Mapper(
if item_type is prp:
prop = iterator.popleft()
- if type_ not in prop.cascade:
+ if not prop.cascade or type_ not in prop.cascade:
continue
+ assert parent_state is not None
+ assert parent_dict is not None
queue = deque(
prop.cascade_iterator(
type_,
@@ -3267,7 +3444,7 @@ class Mapper(
@HasMemoized.memoized_attribute
def _sorted_tables(self):
- table_to_mapper = {}
+ table_to_mapper: Dict[Table, Mapper[Any]] = {}
for mapper in self.base_mapper.self_and_descendants:
for t in mapper.tables:
@@ -3316,9 +3493,9 @@ class Mapper(
ret[t] = table_to_mapper[t]
return ret
- def _memo(self, key, callable_):
+ def _memo(self, key: Any, callable_: Callable[[], _T]) -> _T:
if key in self._memoized_values:
- return self._memoized_values[key]
+ return cast(_T, self._memoized_values[key])
else:
self._memoized_values[key] = value = callable_()
return value
@@ -3328,14 +3505,22 @@ class Mapper(
"""memoized map of tables to collections of columns to be
synchronized upwards to the base mapper."""
- result = util.defaultdict(list)
+ result: util.defaultdict[
+ Table,
+ List[
+ Tuple[
+ Mapper[Any],
+ List[Tuple[ColumnElement[Any], ColumnElement[Any]]],
+ ]
+ ],
+ ] = util.defaultdict(list)
for table in self._sorted_tables:
cols = set(table.c)
for m in self.iterate_to_root():
if m._inherits_equated_pairs and cols.intersection(
reduce(
- set.union,
+ set.union, # type: ignore
[l.proxy_set for l, r in m._inherits_equated_pairs],
)
):
@@ -3440,7 +3625,7 @@ def _configure_registries(registries, cascade):
else:
return
- Mapper.dispatch._for_class(Mapper).before_configured()
+ Mapper.dispatch._for_class(Mapper).before_configured() # type: ignore # noqa: E501
# initialize properties on all mappers
# note that _mapper_registry is unordered, which
# may randomly conceal/reveal issues related to
@@ -3449,7 +3634,7 @@ def _configure_registries(registries, cascade):
_do_configure_registries(registries, cascade)
finally:
_already_compiling = False
- Mapper.dispatch._for_class(Mapper).after_configured()
+ Mapper.dispatch._for_class(Mapper).after_configured() # type: ignore
@util.preload_module("sqlalchemy.orm.decl_api")
@@ -3480,7 +3665,7 @@ def _do_configure_registries(registries, cascade):
"Original exception was: %s"
% (mapper, mapper._configure_failed)
)
- e._configure_failed = mapper._configure_failed
+ e._configure_failed = mapper._configure_failed # type: ignore
raise e
if not mapper.configured:
@@ -3636,7 +3821,7 @@ def _event_on_init(state, args, kwargs):
instrumenting_mapper._set_polymorphic_identity(state)
-class _ColumnMapping(dict):
+class _ColumnMapping(Dict["ColumnElement[Any]", "MapperProperty[Any]"]):
"""Error reporting helper for mapper._columntoproperty."""
__slots__ = ("mapper",)
diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py
index e2cf1d5b0..361cea975 100644
--- a/lib/sqlalchemy/orm/path_registry.py
+++ b/lib/sqlalchemy/orm/path_registry.py
@@ -13,22 +13,70 @@ from __future__ import annotations
from functools import reduce
from itertools import chain
import logging
+import operator
from typing import Any
+from typing import cast
+from typing import Dict
+from typing import Iterator
+from typing import List
+from typing import Optional
+from typing import overload
from typing import Sequence
from typing import Tuple
+from typing import TYPE_CHECKING
from typing import Union
from . import base as orm_base
+from ._typing import insp_is_mapper_property
from .. import exc
-from .. import inspection
from .. import util
from ..sql import visitors
from ..sql.cache_key import HasCacheKey
+if TYPE_CHECKING:
+ from ._typing import _InternalEntityType
+ from .interfaces import MapperProperty
+ from .mapper import Mapper
+ from .relationships import Relationship
+ from .util import AliasedInsp
+ from ..sql.cache_key import _CacheKeyTraversalType
+ from ..sql.elements import BindParameter
+ from ..sql.visitors import anon_map
+ from ..util.typing import TypeGuard
+
+ def is_root(path: PathRegistry) -> TypeGuard[RootRegistry]:
+ ...
+
+ def is_entity(path: PathRegistry) -> TypeGuard[AbstractEntityRegistry]:
+ ...
+
+else:
+ is_root = operator.attrgetter("is_root")
+ is_entity = operator.attrgetter("is_entity")
+
+
+_SerializedPath = List[Any]
+
+_PathElementType = Union[
+ str, "_InternalEntityType[Any]", "MapperProperty[Any]"
+]
+
+# the representation is in fact
+# a tuple with alternating:
+# [_InternalEntityType[Any], Union[str, MapperProperty[Any]],
+# _InternalEntityType[Any], Union[str, MapperProperty[Any]], ...]
+# this might someday be a tuple of 2-tuples instead, but paths can be
+# chopped at odd intervals as well so this is less flexible
+_PathRepresentation = Tuple[_PathElementType, ...]
+
+_OddPathRepresentation = Sequence["_InternalEntityType[Any]"]
+_EvenPathRepresentation = Sequence[Union["MapperProperty[Any]", str]]
+
+
log = logging.getLogger(__name__)
-def _unreduce_path(path):
+def _unreduce_path(path: _SerializedPath) -> PathRegistry:
return PathRegistry.deserialize(path)
@@ -67,17 +115,18 @@ class PathRegistry(HasCacheKey):
is_token = False
is_root = False
has_entity = False
+ is_entity = False
- path: Tuple
- natural_path: Tuple
- parent: Union["PathRegistry", None]
+ path: _PathRepresentation
+ natural_path: _PathRepresentation
+ parent: Optional[PathRegistry]
+ root: RootRegistry
- root: "PathRegistry"
- _cache_key_traversal = [
+ _cache_key_traversal: _CacheKeyTraversalType = [
("path", visitors.ExtendedInternalTraversal.dp_has_cache_key_list)
]
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
try:
return other is not None and self.path == other._path_for_compare
except AttributeError:
@@ -87,7 +136,7 @@ class PathRegistry(HasCacheKey):
)
return False
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> bool:
try:
return other is None or self.path != other._path_for_compare
except AttributeError:
@@ -98,74 +147,88 @@ class PathRegistry(HasCacheKey):
return True
@property
- def _path_for_compare(self):
+ def _path_for_compare(self) -> Optional[_PathRepresentation]:
return self.path
- def set(self, attributes, key, value):
+ def set(self, attributes: Dict[Any, Any], key: Any, value: Any) -> None:
log.debug("set '%s' on path '%s' to '%s'", key, self, value)
attributes[(key, self.natural_path)] = value
- def setdefault(self, attributes, key, value):
+ def setdefault(
+ self, attributes: Dict[Any, Any], key: Any, value: Any
+ ) -> None:
log.debug("setdefault '%s' on path '%s' to '%s'", key, self, value)
attributes.setdefault((key, self.natural_path), value)
- def get(self, attributes, key, value=None):
+ def get(
+ self, attributes: Dict[Any, Any], key: Any, value: Optional[Any] = None
+ ) -> Any:
key = (key, self.natural_path)
if key in attributes:
return attributes[key]
else:
return value
- def __len__(self):
+ def __len__(self) -> int:
return len(self.path)
- def __hash__(self):
+ def __hash__(self) -> int:
return id(self)
- def __getitem__(self, key: Any) -> "PathRegistry":
+ def __getitem__(self, key: Any) -> PathRegistry:
raise NotImplementedError()
+ # TODO: what are we using this for?
@property
- def length(self):
+ def length(self) -> int:
return len(self.path)
- def pairs(self):
- path = self.path
- for i in range(0, len(path), 2):
- yield path[i], path[i + 1]
-
- def contains_mapper(self, mapper):
- for path_mapper in [self.path[i] for i in range(0, len(self.path), 2)]:
+ def pairs(
+ self,
+ ) -> Iterator[
+ Tuple[_InternalEntityType[Any], Union[str, MapperProperty[Any]]]
+ ]:
+ odd_path = cast(_OddPathRepresentation, self.path)
+ even_path = cast(_EvenPathRepresentation, odd_path)
+ for i in range(0, len(odd_path), 2):
+ yield odd_path[i], even_path[i + 1]
+
+ def contains_mapper(self, mapper: Mapper[Any]) -> bool:
+ _m_path = cast(_OddPathRepresentation, self.path)
+ for path_mapper in [_m_path[i] for i in range(0, len(_m_path), 2)]:
if path_mapper.is_mapper and path_mapper.isa(mapper):
return True
else:
return False
- def contains(self, attributes, key):
+ def contains(self, attributes: Dict[Any, Any], key: Any) -> bool:
return (key, self.path) in attributes
- def __reduce__(self):
+ def __reduce__(self) -> Any:
return _unreduce_path, (self.serialize(),)
@classmethod
- def _serialize_path(cls, path):
+ def _serialize_path(cls, path: _PathRepresentation) -> _SerializedPath:
+ _m_path = cast(_OddPathRepresentation, path)
+ _p_path = cast(_EvenPathRepresentation, path)
+
return list(
zip(
- [
+ tuple(
m.class_ if (m.is_mapper or m.is_aliased_class) else str(m)
- for m in [path[i] for i in range(0, len(path), 2)]
- ],
- [
- path[i].key if (path[i].is_property) else str(path[i])
- for i in range(1, len(path), 2)
- ]
- + [None],
+ for m in [_m_path[i] for i in range(0, len(_m_path), 2)]
+ ),
+ tuple(
+ p.key if insp_is_mapper_property(p) else str(p)
+ for p in [_p_path[i] for i in range(1, len(_p_path), 2)]
+ )
+ + (None,),
)
)
@classmethod
- def _deserialize_path(cls, path):
- def _deserialize_mapper_token(mcls):
+ def _deserialize_path(cls, path: _SerializedPath) -> _PathRepresentation:
+ def _deserialize_mapper_token(mcls: Any) -> Any:
return (
# note: we likely dont want configure=True here however
# this is maintained at the moment for backwards compatibility
@@ -174,15 +237,15 @@ class PathRegistry(HasCacheKey):
else PathToken._intern[mcls]
)
- def _deserialize_key_token(mcls, key):
+ def _deserialize_key_token(mcls: Any, key: Any) -> Any:
if key is None:
return None
elif key in PathToken._intern:
return PathToken._intern[key]
else:
- return orm_base._inspect_mapped_class(
- mcls, configure=True
- ).attrs[key]
+ mp = orm_base._inspect_mapped_class(mcls, configure=True)
+ assert mp is not None
+ return mp.attrs[key]
p = tuple(
chain(
@@ -199,28 +262,63 @@ class PathRegistry(HasCacheKey):
p = p[0:-1]
return p
- def serialize(self) -> Sequence[Any]:
+ def serialize(self) -> _SerializedPath:
path = self.path
return self._serialize_path(path)
@classmethod
- def deserialize(cls, path: Sequence[Any]) -> PathRegistry:
+ def deserialize(cls, path: _SerializedPath) -> PathRegistry:
assert path is not None
p = cls._deserialize_path(path)
return cls.coerce(p)
+ @overload
@classmethod
- def per_mapper(cls, mapper):
+ def per_mapper(cls, mapper: Mapper[Any]) -> CachingEntityRegistry:
+ ...
+
+ @overload
+ @classmethod
+ def per_mapper(cls, mapper: AliasedInsp[Any]) -> SlotsEntityRegistry:
+ ...
+
+ @classmethod
+ def per_mapper(
+ cls, mapper: _InternalEntityType[Any]
+ ) -> AbstractEntityRegistry:
if mapper.is_mapper:
return CachingEntityRegistry(cls.root, mapper)
else:
return SlotsEntityRegistry(cls.root, mapper)
@classmethod
- def coerce(cls, raw):
- return reduce(lambda prev, next: prev[next], raw, cls.root)
+ def coerce(cls, raw: _PathRepresentation) -> PathRegistry:
+ def _red(prev: PathRegistry, next_: _PathElementType) -> PathRegistry:
+ return prev[next_]
+
+ # can't quite get mypy to appreciate this one :)
+ return reduce(_red, raw, cls.root) # type: ignore
+
+ def __add__(self, other: PathRegistry) -> PathRegistry:
+ def _red(prev: PathRegistry, next_: _PathElementType) -> PathRegistry:
+ return prev[next_]
- def token(self, token):
+ return reduce(_red, other.path, self)
+
+ def __str__(self) -> str:
+ return f"ORM Path[{' -> '.join(str(elem) for elem in self.path)}]"
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({self.path!r})"
+
+
+class CreatesToken(PathRegistry):
+ __slots__ = ()
+
+ is_aliased_class: bool
+ is_root: bool
+
+ def token(self, token: str) -> TokenRegistry:
if token.endswith(f":{_WILDCARD_TOKEN}"):
return TokenRegistry(self, token)
elif token.endswith(f":{_DEFAULT_TOKEN}"):
@@ -228,34 +326,47 @@ class PathRegistry(HasCacheKey):
else:
raise exc.ArgumentError(f"invalid token: {token}")
- def __add__(self, other):
- return reduce(lambda prev, next: prev[next], other.path, self)
-
- def __str__(self):
- return f"ORM Path[{' -> '.join(str(elem) for elem in self.path)}]"
-
- def __repr__(self):
- return f"{self.__class__.__name__}({self.path!r})"
-
-class RootRegistry(PathRegistry):
+class RootRegistry(CreatesToken):
"""Root registry, defers to mappers so that
paths are maintained per-root-mapper.
"""
+ __slots__ = ()
+
inherit_cache = True
path = natural_path = ()
has_entity = False
is_aliased_class = False
is_root = True
+ is_unnatural = False
+
+ @overload
+ def __getitem__(self, entity: str) -> TokenRegistry:
+ ...
+
+ @overload
+ def __getitem__(
+ self, entity: _InternalEntityType[Any]
+ ) -> AbstractEntityRegistry:
+ ...
- def __getitem__(self, entity):
+ def __getitem__(
+ self, entity: Union[str, _InternalEntityType[Any]]
+ ) -> Union[TokenRegistry, AbstractEntityRegistry]:
if entity in PathToken._intern:
+ if TYPE_CHECKING:
+ assert isinstance(entity, str)
return TokenRegistry(self, PathToken._intern[entity])
else:
- return inspection.inspect(entity)._path_registry
+ try:
+ return entity._path_registry # type: ignore
+ except AttributeError:
+ raise IndexError(
+ f"invalid argument for RootRegistry.__getitem__: {entity}"
+ )
PathRegistry.root = RootRegistry()
@@ -264,17 +375,19 @@ PathRegistry.root = RootRegistry()
class PathToken(orm_base.InspectionAttr, HasCacheKey, str):
"""cacheable string token"""
- _intern = {}
+ _intern: Dict[str, PathToken] = {}
- def _gen_cache_key(self, anon_map, bindparams):
+ def _gen_cache_key(
+ self, anon_map: anon_map, bindparams: List[BindParameter[Any]]
+ ) -> Tuple[Any, ...]:
return (str(self),)
@property
- def _path_for_compare(self):
+ def _path_for_compare(self) -> Optional[_PathRepresentation]:
return None
@classmethod
- def intern(cls, strvalue):
+ def intern(cls, strvalue: str) -> PathToken:
if strvalue in cls._intern:
return cls._intern[strvalue]
else:
@@ -287,7 +400,10 @@ class TokenRegistry(PathRegistry):
inherit_cache = True
- def __init__(self, parent, token):
+ token: str
+ parent: CreatesToken
+
+ def __init__(self, parent: CreatesToken, token: str):
token = PathToken.intern(token)
self.token = token
@@ -299,21 +415,33 @@ class TokenRegistry(PathRegistry):
is_token = True
- def generate_for_superclasses(self):
- if not self.parent.is_aliased_class and not self.parent.is_root:
- for ent in self.parent.mapper.iterate_to_root():
- yield TokenRegistry(self.parent.parent[ent], self.token)
+ def generate_for_superclasses(self) -> Iterator[PathRegistry]:
+ parent = self.parent
+ if is_root(parent):
+ yield self
+ return
+
+ if TYPE_CHECKING:
+ assert isinstance(parent, AbstractEntityRegistry)
+ if not parent.is_aliased_class:
+ for mp_ent in parent.mapper.iterate_to_root():
+ yield TokenRegistry(parent.parent[mp_ent], self.token)
elif (
- self.parent.is_aliased_class
- and self.parent.entity._is_with_polymorphic
+ parent.is_aliased_class
+ and cast(
+ "AliasedInsp[Any]",
+ parent.entity,
+ )._is_with_polymorphic
):
yield self
- for ent in self.parent.entity._with_polymorphic_entities:
- yield TokenRegistry(self.parent.parent[ent], self.token)
+ for ent in cast(
+ "AliasedInsp[Any]", parent.entity
+ )._with_polymorphic_entities:
+ yield TokenRegistry(parent.parent[ent], self.token)
else:
yield self
- def __getitem__(self, entity):
+ def __getitem__(self, entity: Any) -> Any:
try:
return self.path[entity]
except TypeError as err:
@@ -321,23 +449,42 @@ class TokenRegistry(PathRegistry):
class PropRegistry(PathRegistry):
- is_unnatural = False
+ __slots__ = (
+ "prop",
+ "parent",
+ "path",
+ "natural_path",
+ "has_entity",
+ "entity",
+ "mapper",
+ "_wildcard_path_loader_key",
+ "_default_path_loader_key",
+ "_loader_key",
+ "is_unnatural",
+ )
inherit_cache = True
- def __init__(self, parent, prop):
+ prop: MapperProperty[Any]
+ mapper: Optional[Mapper[Any]]
+ entity: Optional[_InternalEntityType[Any]]
+
+ def __init__(
+ self, parent: AbstractEntityRegistry, prop: MapperProperty[Any]
+ ):
# restate this path in terms of the
# given MapperProperty's parent.
- insp = inspection.inspect(parent[-1])
- natural_parent = parent
+ insp = cast("_InternalEntityType[Any]", parent[-1])
+ natural_parent: AbstractEntityRegistry = parent
+ self.is_unnatural = False
- if not insp.is_aliased_class or insp._use_mapper_path:
+ if not insp.is_aliased_class or insp._use_mapper_path: # type: ignore
parent = natural_parent = parent.parent[prop.parent]
elif (
insp.is_aliased_class
and insp.with_polymorphic_mappers
and prop.parent in insp.with_polymorphic_mappers
):
- subclass_entity = parent[-1]._entity_for_mapper(prop.parent)
+ subclass_entity: _InternalEntityType[Any] = parent[-1]._entity_for_mapper(prop.parent) # type: ignore # noqa: E501
parent = parent.parent[subclass_entity]
# when building a path where with_polymorphic() is in use,
@@ -388,43 +535,74 @@ class PropRegistry(PathRegistry):
self.parent = parent
self.path = parent.path + (prop,)
self.natural_path = natural_parent.natural_path + (prop,)
+ self.has_entity = prop._links_to_entity
+ if prop._is_relationship:
+ if TYPE_CHECKING:
+ assert isinstance(prop, Relationship)
+ self.entity = prop.entity
+ self.mapper = prop.mapper
+ else:
+ self.entity = None
+ self.mapper = None
self._wildcard_path_loader_key = (
"loader",
- parent.path + self.prop._wildcard_token,
+ parent.path + self.prop._wildcard_token, # type: ignore
)
self._default_path_loader_key = self.prop._default_path_loader_key
self._loader_key = ("loader", self.natural_path)
- @util.memoized_property
- def has_entity(self):
- return self.prop._links_to_entity
+ @property
+ def entity_path(self) -> AbstractEntityRegistry:
+ assert self.entity is not None
+ return self[self.entity]
- @util.memoized_property
- def entity(self):
- return self.prop.entity
+ @overload
+ def __getitem__(self, entity: slice) -> _PathRepresentation:
+ ...
- @property
- def mapper(self):
- return self.prop.mapper
+ @overload
+ def __getitem__(self, entity: int) -> _PathElementType:
+ ...
- @property
- def entity_path(self):
- return self[self.entity]
+ @overload
+ def __getitem__(
+ self, entity: _InternalEntityType[Any]
+ ) -> AbstractEntityRegistry:
+ ...
- def __getitem__(self, entity):
+ def __getitem__(
+ self, entity: Union[int, slice, _InternalEntityType[Any]]
+ ) -> Union[AbstractEntityRegistry, _PathElementType, _PathRepresentation]:
if isinstance(entity, (int, slice)):
return self.path[entity]
else:
return SlotsEntityRegistry(self, entity)
-class AbstractEntityRegistry(PathRegistry):
- __slots__ = ()
+class AbstractEntityRegistry(CreatesToken):
+ __slots__ = (
+ "key",
+ "parent",
+ "is_aliased_class",
+ "path",
+ "entity",
+ "natural_path",
+ )
has_entity = True
-
- def __init__(self, parent, entity):
+ is_entity = True
+
+ parent: Union[RootRegistry, PropRegistry]
+ key: _InternalEntityType[Any]
+ entity: _InternalEntityType[Any]
+ is_aliased_class: bool
+
+ def __init__(
+ self,
+ parent: Union[RootRegistry, PropRegistry],
+ entity: _InternalEntityType[Any],
+ ):
self.key = entity
self.parent = parent
self.is_aliased_class = entity.is_aliased_class
@@ -447,11 +625,11 @@ class AbstractEntityRegistry(PathRegistry):
if parent.path and (self.is_aliased_class or parent.is_unnatural):
# this is an infrequent code path used only for loader strategies
# that also make use of of_type().
- if entity.mapper.isa(parent.natural_path[-1].entity):
+ if entity.mapper.isa(parent.natural_path[-1].entity): # type: ignore # noqa: E501
self.natural_path = parent.natural_path + (entity.mapper,)
else:
self.natural_path = parent.natural_path + (
- parent.natural_path[-1].entity,
+ parent.natural_path[-1].entity, # type: ignore
)
# it seems to make sense that since these paths get mixed up
# with statements that are cached or not, we should make
@@ -465,19 +643,35 @@ class AbstractEntityRegistry(PathRegistry):
self.natural_path = self.path
@property
- def entity_path(self):
+ def entity_path(self) -> PathRegistry:
return self
@property
- def mapper(self):
- return inspection.inspect(self.entity).mapper
+ def mapper(self) -> Mapper[Any]:
+ return self.entity.mapper
- def __bool__(self):
+ def __bool__(self) -> bool:
return True
- __nonzero__ = __bool__
+ @overload
+ def __getitem__(self, entity: MapperProperty[Any]) -> PropRegistry:
+ ...
+
+ @overload
+ def __getitem__(self, entity: str) -> TokenRegistry:
+ ...
+
+ @overload
+ def __getitem__(self, entity: int) -> _PathElementType:
+ ...
- def __getitem__(self, entity):
+ @overload
+ def __getitem__(self, entity: slice) -> _PathRepresentation:
+ ...
+
+ def __getitem__(
+ self, entity: Any
+ ) -> Union[_PathElementType, _PathRepresentation, PathRegistry]:
if isinstance(entity, (int, slice)):
return self.path[entity]
elif entity in PathToken._intern:
@@ -491,31 +685,40 @@ class SlotsEntityRegistry(AbstractEntityRegistry):
# version
inherit_cache = True
- __slots__ = (
- "key",
- "parent",
- "is_aliased_class",
- "entity",
- "path",
- "natural_path",
- )
+
+class _ERDict(Dict[Any, Any]):
+ def __init__(self, registry: CachingEntityRegistry):
+ self.registry = registry
+
+ def __missing__(self, key: Any) -> PropRegistry:
+ self[key] = item = PropRegistry(self.registry, key)
+
+ return item
-class CachingEntityRegistry(AbstractEntityRegistry, dict):
+class CachingEntityRegistry(AbstractEntityRegistry):
# for long lived mapper, return dict based caching
# version that creates reference cycles
+ __slots__ = ("_cache",)
+
inherit_cache = True
- def __getitem__(self, entity):
+ def __init__(
+ self,
+ parent: Union[RootRegistry, PropRegistry],
+ entity: _InternalEntityType[Any],
+ ):
+ super().__init__(parent, entity)
+ self._cache = _ERDict(self)
+
+ def pop(self, key: Any, default: Any) -> Any:
+ return self._cache.pop(key, default)
+
+ def __getitem__(self, entity: Any) -> Any:
if isinstance(entity, (int, slice)):
return self.path[entity]
elif isinstance(entity, PathToken):
return TokenRegistry(self, entity)
else:
- return dict.__getitem__(self, entity)
-
- def __missing__(self, key):
- self[key] = item = PropRegistry(self, key)
-
- return item
+ return self._cache[entity]
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index c01825b6d..9f37e8457 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -19,6 +19,8 @@ from typing import cast
from typing import List
from typing import Optional
from typing import Set
+from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
from . import attributes
@@ -38,17 +40,22 @@ from .util import _orm_full_deannotate
from .. import exc as sa_exc
from .. import ForeignKey
from .. import log
-from .. import sql
from .. import util
from ..sql import coercions
from ..sql import roles
from ..sql import sqltypes
from ..sql.schema import Column
+from ..sql.schema import SchemaConst
from ..util.typing import de_optionalize_union_types
from ..util.typing import de_stringify_annotation
from ..util.typing import is_fwd_ref
from ..util.typing import NoneType
+if TYPE_CHECKING:
+ from ._typing import _ORMColumnExprArgument
+ from ..sql._typing import _InfoType
+ from ..sql.elements import ColumnElement
+
_T = TypeVar("_T", bound=Any)
_PT = TypeVar("_PT", bound=Any)
@@ -78,6 +85,10 @@ class ColumnProperty(
inherit_cache = True
_links_to_entity = False
+ columns: List[ColumnElement[Any]]
+
+ _is_polymorphic_discriminator: bool
+
__slots__ = (
"_orig_columns",
"columns",
@@ -99,7 +110,19 @@ class ColumnProperty(
)
def __init__(
- self, column: sql.ColumnElement[_T], *additional_columns, **kwargs
+ self,
+ column: _ORMColumnExprArgument[_T],
+ *additional_columns: _ORMColumnExprArgument[Any],
+ group: Optional[str] = None,
+ deferred: bool = False,
+ raiseload: bool = False,
+ comparator_factory: Optional[Type[PropComparator]] = None,
+ descriptor: Optional[Any] = None,
+ active_history: bool = False,
+ expire_on_flush: bool = True,
+ info: Optional[_InfoType] = None,
+ doc: Optional[str] = None,
+ _instrument: bool = True,
):
super(ColumnProperty, self).__init__()
columns = (column,) + additional_columns
@@ -112,23 +135,24 @@ class ColumnProperty(
)
for c in columns
]
- self.parent = self.key = None
- self.group = kwargs.pop("group", None)
- self.deferred = kwargs.pop("deferred", False)
- self.raiseload = kwargs.pop("raiseload", False)
- self.instrument = kwargs.pop("_instrument", True)
- self.comparator_factory = kwargs.pop(
- "comparator_factory", self.__class__.Comparator
+ self.group = group
+ self.deferred = deferred
+ self.raiseload = raiseload
+ self.instrument = _instrument
+ self.comparator_factory = (
+ comparator_factory
+ if comparator_factory is not None
+ else self.__class__.Comparator
)
- self.descriptor = kwargs.pop("descriptor", None)
- self.active_history = kwargs.pop("active_history", False)
- self.expire_on_flush = kwargs.pop("expire_on_flush", True)
+ self.descriptor = descriptor
+ self.active_history = active_history
+ self.expire_on_flush = expire_on_flush
- if "info" in kwargs:
- self.info = kwargs.pop("info")
+ if info is not None:
+ self.info = info
- if "doc" in kwargs:
- self.doc = kwargs.pop("doc")
+ if doc is not None:
+ self.doc = doc
else:
for col in reversed(self.columns):
doc = getattr(col, "doc", None)
@@ -138,12 +162,6 @@ class ColumnProperty(
else:
self.doc = None
- if kwargs:
- raise TypeError(
- "%s received unexpected keyword argument(s): %s"
- % (self.__class__.__name__, ", ".join(sorted(kwargs.keys())))
- )
-
util.set_creation_order(self)
self.strategy_key = (
@@ -445,7 +463,10 @@ class MappedColumn(
self.deferred = kw.pop("deferred", False)
self.column = cast("Column[_T]", Column(*arg, **kw))
self.foreign_keys = self.column.foreign_keys
- self._has_nullable = "nullable" in kw
+ self._has_nullable = "nullable" in kw and kw.get("nullable") not in (
+ None,
+ SchemaConst.NULL_UNSPECIFIED,
+ )
util.set_creation_order(self)
def _copy(self, **kw):
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index a754bd4f2..395d01a1e 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -30,6 +30,7 @@ from typing import Optional
from typing import Tuple
from typing import TYPE_CHECKING
from typing import TypeVar
+from typing import Union
from . import exc as orm_exc
from . import interfaces
@@ -77,6 +78,8 @@ from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
if TYPE_CHECKING:
from ..sql.selectable import _SetupJoinsElement
+ from ..sql.selectable import Alias
+ from ..sql.selectable import Subquery
__all__ = ["Query", "QueryContext"]
@@ -2769,14 +2772,14 @@ class AliasOption(interfaces.LoaderOption):
"for entities to be matched up to a query that is established "
"via :meth:`.Query.from_statement` and now does nothing.",
)
- def __init__(self, alias):
+ def __init__(self, alias: Union[Alias, Subquery]):
r"""Return a :class:`.MapperOption` that will indicate to the
:class:`_query.Query`
that the main table has been aliased.
"""
- def process_compile_state(self, compile_state):
+ def process_compile_state(self, compile_state: ORMCompileState):
pass
diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py
index 58c7c4efd..66021c9c2 100644
--- a/lib/sqlalchemy/orm/relationships.py
+++ b/lib/sqlalchemy/orm/relationships.py
@@ -21,7 +21,10 @@ import re
import typing
from typing import Any
from typing import Callable
+from typing import Dict
from typing import Optional
+from typing import Sequence
+from typing import Tuple
from typing import Type
from typing import TypeVar
from typing import Union
@@ -30,6 +33,7 @@ import weakref
from . import attributes
from . import strategy_options
from .base import _is_mapped_class
+from .base import class_mapper
from .base import state_str
from .interfaces import _IntrospectsAnnotations
from .interfaces import MANYTOMANY
@@ -53,7 +57,9 @@ from ..sql import expression
from ..sql import operators
from ..sql import roles
from ..sql import visitors
-from ..sql.elements import SQLCoreOperations
+from ..sql._typing import _ColumnExpressionArgument
+from ..sql._typing import _HasClauseElement
+from ..sql.elements import ColumnClause
from ..sql.util import _deep_deannotate
from ..sql.util import _shallow_annotate
from ..sql.util import adapt_criterion_to_null
@@ -61,11 +67,14 @@ from ..sql.util import ClauseAdapter
from ..sql.util import join_condition
from ..sql.util import selectables_overlap
from ..sql.util import visit_binary_product
+from ..util.typing import Literal
if typing.TYPE_CHECKING:
+ from ._typing import _EntityType
from .mapper import Mapper
from .util import AliasedClass
from .util import AliasedInsp
+ from ..sql.elements import ColumnElement
_T = TypeVar("_T", bound=Any)
_PT = TypeVar("_PT", bound=Any)
@@ -81,6 +90,34 @@ _RelationshipArgumentType = Union[
Callable[[], "AliasedClass[_T]"],
]
+_LazyLoadArgumentType = Literal[
+ "select",
+ "joined",
+ "selectin",
+ "subquery",
+ "raise",
+ "raise_on_sql",
+ "noload",
+ "immediate",
+ "dynamic",
+ True,
+ False,
+ None,
+]
+
+
+_RelationshipJoinConditionArgument = Union[
+ str, _ColumnExpressionArgument[bool]
+]
+_ORMOrderByArgument = Union[
+ Literal[False], str, _ColumnExpressionArgument[Any]
+]
+_ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]]
+_ORMColCollectionArgument = Union[
+ str,
+ Sequence[Union[ColumnClause[Any], _HasClauseElement, roles.DMLColumnRole]],
+]
+
def remote(expr):
"""Annotate a portion of a primaryjoin expression
@@ -144,6 +181,7 @@ class Relationship(
inherit_cache = True
_links_to_entity = True
+ _is_relationship = True
_persistence_only = dict(
passive_deletes=False,
@@ -159,38 +197,39 @@ class Relationship(
self,
argument: Optional[_RelationshipArgumentType[_T]] = None,
secondary=None,
+ *,
+ uselist=None,
+ collection_class=None,
primaryjoin=None,
secondaryjoin=None,
- foreign_keys=None,
- uselist=None,
+ back_populates=None,
order_by=False,
backref=None,
- back_populates=None,
+ cascade_backrefs=False,
overlaps=None,
post_update=False,
- cascade=False,
+ cascade="save-update, merge",
viewonly=False,
- lazy="select",
- collection_class=None,
- passive_deletes=_persistence_only["passive_deletes"],
- passive_updates=_persistence_only["passive_updates"],
+ lazy: _LazyLoadArgumentType = "select",
+ passive_deletes=False,
+ passive_updates=True,
+ active_history=False,
+ enable_typechecks=True,
+ foreign_keys=None,
remote_side=None,
- enable_typechecks=_persistence_only["enable_typechecks"],
join_depth=None,
comparator_factory=None,
single_parent=False,
innerjoin=False,
distinct_target_key=None,
- doc=None,
- active_history=_persistence_only["active_history"],
- cascade_backrefs=_persistence_only["cascade_backrefs"],
load_on_pending=False,
- bake_queries=True,
- _local_remote_pairs=None,
query_class=None,
info=None,
omit_join=None,
sync_backref=None,
+ doc=None,
+ bake_queries=True,
+ _local_remote_pairs=None,
_legacy_inactive_history_style=False,
):
super(Relationship, self).__init__()
@@ -250,7 +289,6 @@ class Relationship(
self.omit_join = omit_join
self.local_remote_pairs = _local_remote_pairs
- self.bake_queries = bake_queries
self.load_on_pending = load_on_pending
self.comparator_factory = comparator_factory or Relationship.Comparator
self.comparator = self.comparator_factory(self, None)
@@ -267,12 +305,7 @@ class Relationship(
else:
self._overlaps = ()
- if cascade is not False:
- self.cascade = cascade
- elif self.viewonly:
- self.cascade = "none"
- else:
- self.cascade = "save-update, merge"
+ self.cascade = cascade
self.order_by = order_by
@@ -539,9 +572,9 @@ class Relationship(
def _criterion_exists(
self,
- criterion: Optional[SQLCoreOperations[Any]] = None,
+ criterion: Optional[_ColumnExpressionArgument[bool]] = None,
**kwargs: Any,
- ) -> Exists[bool]:
+ ) -> Exists:
if getattr(self, "_of_type", None):
info = inspect(self._of_type)
target_mapper, to_selectable, is_aliased_class = (
@@ -898,7 +931,12 @@ class Relationship(
comparator: Comparator[_T]
- def _with_parent(self, instance, alias_secondary=True, from_entity=None):
+ def _with_parent(
+ self,
+ instance: object,
+ alias_secondary: bool = True,
+ from_entity: Optional[_EntityType[Any]] = None,
+ ) -> ColumnElement[bool]:
assert instance is not None
adapt_source = None
if from_entity is not None:
@@ -1502,7 +1540,7 @@ class Relationship(
argument = argument
if isinstance(argument, type):
- entity = mapperlib.class_mapper(argument, configure=False)
+ entity = class_mapper(argument, configure=False)
else:
try:
entity = inspect(argument)
@@ -1568,7 +1606,7 @@ class Relationship(
"""Test that this relationship is legal, warn about
inheritance conflicts."""
mapperlib = util.preloaded.orm_mapper
- if self.parent.non_primary and not mapperlib.class_mapper(
+ if self.parent.non_primary and not class_mapper(
self.parent.class_, configure=False
).has_property(self.key):
raise sa_exc.ArgumentError(
@@ -1585,29 +1623,23 @@ class Relationship(
)
@property
- def cascade(self):
+ def cascade(self) -> CascadeOptions:
"""Return the current cascade setting for this
:class:`.Relationship`.
"""
return self._cascade
@cascade.setter
- def cascade(self, cascade):
+ def cascade(self, cascade: Union[str, CascadeOptions]):
self._set_cascade(cascade)
- def _set_cascade(self, cascade):
- cascade = CascadeOptions(cascade)
+ def _set_cascade(self, cascade_arg: Union[str, CascadeOptions]):
+ cascade = CascadeOptions(cascade_arg)
if self.viewonly:
- non_viewonly = set(cascade).difference(
- CascadeOptions._viewonly_cascades
+ cascade = CascadeOptions(
+ cascade.intersection(CascadeOptions._viewonly_cascades)
)
- if non_viewonly:
- raise sa_exc.ArgumentError(
- 'Cascade settings "%s" apply to persistence operations '
- "and should not be combined with a viewonly=True "
- "relationship." % (", ".join(sorted(non_viewonly)))
- )
if "mapper" in self.__dict__:
self._check_cascade_settings(cascade)
@@ -1754,8 +1786,8 @@ class Relationship(
relationship = Relationship(
parent,
self.secondary,
- pj,
- sj,
+ primaryjoin=pj,
+ secondaryjoin=sj,
foreign_keys=foreign_keys,
back_populates=self.key,
**kwargs,
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 5b1d0bb08..74035ec0a 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -39,6 +39,7 @@ from . import persistence
from . import query
from . import state as statelib
from ._typing import _O
+from ._typing import insp_is_mapper
from ._typing import is_composite_class
from ._typing import is_user_defined_option
from .base import _class_to_mapper
@@ -69,12 +70,14 @@ from ..engine.util import TransactionalContext
from ..event import dispatcher
from ..event import EventTarget
from ..inspection import inspect
+from ..inspection import Inspectable
from ..sql import coercions
from ..sql import dml
from ..sql import roles
from ..sql import Select
from ..sql import visitors
from ..sql.base import CompileState
+from ..sql.schema import Table
from ..sql.selectable import ForUpdateArg
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
from ..util import IdentitySet
@@ -90,6 +93,7 @@ if typing.TYPE_CHECKING:
from .path_registry import PathRegistry
from ..engine import Result
from ..engine import Row
+ from ..engine import RowMapping
from ..engine.base import Transaction
from ..engine.base import TwoPhaseTransaction
from ..engine.interfaces import _CoreAnyExecuteParams
@@ -103,6 +107,7 @@ if typing.TYPE_CHECKING:
from ..sql.base import Executable
from ..sql.elements import ClauseElement
from ..sql.schema import Table
+ from ..sql.selectable import TableClause
__all__ = [
"Session",
@@ -184,7 +189,7 @@ class _SessionClassMethods:
ident: Union[Any, Tuple[Any, ...]] = None,
*,
instance: Optional[Any] = None,
- row: Optional[Row] = None,
+ row: Optional[Union[Row, RowMapping]] = None,
identity_token: Optional[Any] = None,
) -> _IdentityKeyType[Any]:
"""Return an identity key.
@@ -2050,9 +2055,12 @@ class Session(_SessionClassMethods, EventTarget):
else:
self.__binds[key] = bind
else:
- if insp.is_selectable:
+ if TYPE_CHECKING:
+ assert isinstance(insp, Inspectable)
+
+ if isinstance(insp, Table):
self.__binds[insp] = bind
- elif insp.is_mapper:
+ elif insp_is_mapper(insp):
self.__binds[insp.class_] = bind
for _selectable in insp._all_tables:
self.__binds[_selectable] = bind
@@ -2211,7 +2219,7 @@ class Session(_SessionClassMethods, EventTarget):
# we don't have self.bind and either have self.__binds
# or we don't have self.__binds (which is legacy). Look at the
# mapper and the clause
- if mapper is clause is None:
+ if mapper is None and clause is None:
if self.bind:
return self.bind
else:
@@ -2350,7 +2358,10 @@ class Session(_SessionClassMethods, EventTarget):
key = mapper.identity_key_from_primary_key(
primary_key_identity, identity_token=identity_token
)
- return loading.get_from_identity(self, mapper, key, passive)
+
+ # work around: https://github.com/python/typing/discussions/1143
+ return_value = loading.get_from_identity(self, mapper, key, passive)
+ return return_value
@util.non_memoized_property
@contextlib.contextmanager
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index 85e015193..2d85ba7f6 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -2162,10 +2162,11 @@ class JoinedLoader(AbstractRelationshipLoader):
else:
to_adapt = self._gen_pooled_aliased_class(compile_state)
- clauses = inspect(to_adapt)._memo(
+ to_adapt_insp = inspect(to_adapt)
+ clauses = to_adapt_insp._memo(
("joinedloader_ormadapter", self),
orm_util.ORMAdapter,
- to_adapt,
+ to_adapt_insp,
equivalents=self.mapper._equivalent_columns,
adapt_required=True,
allow_label_resolve=False,
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 4699781a4..3934de535 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -11,8 +11,16 @@ import re
import types
import typing
from typing import Any
+from typing import cast
+from typing import Dict
+from typing import FrozenSet
from typing import Generic
+from typing import Iterable
+from typing import Iterator
+from typing import List
+from typing import Match
from typing import Optional
+from typing import Sequence
from typing import Tuple
from typing import Type
from typing import TypeVar
@@ -20,32 +28,35 @@ from typing import Union
import weakref
from . import attributes # noqa
-from .base import _class_to_mapper # noqa
-from .base import _never_set # noqa
-from .base import _none_set # noqa
-from .base import attribute_str # noqa
-from .base import class_mapper # noqa
-from .base import InspectionAttr # noqa
-from .base import instance_str # noqa
-from .base import object_mapper # noqa
-from .base import object_state # noqa
-from .base import state_attribute_str # noqa
-from .base import state_class_str # noqa
-from .base import state_str # noqa
+from ._typing import _O
+from ._typing import insp_is_aliased_class
+from ._typing import insp_is_mapper
+from ._typing import prop_is_relationship
+from .base import _class_to_mapper as _class_to_mapper
+from .base import _never_set as _never_set
+from .base import _none_set as _none_set
+from .base import attribute_str as attribute_str
+from .base import class_mapper as class_mapper
+from .base import InspectionAttr as InspectionAttr
+from .base import instance_str as instance_str
+from .base import object_mapper as object_mapper
+from .base import object_state as object_state
+from .base import state_attribute_str as state_attribute_str
+from .base import state_class_str as state_class_str
+from .base import state_str as state_str
from .interfaces import CriteriaOption
-from .interfaces import MapperProperty # noqa
+from .interfaces import MapperProperty as MapperProperty
from .interfaces import ORMColumnsClauseRole
from .interfaces import ORMEntityColumnsClauseRole
from .interfaces import ORMFromClauseRole
-from .interfaces import PropComparator # noqa
-from .path_registry import PathRegistry # noqa
+from .interfaces import PropComparator as PropComparator
+from .path_registry import PathRegistry as PathRegistry
from .. import event
from .. import exc as sa_exc
from .. import inspection
from .. import sql
from .. import util
from ..engine.result import result_tuple
-from ..sql import base as sql_base
from ..sql import coercions
from ..sql import expression
from ..sql import lambdas
@@ -54,19 +65,39 @@ from ..sql import util as sql_util
from ..sql import visitors
from ..sql.annotation import SupportsCloneAnnotations
from ..sql.base import ColumnCollection
+from ..sql.cache_key import HasCacheKey
+from ..sql.cache_key import MemoizedHasCacheKey
+from ..sql.elements import ColumnElement
from ..sql.selectable import FromClause
from ..util.langhelpers import MemoizedSlots
from ..util.typing import de_stringify_annotation
from ..util.typing import is_origin_of
+from ..util.typing import Literal
if typing.TYPE_CHECKING:
from ._typing import _EntityType
from ._typing import _IdentityKeyType
from ._typing import _InternalEntityType
+ from ._typing import _ORMColumnExprArgument
+ from .context import _MapperEntity
+ from .context import ORMCompileState
from .mapper import Mapper
+ from .relationships import Relationship
from ..engine import Row
+ from ..engine import RowMapping
+ from ..sql._typing import _ColumnExpressionArgument
+ from ..sql._typing import _EquivalentColumnMap
+ from ..sql._typing import _FromClauseArgument
+ from ..sql._typing import _OnClauseArgument
from ..sql._typing import _PropagateAttrsType
+ from ..sql.base import ReadOnlyColumnCollection
+ from ..sql.elements import BindParameter
+ from ..sql.selectable import _ColumnsClauseElement
from ..sql.selectable import Alias
+ from ..sql.selectable import Subquery
+ from ..sql.visitors import _ET
+ from ..sql.visitors import anon_map
+ from ..sql.visitors import ExternallyTraversible
_T = TypeVar("_T", bound=Any)
@@ -84,7 +115,7 @@ all_cascades = frozenset(
)
-class CascadeOptions(frozenset):
+class CascadeOptions(FrozenSet[str]):
"""Keeps track of the options sent to
:paramref:`.relationship.cascade`"""
@@ -104,6 +135,13 @@ class CascadeOptions(frozenset):
"delete_orphan",
)
+ save_update: bool
+ delete: bool
+ refresh_expire: bool
+ merge: bool
+ expunge: bool
+ delete_orphan: bool
+
def __new__(cls, value_list):
if isinstance(value_list, str) or value_list is None:
return cls.from_string(value_list)
@@ -127,7 +165,7 @@ class CascadeOptions(frozenset):
values.clear()
values.discard("all")
- self = frozenset.__new__(CascadeOptions, values)
+ self = super().__new__(cls, values) # type: ignore
self.save_update = "save-update" in values
self.delete = "delete" in values
self.refresh_expire = "refresh-expire" in values
@@ -238,7 +276,7 @@ def polymorphic_union(
"""
- colnames = util.OrderedSet()
+ colnames: util.OrderedSet[str] = util.OrderedSet()
colnamemaps = {}
types = {}
for key in table_map:
@@ -299,13 +337,13 @@ def polymorphic_union(
def identity_key(
- class_: Optional[Type[Any]] = None,
+ class_: Optional[Type[_T]] = None,
ident: Union[Any, Tuple[Any, ...]] = None,
*,
- instance: Optional[Any] = None,
- row: Optional[Row] = None,
+ instance: Optional[_T] = None,
+ row: Optional[Union[Row, RowMapping]] = None,
identity_token: Optional[Any] = None,
-) -> _IdentityKeyType:
+) -> _IdentityKeyType[_T]:
r"""Generate "identity key" tuples, as are used as keys in the
:attr:`.Session.identity_map` dictionary.
@@ -351,7 +389,7 @@ def identity_key(
* ``identity_key(class, row=row, identity_token=token)``
This form is similar to the class/tuple form, except is passed a
- database result row as a :class:`.Row` object.
+ database result row as a :class:`.Row` or :class:`.RowMapping` object.
E.g.::
@@ -375,7 +413,7 @@ def identity_key(
if ident is None:
raise sa_exc.ArgumentError("ident or row is required")
return mapper.identity_key_from_primary_key(
- util.to_list(ident), identity_token=identity_token
+ tuple(util.to_list(ident)), identity_token=identity_token
)
else:
return mapper.identity_key_from_row(
@@ -394,24 +432,26 @@ class ORMAdapter(sql_util.ColumnAdapter):
"""
- is_aliased_class = False
- aliased_insp = None
+ is_aliased_class: bool
+ aliased_insp: Optional[AliasedInsp[Any]]
def __init__(
self,
- entity,
- equivalents=None,
- adapt_required=False,
- allow_label_resolve=True,
- anonymize_labels=False,
+ entity: _InternalEntityType[Any],
+ equivalents: Optional[_EquivalentColumnMap] = None,
+ adapt_required: bool = False,
+ allow_label_resolve: bool = True,
+ anonymize_labels: bool = False,
):
- info = inspection.inspect(entity)
- self.mapper = info.mapper
- selectable = info.selectable
- if info.is_aliased_class:
+ self.mapper = entity.mapper
+ selectable = entity.selectable
+ if insp_is_aliased_class(entity):
self.is_aliased_class = True
- self.aliased_insp = info
+ self.aliased_insp = entity
+ else:
+ self.is_aliased_class = False
+ self.aliased_insp = None
sql_util.ColumnAdapter.__init__(
self,
@@ -428,7 +468,7 @@ class ORMAdapter(sql_util.ColumnAdapter):
return not entity or entity.isa(self.mapper)
-class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]):
+class AliasedClass(inspection.Inspectable["AliasedInsp[_O]"], Generic[_O]):
r"""Represents an "aliased" form of a mapped class for usage with Query.
The ORM equivalent of a :func:`~sqlalchemy.sql.expression.alias`
@@ -489,19 +529,20 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]):
def __init__(
self,
- mapped_class_or_ac: Union[Type[_T], "Mapper[_T]", "AliasedClass[_T]"],
- alias=None,
- name=None,
- flat=False,
- adapt_on_names=False,
- # TODO: None for default here?
- with_polymorphic_mappers=(),
- with_polymorphic_discriminator=None,
- base_alias=None,
- use_mapper_path=False,
- represents_outer_join=False,
+ mapped_class_or_ac: _EntityType[_O],
+ alias: Optional[FromClause] = None,
+ name: Optional[str] = None,
+ flat: bool = False,
+ adapt_on_names: bool = False,
+ with_polymorphic_mappers: Optional[Sequence[Mapper[Any]]] = None,
+ with_polymorphic_discriminator: Optional[ColumnElement[Any]] = None,
+ base_alias: Optional[AliasedInsp[Any]] = None,
+ use_mapper_path: bool = False,
+ represents_outer_join: bool = False,
):
- insp = inspection.inspect(mapped_class_or_ac)
+ insp = cast(
+ "_InternalEntityType[_O]", inspection.inspect(mapped_class_or_ac)
+ )
mapper = insp.mapper
nest_adapters = False
@@ -519,6 +560,7 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]):
elif insp.is_aliased_class:
nest_adapters = True
+ assert alias is not None
self._aliased_insp = AliasedInsp(
self,
insp,
@@ -540,7 +582,9 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]):
self.__name__ = f"aliased({mapper.class_.__name__})"
@classmethod
- def _reconstitute_from_aliased_insp(cls, aliased_insp):
+ def _reconstitute_from_aliased_insp(
+ cls, aliased_insp: AliasedInsp[_O]
+ ) -> AliasedClass[_O]:
obj = cls.__new__(cls)
obj.__name__ = f"aliased({aliased_insp.mapper.class_.__name__})"
obj._aliased_insp = aliased_insp
@@ -555,7 +599,7 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]):
return obj
- def __getattr__(self, key):
+ def __getattr__(self, key: str) -> Any:
try:
_aliased_insp = self.__dict__["_aliased_insp"]
except KeyError:
@@ -584,7 +628,9 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]):
return attr
- def _get_from_serialized(self, key, mapped_class, aliased_insp):
+ def _get_from_serialized(
+ self, key: str, mapped_class: _O, aliased_insp: AliasedInsp[_O]
+ ) -> Any:
# this method is only used in terms of the
# sqlalchemy.ext.serializer extension
attr = getattr(mapped_class, key)
@@ -605,23 +651,25 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]):
return attr
- def __repr__(self):
+ def __repr__(self) -> str:
return "<AliasedClass at 0x%x; %s>" % (
id(self),
self._aliased_insp._target.__name__,
)
- def __str__(self):
+ def __str__(self) -> str:
return str(self._aliased_insp)
+@inspection._self_inspects
class AliasedInsp(
ORMEntityColumnsClauseRole,
ORMFromClauseRole,
- sql_base.HasCacheKey,
+ HasCacheKey,
InspectionAttr,
MemoizedSlots,
- Generic[_T],
+ inspection.Inspectable["AliasedInsp[_O]"],
+ Generic[_O],
):
"""Provide an inspection interface for an
:class:`.AliasedClass` object.
@@ -685,19 +733,36 @@ class AliasedInsp(
"_nest_adapters",
)
+ mapper: Mapper[_O]
+ selectable: FromClause
+ _adapter: sql_util.ColumnAdapter
+ with_polymorphic_mappers: Sequence[Mapper[Any]]
+ _with_polymorphic_entities: Sequence[AliasedInsp[Any]]
+
+ _weak_entity: weakref.ref[AliasedClass[_O]]
+ """the AliasedClass that refers to this AliasedInsp"""
+
+ _target: Union[_O, AliasedClass[_O]]
+ """the thing referred towards by the AliasedClass/AliasedInsp.
+
+ In the vast majority of cases, this is the mapped class. However
+ it may also be another AliasedClass (alias of alias).
+
+ """
+
def __init__(
self,
- entity: _EntityType,
- inspected: _InternalEntityType,
- selectable,
- name,
- with_polymorphic_mappers,
- polymorphic_on,
- _base_alias,
- _use_mapper_path,
- adapt_on_names,
- represents_outer_join,
- nest_adapters,
+ entity: AliasedClass[_O],
+ inspected: _InternalEntityType[_O],
+ selectable: FromClause,
+ name: Optional[str],
+ with_polymorphic_mappers: Optional[Sequence[Mapper[Any]]],
+ polymorphic_on: Optional[ColumnElement[Any]],
+ _base_alias: Optional[AliasedInsp[Any]],
+ _use_mapper_path: bool,
+ adapt_on_names: bool,
+ represents_outer_join: bool,
+ nest_adapters: bool,
):
mapped_class_or_ac = inspected.entity
@@ -752,23 +817,22 @@ class AliasedInsp(
)
if nest_adapters:
+ # supports "aliased class of aliased class" use case
+ assert isinstance(inspected, AliasedInsp)
self._adapter = inspected._adapter.wrap(self._adapter)
self._adapt_on_names = adapt_on_names
self._target = mapped_class_or_ac
- # self._target = mapper.class_ # mapped_class_or_ac
@classmethod
def _alias_factory(
cls,
- element: Union[
- Type[_T], "Mapper[_T]", "FromClause", "AliasedClass[_T]"
- ],
- alias=None,
- name=None,
- flat=False,
- adapt_on_names=False,
- ) -> Union["AliasedClass[_T]", "Alias"]:
+ element: Union[_EntityType[_O], FromClause],
+ alias: Optional[Union[Alias, Subquery]] = None,
+ name: Optional[str] = None,
+ flat: bool = False,
+ adapt_on_names: bool = False,
+ ) -> Union[AliasedClass[_O], FromClause]:
if isinstance(element, FromClause):
if adapt_on_names:
@@ -793,16 +857,16 @@ class AliasedInsp(
@classmethod
def _with_polymorphic_factory(
cls,
- base,
- classes,
- selectable=False,
- flat=False,
- polymorphic_on=None,
- aliased=False,
- innerjoin=False,
- adapt_on_names=False,
- _use_mapper_path=False,
- ):
+ base: Union[_O, Mapper[_O]],
+ classes: Iterable[Type[Any]],
+ selectable: Union[Literal[False, None], FromClause] = False,
+ flat: bool = False,
+ polymorphic_on: Optional[ColumnElement[Any]] = None,
+ aliased: bool = False,
+ innerjoin: bool = False,
+ adapt_on_names: bool = False,
+ _use_mapper_path: bool = False,
+ ) -> AliasedClass[_O]:
primary_mapper = _class_to_mapper(base)
@@ -816,7 +880,9 @@ class AliasedInsp(
classes, selectable, innerjoin=innerjoin
)
if aliased or flat:
+ assert selectable is not None
selectable = selectable._anonymous_fromclause(flat=flat)
+
return AliasedClass(
base,
selectable,
@@ -828,7 +894,7 @@ class AliasedInsp(
)
@property
- def entity(self):
+ def entity(self) -> AliasedClass[_O]:
# to eliminate reference cycles, the AliasedClass is held weakly.
# this produces some situations where the AliasedClass gets lost,
# particularly when one is created internally and only the AliasedInsp
@@ -844,7 +910,7 @@ class AliasedInsp(
is_aliased_class = True
"always returns True"
- def _memoized_method___clause_element__(self):
+ def _memoized_method___clause_element__(self) -> FromClause:
return self.selectable._annotate(
{
"parentmapper": self.mapper,
@@ -856,7 +922,7 @@ class AliasedInsp(
)
@property
- def entity_namespace(self):
+ def entity_namespace(self) -> AliasedClass[_O]:
return self.entity
_cache_key_traversal = [
@@ -866,7 +932,7 @@ class AliasedInsp(
]
@property
- def class_(self):
+ def class_(self) -> Type[_O]:
"""Return the mapped class ultimately represented by this
:class:`.AliasedInsp`."""
return self.mapper.class_
@@ -878,7 +944,7 @@ class AliasedInsp(
else:
return PathRegistry.per_mapper(self)
- def __getstate__(self):
+ def __getstate__(self) -> Dict[str, Any]:
return {
"entity": self.entity,
"mapper": self.mapper,
@@ -893,8 +959,8 @@ class AliasedInsp(
"nest_adapters": self._nest_adapters,
}
- def __setstate__(self, state):
- self.__init__(
+ def __setstate__(self, state: Dict[str, Any]) -> None:
+ self.__init__( # type: ignore
state["entity"],
state["mapper"],
state["alias"],
@@ -908,7 +974,7 @@ class AliasedInsp(
state["nest_adapters"],
)
- def _merge_with(self, other):
+ def _merge_with(self, other: AliasedInsp[_O]) -> AliasedInsp[_O]:
# assert self._is_with_polymorphic
# assert other._is_with_polymorphic
@@ -929,7 +995,6 @@ class AliasedInsp(
classes, None, innerjoin=not other.represents_outer_join
)
selectable = selectable._anonymous_fromclause(flat=True)
-
return AliasedClass(
primary_mapper,
selectable,
@@ -937,10 +1002,13 @@ class AliasedInsp(
with_polymorphic_discriminator=other.polymorphic_on,
use_mapper_path=other._use_mapper_path,
represents_outer_join=other.represents_outer_join,
- )
+ )._aliased_insp
- def _adapt_element(self, elem, key=None):
- d = {
+ def _adapt_element(
+ self, elem: _ORMColumnExprArgument[_T], key: Optional[str] = None
+ ) -> _ORMColumnExprArgument[_T]:
+ assert isinstance(elem, ColumnElement)
+ d: Dict[str, Any] = {
"parententity": self,
"parentmapper": self.mapper,
}
@@ -1084,35 +1152,45 @@ class LoaderCriteriaOption(CriteriaOption):
("propagate_to_loaders", visitors.InternalTraversal.dp_boolean),
]
+ root_entity: Optional[Type[Any]]
+ entity: Optional[_InternalEntityType[Any]]
+ where_criteria: Union[ColumnElement[bool], lambdas.DeferredLambdaElement]
+ deferred_where_criteria: bool
+ include_aliases: bool
+ propagate_to_loaders: bool
+
def __init__(
self,
- entity_or_base,
- where_criteria,
- loader_only=False,
- include_aliases=False,
- propagate_to_loaders=True,
- track_closure_variables=True,
+ entity_or_base: _EntityType[Any],
+ where_criteria: _ColumnExpressionArgument[bool],
+ loader_only: bool = False,
+ include_aliases: bool = False,
+ propagate_to_loaders: bool = True,
+ track_closure_variables: bool = True,
):
- entity = inspection.inspect(entity_or_base, False)
+ entity = cast(
+ "_InternalEntityType[Any]",
+ inspection.inspect(entity_or_base, False),
+ )
if entity is None:
- self.root_entity = entity_or_base
+ self.root_entity = cast("Type[Any]", entity_or_base)
self.entity = None
else:
self.root_entity = None
self.entity = entity
if callable(where_criteria):
+ if self.root_entity is not None:
+ wrap_entity = self.root_entity
+ else:
+ assert entity is not None
+ wrap_entity = entity.entity
+
self.deferred_where_criteria = True
self.where_criteria = lambdas.DeferredLambdaElement(
- where_criteria,
+ where_criteria, # type: ignore
roles.WhereHavingRole,
- lambda_args=(
- _WrapUserEntity(
- self.root_entity
- if self.root_entity is not None
- else self.entity.entity,
- ),
- ),
+ lambda_args=(_WrapUserEntity(wrap_entity),),
opts=lambdas.LambdaOptions(
track_closure_variables=track_closure_variables
),
@@ -1126,22 +1204,27 @@ class LoaderCriteriaOption(CriteriaOption):
self.include_aliases = include_aliases
self.propagate_to_loaders = propagate_to_loaders
- def _all_mappers(self):
+ def _all_mappers(self) -> Iterator[Mapper[Any]]:
+
if self.entity:
- for ent in self.entity.mapper.self_and_descendants:
- yield ent
+ for mp_ent in self.entity.mapper.self_and_descendants:
+ yield mp_ent
else:
+ assert self.root_entity
stack = list(self.root_entity.__subclasses__())
while stack:
subclass = stack.pop(0)
- ent = inspection.inspect(subclass, raiseerr=False)
+ ent = cast(
+ "_InternalEntityType[Any]",
+ inspection.inspect(subclass, raiseerr=False),
+ )
if ent:
for mp in ent.mapper.self_and_descendants:
yield mp
else:
stack.extend(subclass.__subclasses__())
- def _should_include(self, compile_state):
+ def _should_include(self, compile_state: ORMCompileState) -> bool:
if (
compile_state.select_statement._annotations.get(
"for_loader_criteria", None
@@ -1151,21 +1234,29 @@ class LoaderCriteriaOption(CriteriaOption):
return False
return True
- def _resolve_where_criteria(self, ext_info):
+ def _resolve_where_criteria(
+ self, ext_info: _InternalEntityType[Any]
+ ) -> ColumnElement[bool]:
if self.deferred_where_criteria:
- crit = self.where_criteria._resolve_with_args(ext_info.entity)
+ crit = cast(
+ "ColumnElement[bool]",
+ self.where_criteria._resolve_with_args(ext_info.entity),
+ )
else:
- crit = self.where_criteria
+ crit = self.where_criteria # type: ignore
+ assert isinstance(crit, ColumnElement)
return sql_util._deep_annotate(
crit, {"for_loader_criteria": self}, detect_subquery_cols=True
)
def process_compile_state_replaced_entities(
- self, compile_state, mapper_entities
- ):
- return self.process_compile_state(compile_state)
+ self,
+ compile_state: ORMCompileState,
+ mapper_entities: Iterable[_MapperEntity],
+ ) -> None:
+ self.process_compile_state(compile_state)
- def process_compile_state(self, compile_state):
+ def process_compile_state(self, compile_state: ORMCompileState) -> None:
"""Apply a modification to a given :class:`.CompileState`."""
# if options to limit the criteria to immediate query only,
@@ -1173,7 +1264,7 @@ class LoaderCriteriaOption(CriteriaOption):
self.get_global_criteria(compile_state.global_attributes)
- def get_global_criteria(self, attributes):
+ def get_global_criteria(self, attributes: Dict[Any, Any]) -> None:
for mp in self._all_mappers():
load_criteria = attributes.setdefault(
("additional_entity_criteria", mp), []
@@ -1183,14 +1274,14 @@ class LoaderCriteriaOption(CriteriaOption):
inspection._inspects(AliasedClass)(lambda target: target._aliased_insp)
-inspection._inspects(AliasedInsp)(lambda target: target)
@inspection._self_inspects
class Bundle(
ORMColumnsClauseRole,
SupportsCloneAnnotations,
- sql_base.MemoizedHasCacheKey,
+ MemoizedHasCacheKey,
+ inspection.Inspectable["Bundle"],
InspectionAttr,
):
"""A grouping of SQL expressions that are returned by a :class:`.Query`
@@ -1227,7 +1318,11 @@ class Bundle(
_propagate_attrs: _PropagateAttrsType = util.immutabledict()
- def __init__(self, name, *exprs, **kw):
+ exprs: List[_ColumnsClauseElement]
+
+ def __init__(
+ self, name: str, *exprs: _ColumnExpressionArgument[Any], **kw: Any
+ ):
r"""Construct a new :class:`.Bundle`.
e.g.::
@@ -1246,37 +1341,43 @@ class Bundle(
"""
self.name = self._label = name
- self.exprs = exprs = [
+ coerced_exprs = [
coercions.expect(
roles.ColumnsClauseRole, expr, apply_propagate_attrs=self
)
for expr in exprs
]
+ self.exprs = coerced_exprs
self.c = self.columns = ColumnCollection(
(getattr(col, "key", col._label), col)
- for col in [e._annotations.get("bundle", e) for e in exprs]
- )
+ for col in [e._annotations.get("bundle", e) for e in coerced_exprs]
+ ).as_readonly()
self.single_entity = kw.pop("single_entity", self.single_entity)
- def _gen_cache_key(self, anon_map, bindparams):
+ def _gen_cache_key(
+ self, anon_map: anon_map, bindparams: List[BindParameter[Any]]
+ ) -> Tuple[Any, ...]:
return (self.__class__, self.name, self.single_entity) + tuple(
[expr._gen_cache_key(anon_map, bindparams) for expr in self.exprs]
)
@property
- def mapper(self):
+ def mapper(self) -> Mapper[Any]:
return self.exprs[0]._annotations.get("parentmapper", None)
@property
- def entity(self):
+ def entity(self) -> _InternalEntityType[Any]:
return self.exprs[0]._annotations.get("parententity", None)
@property
- def entity_namespace(self):
+ def entity_namespace(
+ self,
+ ) -> ReadOnlyColumnCollection[str, ColumnElement[Any]]:
return self.c
- columns = None
+ columns: ReadOnlyColumnCollection[str, ColumnElement[Any]]
+
"""A namespace of SQL expressions referred to by this :class:`.Bundle`.
e.g.::
@@ -1301,7 +1402,7 @@ class Bundle(
"""
- c = None
+ c: ReadOnlyColumnCollection[str, ColumnElement[Any]]
"""An alias for :attr:`.Bundle.columns`."""
def _clone(self):
@@ -1400,32 +1501,30 @@ class _ORMJoin(expression.Join):
def __init__(
self,
- left,
- right,
- onclause=None,
- isouter=False,
- full=False,
- _left_memo=None,
- _right_memo=None,
- _extra_criteria=(),
+ left: _FromClauseArgument,
+ right: _FromClauseArgument,
+ onclause: Optional[_OnClauseArgument] = None,
+ isouter: bool = False,
+ full: bool = False,
+ _left_memo: Optional[Any] = None,
+ _right_memo: Optional[Any] = None,
+ _extra_criteria: Sequence[ColumnElement[bool]] = (),
):
- left_info = inspection.inspect(left)
+ left_info = cast(
+ "Union[FromClause, _InternalEntityType[Any]]",
+ inspection.inspect(left),
+ )
- right_info = inspection.inspect(right)
+ right_info = cast(
+ "Union[FromClause, _InternalEntityType[Any]]",
+ inspection.inspect(right),
+ )
adapt_to = right_info.selectable
# used by joined eager loader
self._left_memo = _left_memo
self._right_memo = _right_memo
- # legacy, for string attr name ON clause. if that's removed
- # then the "_joined_from_info" concept can go
- left_orm_info = getattr(left, "_joined_from_info", left_info)
- self._joined_from_info = right_info
- if isinstance(onclause, str):
- onclause = getattr(left_orm_info.entity, onclause)
- # ####
-
if isinstance(onclause, attributes.QueryableAttribute):
on_selectable = onclause.comparator._source_selectable()
prop = onclause.property
@@ -1477,20 +1576,23 @@ class _ORMJoin(expression.Join):
augment_onclause = onclause is None and _extra_criteria
expression.Join.__init__(self, left, right, onclause, isouter, full)
+ assert self.onclause is not None
+
if augment_onclause:
self.onclause &= sql.and_(*_extra_criteria)
if (
not prop
and getattr(right_info, "mapper", None)
- and right_info.mapper.single
+ and right_info.mapper.single # type: ignore
):
+ right_info = cast("_InternalEntityType[Any]", right_info)
# if single inheritance target and we are using a manual
# or implicit ON clause, augment it the same way we'd augment the
# WHERE.
single_crit = right_info.mapper._single_table_criterion
if single_crit is not None:
- if right_info.is_aliased_class:
+ if insp_is_aliased_class(right_info):
single_crit = right_info._adapter.traverse(single_crit)
self.onclause = self.onclause & single_crit
@@ -1525,19 +1627,27 @@ class _ORMJoin(expression.Join):
def join(
self,
- right,
- onclause=None,
- isouter=False,
- full=False,
- join_to_left=None,
- ):
+ right: _FromClauseArgument,
+ onclause: Optional[_OnClauseArgument] = None,
+ isouter: bool = False,
+ full: bool = False,
+ ) -> _ORMJoin:
return _ORMJoin(self, right, onclause, full=full, isouter=isouter)
- def outerjoin(self, right, onclause=None, full=False, join_to_left=None):
+ def outerjoin(
+ self,
+ right: _FromClauseArgument,
+ onclause: Optional[_OnClauseArgument] = None,
+ full: bool = False,
+ ) -> _ORMJoin:
return _ORMJoin(self, right, onclause, isouter=True, full=full)
-def with_parent(instance, prop, from_entity=None):
+def with_parent(
+ instance: object,
+ prop: attributes.QueryableAttribute[Any],
+ from_entity: Optional[_EntityType[Any]] = None,
+) -> ColumnElement[bool]:
"""Create filtering criterion that relates this query's primary entity
to the given related instance, using established
:func:`_orm.relationship()`
@@ -1588,6 +1698,8 @@ def with_parent(instance, prop, from_entity=None):
.. versionadded:: 1.2
"""
+ prop_t: Relationship[Any]
+
if isinstance(prop, str):
raise sa_exc.ArgumentError(
"with_parent() accepts class-bound mapped attributes, not strings"
@@ -1595,12 +1707,19 @@ def with_parent(instance, prop, from_entity=None):
elif isinstance(prop, attributes.QueryableAttribute):
if prop._of_type:
from_entity = prop._of_type
- prop = prop.property
+ if not prop_is_relationship(prop.property):
+ raise sa_exc.ArgumentError(
+ f"Expected relationship property for with_parent(), "
+ f"got {prop.property}"
+ )
+ prop_t = prop.property
+ else:
+ prop_t = prop
- return prop._with_parent(instance, from_entity=from_entity)
+ return prop_t._with_parent(instance, from_entity=from_entity)
-def has_identity(object_):
+def has_identity(object_: object) -> bool:
"""Return True if the given object has a database
identity.
@@ -1616,7 +1735,7 @@ def has_identity(object_):
return state.has_identity
-def was_deleted(object_):
+def was_deleted(object_: object) -> bool:
"""Return True if the given object was deleted
within a session flush.
@@ -1633,27 +1752,32 @@ def was_deleted(object_):
return state.was_deleted
-def _entity_corresponds_to(given, entity):
+def _entity_corresponds_to(
+ given: _InternalEntityType[Any], entity: _InternalEntityType[Any]
+) -> bool:
"""determine if 'given' corresponds to 'entity', in terms
of an entity passed to Query that would match the same entity
being referred to elsewhere in the query.
"""
- if entity.is_aliased_class:
- if given.is_aliased_class:
+ if insp_is_aliased_class(entity):
+ if insp_is_aliased_class(given):
if entity._base_alias() is given._base_alias():
return True
return False
- elif given.is_aliased_class:
+ elif insp_is_aliased_class(given):
if given._use_mapper_path:
return entity in given.with_polymorphic_mappers
else:
return entity is given
+ assert insp_is_mapper(given)
return entity.common_parent(given)
-def _entity_corresponds_to_use_path_impl(given, entity):
+def _entity_corresponds_to_use_path_impl(
+ given: _InternalEntityType[Any], entity: _InternalEntityType[Any]
+) -> bool:
"""determine if 'given' corresponds to 'entity', in terms
of a path of loader options where a mapped attribute is taken to
be a member of a parent entity.
@@ -1673,13 +1797,13 @@ def _entity_corresponds_to_use_path_impl(given, entity):
"""
- if given.is_aliased_class:
+ if insp_is_aliased_class(given):
return (
- entity.is_aliased_class
+ insp_is_aliased_class(entity)
and not entity._use_mapper_path
and (given is entity or entity in given._with_polymorphic_entities)
)
- elif not entity.is_aliased_class:
+ elif not insp_is_aliased_class(entity):
return given.isa(entity.mapper)
else:
return (
@@ -1688,7 +1812,7 @@ def _entity_corresponds_to_use_path_impl(given, entity):
)
-def _entity_isa(given, mapper):
+def _entity_isa(given: _InternalEntityType[Any], mapper: Mapper[Any]) -> bool:
"""determine if 'given' "is a" mapper, in terms of the given
would load rows of type 'mapper'.
@@ -1703,42 +1827,6 @@ def _entity_isa(given, mapper):
return given.isa(mapper)
-def randomize_unitofwork():
- """Use random-ordering sets within the unit of work in order
- to detect unit of work sorting issues.
-
- This is a utility function that can be used to help reproduce
- inconsistent unit of work sorting issues. For example,
- if two kinds of objects A and B are being inserted, and
- B has a foreign key reference to A - the A must be inserted first.
- However, if there is no relationship between A and B, the unit of work
- won't know to perform this sorting, and an operation may or may not
- fail, depending on how the ordering works out. Since Python sets
- and dictionaries have non-deterministic ordering, such an issue may
- occur on some runs and not on others, and in practice it tends to
- have a great dependence on the state of the interpreter. This leads
- to so-called "heisenbugs" where changing entirely irrelevant aspects
- of the test program still cause the failure behavior to change.
-
- By calling ``randomize_unitofwork()`` when a script first runs, the
- ordering of a key series of sets within the unit of work implementation
- are randomized, so that the script can be minimized down to the
- fundamental mapping and operation that's failing, while still reproducing
- the issue on at least some runs.
-
- This utility is also available when running the test suite via the
- ``--reversetop`` flag.
-
- """
- from sqlalchemy.orm import unitofwork, session, mapper, dependency
- from sqlalchemy.util import topological
- from sqlalchemy.testing.util import RandomSet
-
- topological.set = (
- unitofwork.set
- ) = session.set = mapper.set = dependency.set = RandomSet
-
-
def _getitem(iterable_query, item):
"""calculate __getitem__ in terms of an iterable query object
that also has a slice() method.
@@ -1780,16 +1868,21 @@ def _getitem(iterable_query, item):
return list(iterable_query[item : item + 1])[0]
-def _is_mapped_annotation(raw_annotation: Union[type, str], cls: type):
+def _is_mapped_annotation(
+ raw_annotation: Union[type, str], cls: Type[Any]
+) -> bool:
annotated = de_stringify_annotation(cls, raw_annotation)
return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm")
-def _cleanup_mapped_str_annotation(annotation):
+def _cleanup_mapped_str_annotation(annotation: str) -> str:
# fix up an annotation that comes in as the form:
# 'Mapped[List[Address]]' so that it instead looks like:
# 'Mapped[List["Address"]]' , which will allow us to get
# "Address" as a string
+
+ inner: Optional[Match[str]]
+
mm = re.match(r"^(.+?)\[(.+)\]$", annotation)
if mm and mm.group(1) == "Mapped":
stack = []
@@ -1839,8 +1932,8 @@ def _extract_mapped_subtype(
else:
if (
not hasattr(annotated, "__origin__")
- or not issubclass(annotated.__origin__, attr_cls)
- and not issubclass(attr_cls, annotated.__origin__)
+ or not issubclass(annotated.__origin__, attr_cls) # type: ignore
+ and not issubclass(attr_cls, annotated.__origin__) # type: ignore
):
our_annotated_str = (
annotated.__name__
@@ -1853,9 +1946,9 @@ def _extract_mapped_subtype(
f'"{attr_cls.__name__}[{our_annotated_str}]".'
)
- if len(annotated.__args__) != 1:
+ if len(annotated.__args__) != 1: # type: ignore
raise sa_exc.ArgumentError(
"Expected sub-type for Mapped[] annotation"
)
- return annotated.__args__[0]
+ return annotated.__args__[0] # type: ignore