diff options
Diffstat (limited to 'lib/sqlalchemy/orm')
| -rw-r--r-- | lib/sqlalchemy/orm/_orm_constructors.py | 558 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/_typing.py | 51 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/attributes.py | 8 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/base.py | 77 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/context.py | 44 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/decl_api.py | 8 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/decl_base.py | 19 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/descriptor_props.py | 6 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/events.py | 13 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/exc.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/instrumentation.py | 81 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/interfaces.py | 302 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/loading.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 649 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/path_registry.py | 445 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/properties.py | 67 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/query.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/relationships.py | 114 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/session.py | 21 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/strategies.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/util.py | 539 |
21 files changed, 1831 insertions, 1195 deletions
diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index 7690c05de..457ad5c5a 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -9,20 +9,19 @@ from __future__ import annotations import typing from typing import Any +from typing import Callable from typing import Collection -from typing import Dict -from typing import List from typing import Optional from typing import overload -from typing import Set from typing import Type +from typing import TYPE_CHECKING from typing import Union -from . import mapper as mapperlib +from . import mapperlib as mapperlib +from ._typing import _O from .base import Mapped from .descriptor_props import Composite from .descriptor_props import Synonym -from .mapper import Mapper from .properties import ColumnProperty from .properties import MappedColumn from .query import AliasOption @@ -37,11 +36,29 @@ from .. import sql from .. import util from ..exc import InvalidRequestError from ..sql.base import SchemaEventTarget -from ..sql.selectable import Alias +from ..sql.schema import SchemaConst from ..sql.selectable import FromClause -from ..sql.type_api import TypeEngine from ..util.typing import Literal +if TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _ORMColumnExprArgument + from .descriptor_props import _CompositeAttrType + from .interfaces import PropComparator + from .query import Query + from .relationships import _LazyLoadArgumentType + from .relationships import _ORMBackrefArgument + from .relationships import _ORMColCollectionArgument + from .relationships import _ORMOrderByArgument + from .relationships import _RelationshipJoinConditionArgument + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _InfoType + from ..sql._typing import _TypeEngineArgument + from ..sql.schema import _ServerDefaultType + from ..sql.schema import FetchedValue + from ..sql.selectable import Alias + from ..sql.selectable import Subquery + _T = typing.TypeVar("_T") @@ -61,7 +78,7 @@ SynonymProperty = Synonym "for entities to be matched up to a query that is established " "via :meth:`.Query.from_statement` and now does nothing.", ) -def contains_alias(alias) -> AliasOption: +def contains_alias(alias: Union[Alias, Subquery]) -> AliasOption: r"""Return a :class:`.MapperOption` that will indicate to the :class:`_query.Query` that the main table has been aliased. @@ -70,134 +87,36 @@ def contains_alias(alias) -> AliasOption: return AliasOption(alias) -# see test/ext/mypy/plain_files/mapped_column.py for mapped column -# typing tests - - -@overload -def mapped_column( - __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], - *args: SchemaEventTarget, - nullable: Literal[None] = ..., - primary_key: Literal[None] = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[Any]": - ... - - -@overload -def mapped_column( - __name: str, - __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], - *args: SchemaEventTarget, - nullable: Literal[None] = ..., - primary_key: Literal[None] = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[Any]": - ... - - -@overload -def mapped_column( - __name: str, - __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], - *args: SchemaEventTarget, - nullable: Literal[True] = ..., - primary_key: Literal[None] = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[Optional[_T]]": - ... - - -@overload -def mapped_column( - __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], - *args: SchemaEventTarget, - nullable: Literal[True] = ..., - primary_key: Literal[None] = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[Optional[_T]]": - ... - - -@overload -def mapped_column( - __name: str, - __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], - *args: SchemaEventTarget, - nullable: Literal[False] = ..., - primary_key: Literal[None] = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[_T]": - ... - - -@overload -def mapped_column( - __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], - *args: SchemaEventTarget, - nullable: Literal[False] = ..., - primary_key: Literal[None] = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[_T]": - ... - - -@overload -def mapped_column( - __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], - *args: SchemaEventTarget, - nullable: bool = ..., - primary_key: Literal[True] = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[_T]": - ... - - -@overload -def mapped_column( - __name: str, - __type: Union[Type[TypeEngine[_T]], TypeEngine[_T]], - *args: SchemaEventTarget, - nullable: bool = ..., - primary_key: Literal[True] = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[_T]": - ... - - -@overload -def mapped_column( - __name: str, - *args: SchemaEventTarget, - nullable: bool = ..., - primary_key: bool = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[Any]": - ... - - -@overload def mapped_column( + __name_pos: Optional[ + Union[str, _TypeEngineArgument[Any], SchemaEventTarget] + ] = None, + __type_pos: Optional[ + Union[_TypeEngineArgument[Any], SchemaEventTarget] + ] = None, *args: SchemaEventTarget, - nullable: bool = ..., - primary_key: bool = ..., - deferred: bool = ..., - **kw: Any, -) -> "MappedColumn[Any]": - ... - - -def mapped_column(*args: Any, **kw: Any) -> "MappedColumn[Any]": + nullable: Optional[ + Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] + ] = SchemaConst.NULL_UNSPECIFIED, + primary_key: Optional[bool] = False, + deferred: bool = False, + name: Optional[str] = None, + type_: Optional[_TypeEngineArgument[Any]] = None, + autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", + default: Optional[Any] = None, + doc: Optional[str] = None, + key: Optional[str] = None, + index: Optional[bool] = None, + unique: Optional[bool] = None, + info: Optional[_InfoType] = None, + onupdate: Optional[Any] = None, + server_default: Optional[_ServerDefaultType] = None, + server_onupdate: Optional[FetchedValue] = None, + quote: Optional[bool] = None, + system: bool = False, + comment: Optional[str] = None, + **dialect_kwargs: Any, +) -> MappedColumn[Any]: r"""construct a new ORM-mapped :class:`_schema.Column` construct. The :func:`_orm.mapped_column` function provides an ORM-aware and @@ -363,12 +282,45 @@ def mapped_column(*args: Any, **kw: Any) -> "MappedColumn[Any]": """ - return MappedColumn(*args, **kw) + return MappedColumn( + __name_pos, + __type_pos, + *args, + name=name, + type_=type_, + autoincrement=autoincrement, + default=default, + doc=doc, + key=key, + index=index, + unique=unique, + info=info, + nullable=nullable, + onupdate=onupdate, + primary_key=primary_key, + server_default=server_default, + server_onupdate=server_onupdate, + quote=quote, + comment=comment, + system=system, + deferred=deferred, + **dialect_kwargs, + ) def column_property( - column: sql.ColumnElement[_T], *additional_columns, **kwargs -) -> "ColumnProperty[_T]": + column: _ORMColumnExprArgument[_T], + *additional_columns: _ORMColumnExprArgument[Any], + group: Optional[str] = None, + deferred: bool = False, + raiseload: bool = False, + comparator_factory: Optional[Type[PropComparator[_T]]] = None, + descriptor: Optional[Any] = None, + active_history: bool = False, + expire_on_flush: bool = True, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, +) -> ColumnProperty[_T]: r"""Provide a column-level property for use with a mapping. Column-based properties can normally be applied to the mapper's @@ -452,13 +404,25 @@ def column_property( expressions """ - return ColumnProperty(column, *additional_columns, **kwargs) + return ColumnProperty( + column, + *additional_columns, + group=group, + deferred=deferred, + raiseload=raiseload, + comparator_factory=comparator_factory, + descriptor=descriptor, + active_history=active_history, + expire_on_flush=expire_on_flush, + info=info, + doc=doc, + ) @overload def composite( class_: Type[_T], - *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]], + *attrs: _CompositeAttrType[Any], **kwargs: Any, ) -> Composite[_T]: ... @@ -466,7 +430,7 @@ def composite( @overload def composite( - *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]], + *attrs: _CompositeAttrType[Any], **kwargs: Any, ) -> Composite[Any]: ... @@ -474,7 +438,7 @@ def composite( def composite( class_: Any = None, - *attrs: Union[sql.ColumnElement[Any], MappedColumn, str, Mapped[Any]], + *attrs: _CompositeAttrType[Any], **kwargs: Any, ) -> Composite[Any]: r"""Return a composite column-based property for use with a Mapper. @@ -529,13 +493,13 @@ def composite( def with_loader_criteria( - entity_or_base, - where_criteria, - loader_only=False, - include_aliases=False, - propagate_to_loaders=True, - track_closure_variables=True, -) -> "LoaderCriteriaOption": + entity_or_base: _EntityType[Any], + where_criteria: _ColumnExpressionArgument[bool], + loader_only: bool = False, + include_aliases: bool = False, + propagate_to_loaders: bool = True, + track_closure_variables: bool = True, +) -> LoaderCriteriaOption: """Add additional WHERE criteria to the load for all occurrences of a particular entity. @@ -711,180 +675,40 @@ def with_loader_criteria( ) -@overload -def relationship( - argument: str, - secondary=..., - *, - uselist: bool = ..., - collection_class: Literal[None] = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[Any]: - ... - - -@overload -def relationship( - argument: str, - secondary=..., - *, - uselist: bool = ..., - collection_class: Type[Set] = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[Set[Any]]: - ... - - -@overload -def relationship( - argument: str, - secondary=..., - *, - uselist: bool = ..., - collection_class: Type[List] = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[List[Any]]: - ... - - -@overload -def relationship( - argument: Optional[_RelationshipArgumentType[_T]], - secondary=..., - *, - uselist: Literal[False] = ..., - collection_class: Literal[None] = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[_T]: - ... - - -@overload -def relationship( - argument: Optional[_RelationshipArgumentType[_T]], - secondary=..., - *, - uselist: Literal[True] = ..., - collection_class: Literal[None] = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[List[_T]]: - ... - - -@overload -def relationship( - argument: Optional[_RelationshipArgumentType[_T]], - secondary=..., - *, - uselist: Union[Literal[None], Literal[True]] = ..., - collection_class: Type[List] = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[List[_T]]: - ... - - -@overload -def relationship( - argument: Optional[_RelationshipArgumentType[_T]], - secondary=..., - *, - uselist: Union[Literal[None], Literal[True]] = ..., - collection_class: Type[Set] = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[Set[_T]]: - ... - - -@overload -def relationship( - argument: Optional[_RelationshipArgumentType[_T]], - secondary=..., - *, - uselist: Union[Literal[None], Literal[True]] = ..., - collection_class: Type[Dict[Any, Any]] = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[Dict[Any, _T]]: - ... - - -@overload -def relationship( - argument: _RelationshipArgumentType[_T], - secondary=..., - *, - uselist: Literal[None] = ..., - collection_class: Literal[None] = ..., - primaryjoin=..., - secondaryjoin=None, - back_populates=None, - **kw: Any, -) -> Relationship[Any]: - ... - - -@overload -def relationship( - argument: Optional[_RelationshipArgumentType[_T]] = ..., - secondary=..., - *, - uselist: Literal[True] = ..., - collection_class: Any = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[Any]: - ... - - -@overload def relationship( - argument: Literal[None] = ..., - secondary=..., - *, - uselist: Optional[bool] = ..., - collection_class: Any = ..., - primaryjoin=..., - secondaryjoin=..., - back_populates=..., - **kw: Any, -) -> Relationship[Any]: - ... - - -def relationship( - argument: Optional[_RelationshipArgumentType[_T]] = None, - secondary=None, + argument: Optional[_RelationshipArgumentType[Any]] = None, + secondary: Optional[FromClause] = None, *, uselist: Optional[bool] = None, - collection_class: Optional[Type[Collection]] = None, - primaryjoin=None, - secondaryjoin=None, - back_populates=None, + collection_class: Optional[ + Union[Type[Collection[Any]], Callable[[], Collection[Any]]] + ] = None, + primaryjoin: Optional[_RelationshipJoinConditionArgument] = None, + secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None, + back_populates: Optional[str] = None, + order_by: _ORMOrderByArgument = False, + backref: Optional[_ORMBackrefArgument] = None, + overlaps: Optional[str] = None, + post_update: bool = False, + cascade: str = "save-update, merge", + viewonly: bool = False, + lazy: _LazyLoadArgumentType = "select", + passive_deletes: bool = False, + passive_updates: bool = True, + active_history: bool = False, + enable_typechecks: bool = True, + foreign_keys: Optional[_ORMColCollectionArgument] = None, + remote_side: Optional[_ORMColCollectionArgument] = None, + join_depth: Optional[int] = None, + comparator_factory: Optional[Type[PropComparator[Any]]] = None, + single_parent: bool = False, + innerjoin: bool = False, + distinct_target_key: Optional[bool] = None, + load_on_pending: bool = False, + query_class: Optional[Type[Query[Any]]] = None, + info: Optional[_InfoType] = None, + omit_join: Literal[None, False] = None, + sync_backref: Optional[bool] = None, **kw: Any, ) -> Relationship[Any]: """Provide a relationship between two mapped classes. @@ -1098,13 +922,6 @@ def relationship( :ref:`error_qzyx` - usage example - :param bake_queries=True: - Legacy parameter, not used. - - .. versionchanged:: 1.4.23 the "lambda caching" system is no longer - used by loader strategies and the ``bake_queries`` parameter - has no effect. - :param cascade: A comma-separated list of cascade rules which determines how Session operations should be "cascaded" from parent to child. @@ -1701,18 +1518,42 @@ def relationship( primaryjoin=primaryjoin, secondaryjoin=secondaryjoin, back_populates=back_populates, + order_by=order_by, + backref=backref, + overlaps=overlaps, + post_update=post_update, + cascade=cascade, + viewonly=viewonly, + lazy=lazy, + passive_deletes=passive_deletes, + passive_updates=passive_updates, + active_history=active_history, + enable_typechecks=enable_typechecks, + foreign_keys=foreign_keys, + remote_side=remote_side, + join_depth=join_depth, + comparator_factory=comparator_factory, + single_parent=single_parent, + innerjoin=innerjoin, + distinct_target_key=distinct_target_key, + load_on_pending=load_on_pending, + query_class=query_class, + info=info, + omit_join=omit_join, + sync_backref=sync_backref, **kw, ) def synonym( - name, - map_column=None, - descriptor=None, - comparator_factory=None, - doc=None, - info=None, -) -> "Synonym[Any]": + name: str, + *, + map_column: Optional[bool] = None, + descriptor: Optional[Any] = None, + comparator_factory: Optional[Type[PropComparator[_T]]] = None, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, +) -> Synonym[Any]: """Denote an attribute name as a synonym to a mapped property, in that the attribute will mirror the value and expression behavior of another attribute. @@ -1951,8 +1792,8 @@ def deferred(*columns, **kw): def query_expression( - default_expr: sql.ColumnElement[_T] = sql.null(), -) -> "Mapped[_T]": + default_expr: _ORMColumnExprArgument[_T] = sql.null(), +) -> Mapped[_T]: """Indicate an attribute that populates from a query-time SQL expression. :param default_expr: Optional SQL expression object that will be used in @@ -2010,33 +1851,33 @@ def clear_mappers(): @overload def aliased( - element: Union[Type[_T], "Mapper[_T]", "AliasedClass[_T]"], - alias=None, - name=None, - flat=False, - adapt_on_names=False, -) -> "AliasedClass[_T]": + element: _EntityType[_O], + alias: Optional[Union[Alias, Subquery]] = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, +) -> AliasedClass[_O]: ... @overload def aliased( - element: "FromClause", - alias=None, - name=None, - flat=False, - adapt_on_names=False, -) -> "Alias": + element: FromClause, + alias: Optional[Union[Alias, Subquery]] = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, +) -> FromClause: ... def aliased( - element: Union[Type[_T], "Mapper[_T]", "FromClause", "AliasedClass[_T]"], - alias=None, - name=None, - flat=False, - adapt_on_names=False, -) -> Union["AliasedClass[_T]", "Alias"]: + element: Union[_EntityType[_O], FromClause], + alias: Optional[Union[Alias, Subquery]] = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, +) -> Union[AliasedClass[_O], FromClause]: """Produce an alias of the given element, usually an :class:`.AliasedClass` instance. @@ -2233,9 +2074,7 @@ def with_polymorphic( ) -def join( - left, right, onclause=None, isouter=False, full=False, join_to_left=None -): +def join(left, right, onclause=None, isouter=False, full=False): r"""Produce an inner join between left and right clauses. :func:`_orm.join` is an extension to the core join interface @@ -2270,16 +2109,11 @@ def join( See :ref:`orm_queryguide_joins` for information on modern usage of ORM level joins. - .. deprecated:: 0.8 - - the ``join_to_left`` parameter is deprecated, and will be removed - in a future release. The parameter has no effect. - """ return _ORMJoin(left, right, onclause, isouter, full) -def outerjoin(left, right, onclause=None, full=False, join_to_left=None): +def outerjoin(left, right, onclause=None, full=False): """Produce a left outer join between left and right clauses. This is the "outer join" version of the :func:`_orm.join` function, diff --git a/lib/sqlalchemy/orm/_typing.py b/lib/sqlalchemy/orm/_typing.py index 4250cdbe1..339844f14 100644 --- a/lib/sqlalchemy/orm/_typing.py +++ b/lib/sqlalchemy/orm/_typing.py @@ -2,6 +2,7 @@ from __future__ import annotations import operator from typing import Any +from typing import Callable from typing import Dict from typing import Optional from typing import Tuple @@ -10,7 +11,9 @@ from typing import TYPE_CHECKING from typing import TypeVar from typing import Union -from sqlalchemy.orm.interfaces import UserDefinedOption +from ..sql import roles +from ..sql._typing import _HasClauseElement +from ..sql.elements import ColumnElement from ..util.typing import Protocol from ..util.typing import TypeGuard @@ -18,8 +21,12 @@ if TYPE_CHECKING: from .attributes import AttributeImpl from .attributes import CollectionAttributeImpl from .base import PassiveFlag + from .decl_api import registry as _registry_type from .descriptor_props import _CompositeClassProto + from .interfaces import MapperProperty + from .interfaces import UserDefinedOption from .mapper import Mapper + from .relationships import Relationship from .state import InstanceState from .util import AliasedClass from .util import AliasedInsp @@ -27,21 +34,39 @@ if TYPE_CHECKING: _T = TypeVar("_T", bound=Any) + +# I would have preferred this were bound=object however it seems +# to not travel in all situations when defined in that way. _O = TypeVar("_O", bound=Any) """The 'ORM mapped object' type. -I would have preferred this were bound=object however it seems -to not travel in all situations when defined in that way. + """ +if TYPE_CHECKING: + _RegistryType = _registry_type + _InternalEntityType = Union["Mapper[_T]", "AliasedInsp[_T]"] -_EntityType = Union[_T, "AliasedClass[_T]", "Mapper[_T]", "AliasedInsp[_T]"] +_EntityType = Union[ + Type[_T], "AliasedClass[_T]", "Mapper[_T]", "AliasedInsp[_T]" +] _InstanceDict = Dict[str, Any] _IdentityKeyType = Tuple[Type[_T], Tuple[Any, ...], Optional[Any]] +_ORMColumnExprArgument = Union[ + ColumnElement[_T], + _HasClauseElement, + roles.ExpressionElementRole[_T], +] + +# somehow Protocol didn't want to work for this one +_ORMAdapterProto = Callable[ + [_ORMColumnExprArgument[_T], Optional[str]], _ORMColumnExprArgument[_T] +] + class _LoaderCallable(Protocol): def __call__(self, state: InstanceState[Any], passive: PassiveFlag) -> Any: @@ -60,10 +85,28 @@ def is_composite_class(obj: Any) -> TypeGuard[_CompositeClassProto]: if TYPE_CHECKING: + def insp_is_mapper_property(obj: Any) -> TypeGuard[MapperProperty[Any]]: + ... + + def insp_is_mapper(obj: Any) -> TypeGuard[Mapper[Any]]: + ... + + def insp_is_aliased_class(obj: Any) -> TypeGuard[AliasedInsp[Any]]: + ... + + def prop_is_relationship( + prop: MapperProperty[Any], + ) -> TypeGuard[Relationship[Any]]: + ... + def is_collection_impl( impl: AttributeImpl, ) -> TypeGuard[CollectionAttributeImpl]: ... else: + insp_is_mapper_property = operator.attrgetter("is_property") + insp_is_mapper = operator.attrgetter("is_mapper") + insp_is_aliased_class = operator.attrgetter("is_aliased_class") is_collection_impl = operator.attrgetter("collection") + prop_is_relationship = operator.attrgetter("_is_relationship") diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 33ce96a19..41d944c57 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -44,7 +44,7 @@ from .base import instance_dict as instance_dict from .base import instance_state as instance_state from .base import instance_str from .base import LOAD_AGAINST_COMMITTED -from .base import manager_of_class +from .base import manager_of_class as manager_of_class from .base import Mapped as Mapped # noqa from .base import NEVER_SET # noqa from .base import NO_AUTOFLUSH @@ -52,6 +52,7 @@ from .base import NO_CHANGE # noqa from .base import NO_RAISE from .base import NO_VALUE from .base import NON_PERSISTENT_OK # noqa +from .base import opt_manager_of_class as opt_manager_of_class from .base import PASSIVE_CLASS_MISMATCH # noqa from .base import PASSIVE_NO_FETCH from .base import PASSIVE_NO_FETCH_RELATED # noqa @@ -74,6 +75,7 @@ from ..sql import traversals from ..sql import visitors if TYPE_CHECKING: + from .interfaces import MapperProperty from .state import InstanceState from ..sql.dml import _DMLColumnElement from ..sql.elements import ColumnElement @@ -146,7 +148,7 @@ class QueryableAttribute( self._of_type = of_type self._extra_criteria = extra_criteria - manager = manager_of_class(class_) + manager = opt_manager_of_class(class_) # manager is None in the case of AliasedClass if manager: # propagate existing event listeners from @@ -370,7 +372,7 @@ class QueryableAttribute( return "%s.%s" % (self.class_.__name__, self.key) @util.memoized_property - def property(self): + def property(self) -> MapperProperty[_T]: """Return the :class:`.MapperProperty` associated with this :class:`.QueryableAttribute`. diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index 3fa855a4b..054d52d83 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -26,24 +26,25 @@ from typing import TypeVar from typing import Union from . import exc +from ._typing import insp_is_mapper from .. import exc as sa_exc from .. import inspection from .. import util from ..sql.elements import SQLCoreOperations from ..util import FastIntFlag from ..util.langhelpers import TypingOnly -from ..util.typing import Concatenate from ..util.typing import Literal -from ..util.typing import ParamSpec from ..util.typing import Self if typing.TYPE_CHECKING: from ._typing import _InternalEntityType from .attributes import InstrumentedAttribute + from .instrumentation import ClassManager from .mapper import Mapper from .state import InstanceState from ..sql._typing import _InfoType + _T = TypeVar("_T", bound=Any) _O = TypeVar("_O", bound=object) @@ -246,21 +247,15 @@ _DEFER_FOR_STATE = util.symbol("DEFER_FOR_STATE") _RAISE_FOR_STATE = util.symbol("RAISE_FOR_STATE") -_Fn = TypeVar("_Fn", bound=Callable) -_Args = ParamSpec("_Args") +_F = TypeVar("_F", bound=Callable) _Self = TypeVar("_Self") def _assertions( *assertions: Any, -) -> Callable[ - [Callable[Concatenate[_Self, _Fn, _Args], _Self]], - Callable[Concatenate[_Self, _Fn, _Args], _Self], -]: +) -> Callable[[_F], _F]: @util.decorator - def generate( - fn: _Fn, self: _Self, *args: _Args.args, **kw: _Args.kwargs - ) -> _Self: + def generate(fn: _F, self: _Self, *args: Any, **kw: Any) -> _Self: for assertion in assertions: assertion(self, fn.__name__) fn(self, *args, **kw) @@ -269,13 +264,13 @@ def _assertions( return generate -# these can be replaced by sqlalchemy.ext.instrumentation -# if augmented class instrumentation is enabled. -def manager_of_class(cls): - return cls.__dict__.get(DEFAULT_MANAGER_ATTR, None) +if TYPE_CHECKING: + def manager_of_class(cls: Type[Any]) -> ClassManager: + ... -if TYPE_CHECKING: + def opt_manager_of_class(cls: Type[Any]) -> Optional[ClassManager]: + ... def instance_state(instance: _O) -> InstanceState[_O]: ... @@ -284,6 +279,20 @@ if TYPE_CHECKING: ... else: + # these can be replaced by sqlalchemy.ext.instrumentation + # if augmented class instrumentation is enabled. + + def manager_of_class(cls): + try: + return cls.__dict__[DEFAULT_MANAGER_ATTR] + except KeyError as ke: + raise exc.UnmappedClassError( + cls, f"Can't locate an instrumentation manager for class {cls}" + ) from ke + + def opt_manager_of_class(cls): + return cls.__dict__.get(DEFAULT_MANAGER_ATTR) + instance_state = operator.attrgetter(DEFAULT_STATE_ATTR) instance_dict = operator.attrgetter("__dict__") @@ -458,11 +467,12 @@ else: _state_mapper = util.dottedgetter("manager.mapper") -@inspection._inspects(type) -def _inspect_mapped_class(class_, configure=False): +def _inspect_mapped_class( + class_: Type[_O], configure: bool = False +) -> Optional[Mapper[_O]]: try: - class_manager = manager_of_class(class_) - if not class_manager.is_mapped: + class_manager = opt_manager_of_class(class_) + if class_manager is None or not class_manager.is_mapped: return None mapper = class_manager.mapper except exc.NO_STATE: @@ -473,7 +483,28 @@ def _inspect_mapped_class(class_, configure=False): return mapper -def class_mapper(class_: Type[_T], configure: bool = True) -> Mapper[_T]: +@inspection._inspects(type) +def _inspect_mc(class_: Type[_O]) -> Optional[Mapper[_O]]: + try: + class_manager = opt_manager_of_class(class_) + if class_manager is None or not class_manager.is_mapped: + return None + mapper = class_manager.mapper + except exc.NO_STATE: + return None + else: + return mapper + + +def _parse_mapper_argument(arg: Union[Mapper[_O], Type[_O]]) -> Mapper[_O]: + insp = inspection.inspect(arg, raiseerr=False) + if insp_is_mapper(insp): + return insp + + raise sa_exc.ArgumentError(f"Mapper or mapped class expected, got {arg!r}") + + +def class_mapper(class_: Type[_O], configure: bool = True) -> Mapper[_O]: """Given a class, return the primary :class:`_orm.Mapper` associated with the key. @@ -502,8 +533,8 @@ def class_mapper(class_: Type[_T], configure: bool = True) -> Mapper[_T]: class InspectionAttr: - """A base class applied to all ORM objects that can be returned - by the :func:`_sa.inspect` function. + """A base class applied to all ORM objects and attributes that are + related to things that can be returned by the :func:`_sa.inspect` function. The attributes defined here allow the usage of simple boolean checks to test basic facts about the object returned. diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 419da65f7..4fee2d383 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -63,11 +63,14 @@ from ..sql.visitors import InternalTraversal if TYPE_CHECKING: from ._typing import _InternalEntityType + from .mapper import Mapper + from .query import Query from ..sql.compiler import _CompilerStackEntry from ..sql.dml import _DMLTableElement from ..sql.elements import ColumnElement from ..sql.selectable import _LabelConventionCallable from ..sql.selectable import SelectBase + from ..sql.type_api import TypeEngine _path_registry = PathRegistry.root @@ -211,6 +214,9 @@ class ORMCompileState(CompileState): _for_refresh_state = False _render_for_subquery = False + attributes: Dict[Any, Any] + global_attributes: Dict[Any, Any] + statement: Union[Select, FromStatement] select_statement: Union[Select, FromStatement] _entities: List[_QueryEntity] @@ -1930,7 +1936,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): assert right_mapper adapter = ORMAdapter( - right, equivalents=right_mapper._equivalent_columns + inspect(right), equivalents=right_mapper._equivalent_columns ) # if an alias() on the right side was generated, @@ -2075,14 +2081,16 @@ class ORMSelectCompileState(ORMCompileState, SelectState): def _column_descriptions( - query_or_select_stmt, compile_state=None, legacy=False + query_or_select_stmt: Union[Query, Select, FromStatement], + compile_state: Optional[ORMSelectCompileState] = None, + legacy: bool = False, ) -> List[ORMColumnDescription]: if compile_state is None: compile_state = ORMSelectCompileState._create_entities_collection( query_or_select_stmt, legacy=legacy ) ctx = compile_state - return [ + d = [ { "name": ent._label_name, "type": ent.type, @@ -2093,17 +2101,10 @@ def _column_descriptions( else None, } for ent, insp_ent in [ - ( - _ent, - ( - inspect(_ent.entity_zero) - if _ent.entity_zero is not None - else None - ), - ) - for _ent in ctx._entities + (_ent, _ent.entity_zero) for _ent in ctx._entities ] ] + return d def _legacy_filter_by_entity_zero(query_or_augmented_select): @@ -2157,6 +2158,11 @@ class _QueryEntity: _null_column_type = False use_id_for_hash = False + _label_name: Optional[str] + type: Union[Type[Any], TypeEngine[Any]] + expr: Union[_InternalEntityType, ColumnElement[Any]] + entity_zero: Optional[_InternalEntityType] + def setup_compile_state(self, compile_state: ORMCompileState) -> None: raise NotImplementedError() @@ -2234,6 +2240,13 @@ class _MapperEntity(_QueryEntity): "_polymorphic_discriminator", ) + expr: _InternalEntityType + mapper: Mapper[Any] + entity_zero: _InternalEntityType + is_aliased_class: bool + path: PathRegistry + _label_name: str + def __init__( self, compile_state, entity, entities_collection, is_current_entities ): @@ -2389,6 +2402,13 @@ class _BundleEntity(_QueryEntity): "supports_single_entity", ) + _entities: List[_QueryEntity] + bundle: Bundle + type: Type[Any] + _label_name: str + supports_single_entity: bool + expr: Bundle + def __init__( self, compile_state, diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 70507015b..0c990f809 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -50,7 +50,7 @@ from ..util import hybridproperty from ..util import typing as compat_typing if typing.TYPE_CHECKING: - from .state import InstanceState # noqa + from .state import InstanceState _T = TypeVar("_T", bound=Any) @@ -280,7 +280,7 @@ class declared_attr(interfaces._MappedAttribute[_T]): # for the span of the declarative scan_attributes() phase. # to achieve this we look at the class manager that's configured. cls = owner - manager = attributes.manager_of_class(cls) + manager = attributes.opt_manager_of_class(cls) if manager is None: if not re.match(r"^__.+__$", self.fget.__name__): # if there is no manager at all, then this class hasn't been @@ -1294,8 +1294,8 @@ def as_declarative(**kw): @inspection._inspects( DeclarativeMeta, DeclarativeBase, DeclarativeAttributeIntercept ) -def _inspect_decl_meta(cls): - mp = _inspect_mapped_class(cls) +def _inspect_decl_meta(cls: Type[Any]) -> Mapper[Any]: + mp: Mapper[Any] = _inspect_mapped_class(cls) if mp is None: if _DeferredMapperConfig.has_cls(cls): _DeferredMapperConfig.raise_unmapped_for_cls(cls) diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index 804d05ce1..9c79a4172 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -12,6 +12,8 @@ import collections from typing import Any from typing import Dict from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING import weakref from . import attributes @@ -42,6 +44,10 @@ from ..sql.schema import Column from ..sql.schema import Table from ..util import topological +if TYPE_CHECKING: + from ._typing import _O + from ._typing import _RegistryType + def _declared_mapping_info(cls): # deferred mapping @@ -121,7 +127,7 @@ def _dive_for_cls_manager(cls): return None for base in cls.__mro__: - manager = attributes.manager_of_class(base) + manager = attributes.opt_manager_of_class(base) if manager: return manager return None @@ -171,7 +177,7 @@ class _MapperConfig: @classmethod def setup_mapping(cls, registry, cls_, dict_, table, mapper_kw): - manager = attributes.manager_of_class(cls) + manager = attributes.opt_manager_of_class(cls) if manager and manager.class_ is cls_: raise exc.InvalidRequestError( "Class %r already has been " "instrumented declaratively" % cls @@ -191,7 +197,12 @@ class _MapperConfig: return cfg_cls(registry, cls_, dict_, table, mapper_kw) - def __init__(self, registry, cls_, mapper_kw): + def __init__( + self, + registry: _RegistryType, + cls_: Type[Any], + mapper_kw: Dict[str, Any], + ): self.cls = util.assert_arg_type(cls_, type, "cls_") self.classname = cls_.__name__ self.properties = util.OrderedDict() @@ -206,7 +217,7 @@ class _MapperConfig: init_method=registry.constructor, ) else: - manager = attributes.manager_of_class(self.cls) + manager = attributes.opt_manager_of_class(self.cls) if not manager or not manager.is_mapped: raise exc.InvalidRequestError( "Class %s has no primary mapper configured. Configure " diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 8beac472e..4738d8c2c 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -122,7 +122,11 @@ class DescriptorProperty(MapperProperty[_T]): _CompositeAttrType = Union[ - str, "Column[Any]", "MappedColumn[Any]", "InstrumentedAttribute[Any]" + str, + "Column[_T]", + "MappedColumn[_T]", + "InstrumentedAttribute[_T]", + "Mapped[_T]", ] diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index c531e7cf1..331c224ee 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -11,6 +11,9 @@ from __future__ import annotations from typing import Any +from typing import Optional +from typing import Type +from typing import TYPE_CHECKING import weakref from . import instrumentation @@ -27,6 +30,10 @@ from .. import exc from .. import util from ..util.compat import inspect_getfullargspec +if TYPE_CHECKING: + from ._typing import _O + from .instrumentation import ClassManager + class InstrumentationEvents(event.Events): """Events related to class instrumentation events. @@ -214,7 +221,7 @@ class InstanceEvents(event.Events): if issubclass(target, mapperlib.Mapper): return instrumentation.ClassManager else: - manager = instrumentation.manager_of_class(target) + manager = instrumentation.opt_manager_of_class(target) if manager: return manager else: @@ -613,8 +620,8 @@ class _EventsHold(event.RefCollection): class _InstanceEventsHold(_EventsHold): all_holds = weakref.WeakKeyDictionary() - def resolve(self, class_): - return instrumentation.manager_of_class(class_) + def resolve(self, class_: Type[_O]) -> Optional[ClassManager[_O]]: + return instrumentation.opt_manager_of_class(class_) class HoldInstanceEvents(_EventsHold.HoldEvents, InstanceEvents): pass diff --git a/lib/sqlalchemy/orm/exc.py b/lib/sqlalchemy/orm/exc.py index 00829ecbb..529a7cd01 100644 --- a/lib/sqlalchemy/orm/exc.py +++ b/lib/sqlalchemy/orm/exc.py @@ -203,7 +203,10 @@ def _default_unmapped(cls) -> Optional[str]: try: mappers = base.manager_of_class(cls).mappers - except (TypeError,) + NO_STATE: + except ( + UnmappedClassError, + TypeError, + ) + NO_STATE: mappers = {} name = _safe_cls_name(cls) diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py index 0d4b630da..88ceacd07 100644 --- a/lib/sqlalchemy/orm/instrumentation.py +++ b/lib/sqlalchemy/orm/instrumentation.py @@ -33,10 +33,13 @@ alternate instrumentation forms. from __future__ import annotations from typing import Any +from typing import Callable from typing import Dict from typing import Generic from typing import Optional from typing import Set +from typing import Tuple +from typing import Type from typing import TYPE_CHECKING from typing import TypeVar import weakref @@ -53,7 +56,9 @@ from ..util import HasMemoized from ..util.typing import Protocol if TYPE_CHECKING: + from ._typing import _RegistryType from .attributes import InstrumentedAttribute + from .decl_base import _MapperConfig from .mapper import Mapper from .state import InstanceState from ..event import dispatcher @@ -72,6 +77,11 @@ class _ExpiredAttributeLoaderProto(Protocol): ... +class _ManagerFactory(Protocol): + def __call__(self, class_: Type[_O]) -> ClassManager[_O]: + ... + + class ClassManager( HasMemoized, Dict[str, "InstrumentedAttribute[Any]"], @@ -90,12 +100,12 @@ class ClassManager( expired_attribute_loader: _ExpiredAttributeLoaderProto "previously known as deferred_scalar_loader" - init_method = None + init_method: Optional[Callable[..., None]] - factory = None + factory: Optional[_ManagerFactory] - declarative_scan = None - registry = None + declarative_scan: Optional[weakref.ref[_MapperConfig]] = None + registry: Optional[_RegistryType] = None @property @util.deprecated( @@ -122,11 +132,13 @@ class ClassManager( self.local_attrs = {} self.originals = {} self._finalized = False + self.factory = None + self.init_method = None self._bases = [ mgr for mgr in [ - manager_of_class(base) + opt_manager_of_class(base) for base in self.class_.__bases__ if isinstance(base, type) ] @@ -139,7 +151,7 @@ class ClassManager( self.dispatch._events._new_classmanager_instance(class_, self) for basecls in class_.__mro__: - mgr = manager_of_class(basecls) + mgr = opt_manager_of_class(basecls) if mgr is not None: self.dispatch._update(mgr.dispatch) @@ -155,16 +167,18 @@ class ClassManager( def _update_state( self, - finalize=False, - mapper=None, - registry=None, - declarative_scan=None, - expired_attribute_loader=None, - init_method=None, - ): + finalize: bool = False, + mapper: Optional[Mapper[_O]] = None, + registry: Optional[_RegistryType] = None, + declarative_scan: Optional[_MapperConfig] = None, + expired_attribute_loader: Optional[ + _ExpiredAttributeLoaderProto + ] = None, + init_method: Optional[Callable[..., None]] = None, + ) -> None: if mapper: - self.mapper = mapper + self.mapper = mapper # type: ignore[assignment] if registry: registry._add_manager(self) if declarative_scan: @@ -350,7 +364,7 @@ class ClassManager( def subclass_managers(self, recursive): for cls in self.class_.__subclasses__(): - mgr = manager_of_class(cls) + mgr = opt_manager_of_class(cls) if mgr is not None and mgr is not self: yield mgr if recursive: @@ -374,7 +388,7 @@ class ClassManager( self._reset_memoizations() del self[key] for cls in self.class_.__subclasses__(): - manager = manager_of_class(cls) + manager = opt_manager_of_class(cls) if manager: manager.uninstrument_attribute(key, True) @@ -523,7 +537,7 @@ class _SerializeManager: manager.dispatch.pickle(state, d) def __call__(self, state, inst, state_dict): - state.manager = manager = manager_of_class(self.class_) + state.manager = manager = opt_manager_of_class(self.class_) if manager is None: raise exc.UnmappedInstanceError( inst, @@ -546,9 +560,9 @@ class _SerializeManager: class InstrumentationFactory: """Factory for new ClassManager instances.""" - def create_manager_for_cls(self, class_): + def create_manager_for_cls(self, class_: Type[_O]) -> ClassManager[_O]: assert class_ is not None - assert manager_of_class(class_) is None + assert opt_manager_of_class(class_) is None # give a more complicated subclass # a chance to do what it wants here @@ -557,6 +571,8 @@ class InstrumentationFactory: if factory is None: factory = ClassManager manager = factory(class_) + else: + assert manager is not None self._check_conflicts(class_, factory) @@ -564,11 +580,15 @@ class InstrumentationFactory: return manager - def _locate_extended_factory(self, class_): + def _locate_extended_factory( + self, class_: Type[_O] + ) -> Tuple[Optional[ClassManager[_O]], Optional[_ManagerFactory]]: """Overridden by a subclass to do an extended lookup.""" return None, None - def _check_conflicts(self, class_, factory): + def _check_conflicts( + self, class_: Type[_O], factory: Callable[[Type[_O]], ClassManager[_O]] + ): """Overridden by a subclass to test for conflicting factories.""" return @@ -590,24 +610,25 @@ instance_state = _default_state_getter = base.instance_state instance_dict = _default_dict_getter = base.instance_dict manager_of_class = _default_manager_getter = base.manager_of_class +opt_manager_of_class = _default_opt_manager_getter = base.opt_manager_of_class def register_class( - class_, - finalize=True, - mapper=None, - registry=None, - declarative_scan=None, - expired_attribute_loader=None, - init_method=None, -): + class_: Type[_O], + finalize: bool = True, + mapper: Optional[Mapper[_O]] = None, + registry: Optional[_RegistryType] = None, + declarative_scan: Optional[_MapperConfig] = None, + expired_attribute_loader: Optional[_ExpiredAttributeLoaderProto] = None, + init_method: Optional[Callable[..., None]] = None, +) -> ClassManager[_O]: """Register class instrumentation. Returns the existing or newly created class manager. """ - manager = manager_of_class(class_) + manager = opt_manager_of_class(class_) if manager is None: manager = _instrumentation_factory.create_manager_for_cls(class_) manager._update_state( diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index abc1300d8..0ca62b7e3 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -21,10 +21,15 @@ from __future__ import annotations import collections import typing from typing import Any +from typing import Callable from typing import cast +from typing import ClassVar +from typing import Dict +from typing import Iterator from typing import List from typing import Optional from typing import Sequence +from typing import Set from typing import Tuple from typing import Type from typing import TypeVar @@ -45,7 +50,6 @@ from .base import NotExtension as NotExtension from .base import ONETOMANY as ONETOMANY from .base import SQLORMOperations from .. import ColumnElement -from .. import inspect from .. import inspection from .. import util from ..sql import operators @@ -53,19 +57,47 @@ from ..sql import roles from ..sql import visitors from ..sql.base import ExecutableOption from ..sql.cache_key import HasCacheKey -from ..sql.elements import SQLCoreOperations from ..sql.schema import Column from ..sql.type_api import TypeEngine from ..util.typing import TypedDict + if typing.TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _IdentityKeyType + from ._typing import _InstanceDict + from ._typing import _InternalEntityType + from ._typing import _ORMAdapterProto + from ._typing import _ORMColumnExprArgument + from .attributes import InstrumentedAttribute + from .context import _MapperEntity + from .context import ORMCompileState from .decl_api import RegistryType + from .loading import _PopulatorDict + from .mapper import Mapper + from .path_registry import AbstractEntityRegistry + from .path_registry import PathRegistry + from .query import Query + from .session import Session + from .state import InstanceState + from .strategy_options import _LoadElement + from .util import AliasedInsp + from .util import CascadeOptions + from .util import ORMAdapter + from ..engine.result import Result + from ..sql._typing import _ColumnExpressionArgument from ..sql._typing import _ColumnsClauseArgument from ..sql._typing import _DMLColumnArgument from ..sql._typing import _InfoType + from ..sql._typing import _PropagateAttrsType + from ..sql.operators import OperatorType + from ..sql.util import ColumnAdapter + from ..sql.visitors import _TraverseInternalsType _T = TypeVar("_T", bound=Any) +_TLS = TypeVar("_TLS", bound="Type[LoaderStrategy]") + class ORMStatementRole(roles.StatementRole): __slots__ = () @@ -91,7 +123,9 @@ class ORMFromClauseRole(roles.StrictFromClauseRole): class ORMColumnDescription(TypedDict): name: str - type: Union[Type, TypeEngine] + # TODO: add python_type and sql_type here; combining them + # into "type" is a bad idea + type: Union[Type[Any], TypeEngine[Any]] aliased: bool expr: _ColumnsClauseArgument entity: Optional[_ColumnsClauseArgument] @@ -102,10 +136,10 @@ class _IntrospectsAnnotations: def declarative_scan( self, - registry: "RegistryType", - cls: type, + registry: RegistryType, + cls: Type[Any], key: str, - annotation: Optional[type], + annotation: Optional[Type[Any]], is_dataclass_field: Optional[bool], ) -> None: """Perform class-specific initializaton at early declarative scanning @@ -124,12 +158,12 @@ class _MapsColumns(_MappedAttribute[_T]): __slots__ = () @property - def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]: + def mapper_property_to_assign(self) -> Optional[MapperProperty[_T]]: """return a MapperProperty to be assigned to the declarative mapping""" raise NotImplementedError() @property - def columns_to_assign(self) -> List[Column]: + def columns_to_assign(self) -> List[Column[_T]]: """A list of Column objects that should be declaratively added to the new Table object. @@ -139,7 +173,10 @@ class _MapsColumns(_MappedAttribute[_T]): @inspection._self_inspects class MapperProperty( - HasCacheKey, _MappedAttribute[_T], InspectionAttr, util.MemoizedSlots + HasCacheKey, + _MappedAttribute[_T], + InspectionAttrInfo, + util.MemoizedSlots, ): """Represent a particular class attribute mapped by :class:`_orm.Mapper`. @@ -160,12 +197,12 @@ class MapperProperty( "info", ) - _cache_key_traversal = [ + _cache_key_traversal: _TraverseInternalsType = [ ("parent", visitors.ExtendedInternalTraversal.dp_has_cache_key), ("key", visitors.ExtendedInternalTraversal.dp_string), ] - cascade = frozenset() + cascade: Optional[CascadeOptions] = None """The set of 'cascade' attribute names. This collection is checked before the 'cascade_iterator' method is called. @@ -184,14 +221,20 @@ class MapperProperty( """The :class:`_orm.PropComparator` instance that implements SQL expression construction on behalf of this mapped attribute.""" - @property - def _links_to_entity(self): - """True if this MapperProperty refers to a mapped entity. + key: str + """name of class attribute""" - Should only be True for Relationship, False for all others. + parent: Mapper[Any] + """the :class:`.Mapper` managing this property.""" - """ - raise NotImplementedError() + _is_relationship = False + + _links_to_entity: bool + """True if this MapperProperty refers to a mapped entity. + + Should only be True for Relationship, False for all others. + + """ def _memoized_attr_info(self) -> _InfoType: """Info dictionary associated with the object, allowing user-defined @@ -217,7 +260,14 @@ class MapperProperty( """ return {} - def setup(self, context, query_entity, path, adapter, **kwargs): + def setup( + self, + context: ORMCompileState, + query_entity: _MapperEntity, + path: PathRegistry, + adapter: Optional[ColumnAdapter], + **kwargs: Any, + ) -> None: """Called by Query for the purposes of constructing a SQL statement. Each MapperProperty associated with the target mapper processes the @@ -227,16 +277,30 @@ class MapperProperty( """ def create_row_processor( - self, context, query_entity, path, mapper, result, adapter, populators - ): + self, + context: ORMCompileState, + query_entity: _MapperEntity, + path: PathRegistry, + mapper: Mapper[Any], + result: Result, + adapter: Optional[ColumnAdapter], + populators: _PopulatorDict, + ) -> None: """Produce row processing functions and append to the given set of populators lists. """ def cascade_iterator( - self, type_, state, dict_, visited_states, halt_on=None - ): + self, + type_: str, + state: InstanceState[Any], + dict_: _InstanceDict, + visited_states: Set[InstanceState[Any]], + halt_on: Optional[Callable[[InstanceState[Any]], bool]] = None, + ) -> Iterator[ + Tuple[object, Mapper[Any], InstanceState[Any], _InstanceDict] + ]: """Iterate through instances related to the given instance for a particular 'cascade', starting with this MapperProperty. @@ -251,7 +315,7 @@ class MapperProperty( return iter(()) - def set_parent(self, parent, init): + def set_parent(self, parent: Mapper[Any], init: bool) -> None: """Set the parent mapper that references this MapperProperty. This method is overridden by some subclasses to perform extra @@ -260,7 +324,7 @@ class MapperProperty( """ self.parent = parent - def instrument_class(self, mapper): + def instrument_class(self, mapper: Mapper[Any]) -> None: """Hook called by the Mapper to the property to initiate instrumentation of the class attribute managed by this MapperProperty. @@ -280,11 +344,11 @@ class MapperProperty( """ - def __init__(self): + def __init__(self) -> None: self._configure_started = False self._configure_finished = False - def init(self): + def init(self) -> None: """Called after all mappers are created to assemble relationships between mappers and perform other post-mapper-creation initialization steps. @@ -296,7 +360,7 @@ class MapperProperty( self._configure_finished = True @property - def class_attribute(self): + def class_attribute(self) -> InstrumentedAttribute[_T]: """Return the class-bound descriptor corresponding to this :class:`.MapperProperty`. @@ -319,9 +383,9 @@ class MapperProperty( """ - return getattr(self.parent.class_, self.key) + return getattr(self.parent.class_, self.key) # type: ignore - def do_init(self): + def do_init(self) -> None: """Perform subclass-specific initialization post-mapper-creation steps. @@ -330,7 +394,7 @@ class MapperProperty( """ - def post_instrument_class(self, mapper): + def post_instrument_class(self, mapper: Mapper[Any]) -> None: """Perform instrumentation adjustments that need to occur after init() has completed. @@ -347,21 +411,21 @@ class MapperProperty( def merge( self, - session, - source_state, - source_dict, - dest_state, - dest_dict, - load, - _recursive, - _resolve_conflict_map, - ): + session: Session, + source_state: InstanceState[Any], + source_dict: _InstanceDict, + dest_state: InstanceState[Any], + dest_dict: _InstanceDict, + load: bool, + _recursive: Set[InstanceState[Any]], + _resolve_conflict_map: Dict[_IdentityKeyType[Any], object], + ) -> None: """Merge the attribute represented by this ``MapperProperty`` from source to destination object. """ - def __repr__(self): + def __repr__(self) -> str: return "<%s at 0x%x; %s>" % ( self.__class__.__name__, id(self), @@ -452,21 +516,28 @@ class PropComparator(SQLORMOperations[_T]): """ - __slots__ = "prop", "property", "_parententity", "_adapt_to_entity" + __slots__ = "prop", "_parententity", "_adapt_to_entity" __visit_name__ = "orm_prop_comparator" + _parententity: _InternalEntityType[Any] + _adapt_to_entity: Optional[AliasedInsp[Any]] + def __init__( self, - prop, - parentmapper, - adapt_to_entity=None, + prop: MapperProperty[_T], + parentmapper: _InternalEntityType[Any], + adapt_to_entity: Optional[AliasedInsp[Any]] = None, ): - self.prop = self.property = prop + self.prop = prop self._parententity = adapt_to_entity or parentmapper self._adapt_to_entity = adapt_to_entity - def __clause_element__(self): + @util.ro_non_memoized_property + def property(self) -> Optional[MapperProperty[_T]]: + return self.prop + + def __clause_element__(self) -> _ORMColumnExprArgument[_T]: raise NotImplementedError("%r" % self) def _bulk_update_tuples( @@ -480,22 +551,24 @@ class PropComparator(SQLORMOperations[_T]): """ - return [(self.__clause_element__(), value)] + return [(cast("_DMLColumnArgument", self.__clause_element__()), value)] - def adapt_to_entity(self, adapt_to_entity): + def adapt_to_entity( + self, adapt_to_entity: AliasedInsp[Any] + ) -> PropComparator[_T]: """Return a copy of this PropComparator which will use the given :class:`.AliasedInsp` to produce corresponding expressions. """ return self.__class__(self.prop, self._parententity, adapt_to_entity) - @property - def _parentmapper(self): + @util.ro_non_memoized_property + def _parentmapper(self) -> Mapper[Any]: """legacy; this is renamed to _parententity to be compatible with QueryableAttribute.""" - return inspect(self._parententity).mapper + return self._parententity.mapper - @property - def _propagate_attrs(self): + @util.memoized_property + def _propagate_attrs(self) -> _PropagateAttrsType: # this suits the case in coercions where we don't actually # call ``__clause_element__()`` but still need to get # resolved._propagate_attrs. See #6558. @@ -507,12 +580,14 @@ class PropComparator(SQLORMOperations[_T]): ) def _criterion_exists( - self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs: Any + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, ) -> ColumnElement[Any]: return self.prop.comparator._criterion_exists(criterion, **kwargs) - @property - def adapter(self): + @util.ro_non_memoized_property + def adapter(self) -> Optional[_ORMAdapterProto[_T]]: """Produce a callable that adapts column expressions to suit an aliased version of this comparator. @@ -522,20 +597,20 @@ class PropComparator(SQLORMOperations[_T]): else: return self._adapt_to_entity._adapt_element - @util.non_memoized_property + @util.ro_non_memoized_property def info(self) -> _InfoType: - return self.property.info + return self.prop.info @staticmethod - def _any_op(a, b, **kwargs): + def _any_op(a: Any, b: Any, **kwargs: Any) -> Any: return a.any(b, **kwargs) @staticmethod - def _has_op(left, other, **kwargs): + def _has_op(left: Any, other: Any, **kwargs: Any) -> Any: return left.has(other, **kwargs) @staticmethod - def _of_type_op(a, class_): + def _of_type_op(a: Any, class_: Any) -> Any: return a.of_type(class_) any_op = cast(operators.OperatorType, _any_op) @@ -545,16 +620,16 @@ class PropComparator(SQLORMOperations[_T]): if typing.TYPE_CHECKING: def operate( - self, op: operators.OperatorType, *other: Any, **kwargs: Any - ) -> "SQLCoreOperations[Any]": + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[Any]: ... def reverse_operate( - self, op: operators.OperatorType, other: Any, **kwargs: Any - ) -> "SQLCoreOperations[Any]": + self, op: OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[Any]: ... - def of_type(self, class_) -> "SQLORMOperations[_T]": + def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T]: r"""Redefine this object in terms of a polymorphic subclass, :func:`_orm.with_polymorphic` construct, or :func:`_orm.aliased` construct. @@ -578,9 +653,11 @@ class PropComparator(SQLORMOperations[_T]): """ - return self.operate(PropComparator.of_type_op, class_) + return self.operate(PropComparator.of_type_op, class_) # type: ignore - def and_(self, *criteria) -> "SQLORMOperations[_T]": + def and_( + self, *criteria: _ColumnExpressionArgument[bool] + ) -> ColumnElement[bool]: """Add additional criteria to the ON clause that's represented by this relationship attribute. @@ -606,10 +683,12 @@ class PropComparator(SQLORMOperations[_T]): :func:`.with_loader_criteria` """ - return self.operate(operators.and_, *criteria) + return self.operate(operators.and_, *criteria) # type: ignore def any( - self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, ) -> ColumnElement[bool]: r"""Return a SQL expression representing true if this element references a member which meets the given criterion. @@ -626,10 +705,14 @@ class PropComparator(SQLORMOperations[_T]): """ - return self.operate(PropComparator.any_op, criterion, **kwargs) + return self.operate( # type: ignore + PropComparator.any_op, criterion, **kwargs + ) def has( - self, criterion: Optional[SQLCoreOperations[Any]] = None, **kwargs + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, ) -> ColumnElement[bool]: r"""Return a SQL expression representing true if this element references a member which meets the given criterion. @@ -646,7 +729,9 @@ class PropComparator(SQLORMOperations[_T]): """ - return self.operate(PropComparator.has_op, criterion, **kwargs) + return self.operate( # type: ignore + PropComparator.has_op, criterion, **kwargs + ) class StrategizedProperty(MapperProperty[_T]): @@ -674,23 +759,30 @@ class StrategizedProperty(MapperProperty[_T]): "strategy_key", ) inherit_cache = True - strategy_wildcard_key = None + strategy_wildcard_key: ClassVar[str] strategy_key: Tuple[Any, ...] - def _memoized_attr__wildcard_token(self): + _strategies: Dict[Tuple[Any, ...], LoaderStrategy] + + def _memoized_attr__wildcard_token(self) -> Tuple[str]: return ( f"{self.strategy_wildcard_key}:{path_registry._WILDCARD_TOKEN}", ) - def _memoized_attr__default_path_loader_key(self): + def _memoized_attr__default_path_loader_key( + self, + ) -> Tuple[str, Tuple[str]]: return ( "loader", (f"{self.strategy_wildcard_key}:{path_registry._DEFAULT_TOKEN}",), ) - def _get_context_loader(self, context, path): - load = None + def _get_context_loader( + self, context: ORMCompileState, path: AbstractEntityRegistry + ) -> Optional[_LoadElement]: + + load: Optional[_LoadElement] = None search_path = path[self] @@ -714,7 +806,7 @@ class StrategizedProperty(MapperProperty[_T]): return load - def _get_strategy(self, key): + def _get_strategy(self, key: Tuple[Any, ...]) -> LoaderStrategy: try: return self._strategies[key] except KeyError: @@ -768,11 +860,13 @@ class StrategizedProperty(MapperProperty[_T]): ): self.strategy.init_class_attribute(mapper) - _all_strategies = collections.defaultdict(dict) + _all_strategies: collections.defaultdict[ + Type[Any], Dict[Tuple[Any, ...], Type[LoaderStrategy]] + ] = collections.defaultdict(dict) @classmethod - def strategy_for(cls, **kw): - def decorate(dec_cls): + def strategy_for(cls, **kw: Any) -> Callable[[_TLS], _TLS]: + def decorate(dec_cls: _TLS) -> _TLS: # ensure each subclass of the strategy has its # own _strategy_keys collection if "_strategy_keys" not in dec_cls.__dict__: @@ -785,7 +879,9 @@ class StrategizedProperty(MapperProperty[_T]): return decorate @classmethod - def _strategy_lookup(cls, requesting_property, *key): + def _strategy_lookup( + cls, requesting_property: MapperProperty[Any], *key: Any + ) -> Type[LoaderStrategy]: requesting_property.parent._with_polymorphic_mappers for prop_cls in cls.__mro__: @@ -984,10 +1080,10 @@ class MapperOption(ORMOption): """ - def process_query(self, query): + def process_query(self, query: Query[Any]) -> None: """Apply a modification to the given :class:`_query.Query`.""" - def process_query_conditionally(self, query): + def process_query_conditionally(self, query: Query[Any]) -> None: """same as process_query(), except that this option may not apply to the given query. @@ -1034,7 +1130,11 @@ class LoaderStrategy: "strategy_opts", ) - def __init__(self, parent, strategy_key): + _strategy_keys: ClassVar[List[Tuple[Any, ...]]] + + def __init__( + self, parent: MapperProperty[Any], strategy_key: Tuple[Any, ...] + ): self.parent_property = parent self.is_class_level = False self.parent = self.parent_property.parent @@ -1042,12 +1142,18 @@ class LoaderStrategy: self.strategy_key = strategy_key self.strategy_opts = dict(strategy_key) - def init_class_attribute(self, mapper): + def init_class_attribute(self, mapper: Mapper[Any]) -> None: pass def setup_query( - self, compile_state, query_entity, path, loadopt, adapter, **kwargs - ): + self, + compile_state: ORMCompileState, + query_entity: _MapperEntity, + path: AbstractEntityRegistry, + loadopt: Optional[_LoadElement], + adapter: Optional[ORMAdapter], + **kwargs: Any, + ) -> None: """Establish column and other state for a given QueryContext. This method fulfills the contract specified by MapperProperty.setup(). @@ -1059,15 +1165,15 @@ class LoaderStrategy: def create_row_processor( self, - context, - query_entity, - path, - loadopt, - mapper, - result, - adapter, - populators, - ): + context: ORMCompileState, + query_entity: _MapperEntity, + path: AbstractEntityRegistry, + loadopt: Optional[_LoadElement], + mapper: Mapper[Any], + result: Result, + adapter: Optional[ORMAdapter], + populators: _PopulatorDict, + ) -> None: """Establish row processing functions for a given QueryContext. This method fulfills the contract specified by diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index ae083054c..d9949eb7a 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -16,7 +16,9 @@ as well as some of the attribute loading strategies. from __future__ import annotations from typing import Any +from typing import Dict from typing import Iterable +from typing import List from typing import Mapping from typing import Optional from typing import Sequence @@ -65,6 +67,9 @@ _O = TypeVar("_O", bound=object) _new_runid = util.counter() +_PopulatorDict = Dict[str, List[Tuple[str, Any]]] + + def instances(cursor, context): """Return a :class:`.Result` given an ORM query context. @@ -383,7 +388,7 @@ def get_from_identity( mapper: Mapper[_O], key: _IdentityKeyType[_O], passive: PassiveFlag, -) -> Union[Optional[_O], LoaderCallableStatus]: +) -> Union[LoaderCallableStatus, Optional[_O]]: """Look up the given key in the given session's identity map, check the object for expired state if found. diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index abe11cc68..b37c080ea 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -23,12 +23,23 @@ import sys import threading from typing import Any from typing import Callable +from typing import cast +from typing import Collection +from typing import Deque +from typing import Dict from typing import Generic +from typing import Iterable from typing import Iterator +from typing import List +from typing import Mapping from typing import Optional +from typing import Sequence +from typing import Set from typing import Tuple from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union import weakref from . import attributes @@ -39,8 +50,8 @@ from . import properties from . import util as orm_util from ._typing import _O from .base import _class_to_mapper +from .base import _parse_mapper_argument from .base import _state_mapper -from .base import class_mapper from .base import PassiveFlag from .base import state_str from .interfaces import _MappedAttribute @@ -58,6 +69,8 @@ from .. import log from .. import schema from .. import sql from .. import util +from ..event import dispatcher +from ..event import EventTarget from ..sql import base as sql_base from ..sql import coercions from ..sql import expression @@ -65,26 +78,68 @@ from ..sql import operators from ..sql import roles from ..sql import util as sql_util from ..sql import visitors +from ..sql.cache_key import MemoizedHasCacheKey +from ..sql.schema import Table from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..util import HasMemoized +from ..util import HasMemoized_ro_memoized_attribute +from ..util.typing import Literal if TYPE_CHECKING: from ._typing import _IdentityKeyType from ._typing import _InstanceDict + from ._typing import _ORMColumnExprArgument + from ._typing import _RegistryType + from .decl_api import registry + from .dependency import DependencyProcessor + from .descriptor_props import Composite + from .descriptor_props import Synonym + from .events import MapperEvents from .instrumentation import ClassManager + from .path_registry import AbstractEntityRegistry + from .path_registry import CachingEntityRegistry + from .properties import ColumnProperty + from .relationships import Relationship from .state import InstanceState + from ..engine import Row + from ..engine import RowMapping + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _EquivalentColumnMap + from ..sql.base import ReadOnlyColumnCollection + from ..sql.elements import ColumnClause from ..sql.elements import ColumnElement from ..sql.schema import Column + from ..sql.schema import Table + from ..sql.selectable import FromClause + from ..sql.selectable import TableClause + from ..sql.util import ColumnAdapter + from ..util import OrderedSet -_mapper_registries = weakref.WeakKeyDictionary() +_T = TypeVar("_T", bound=Any) +_MP = TypeVar("_MP", bound="MapperProperty[Any]") -def _all_registries(): +_WithPolymorphicArg = Union[ + Literal["*"], + Tuple[ + Union[Literal["*"], Sequence[Union["Mapper[Any]", Type[Any]]]], + Optional["FromClause"], + ], + Sequence[Union["Mapper[Any]", Type[Any]]], +] + + +_mapper_registries: weakref.WeakKeyDictionary[ + _RegistryType, bool +] = weakref.WeakKeyDictionary() + + +def _all_registries() -> Set[registry]: with _CONFIGURE_MUTEX: return set(_mapper_registries) -def _unconfigured_mappers(): +def _unconfigured_mappers() -> Iterator[Mapper[Any]]: for reg in _all_registries(): for mapper in reg._mappers_to_configure(): yield mapper @@ -107,9 +162,11 @@ _CONFIGURE_MUTEX = threading.RLock() class Mapper( ORMFromClauseRole, ORMEntityColumnsClauseRole, - sql_base.MemoizedHasCacheKey, + MemoizedHasCacheKey, InspectionAttr, log.Identified, + inspection.Inspectable["Mapper[_O]"], + EventTarget, Generic[_O], ): """Defines an association between a Python class and a database table or @@ -123,18 +180,11 @@ class Mapper( """ + dispatch: dispatcher[Mapper[_O]] + _dispose_called = False _ready_for_configure = False - class_: Type[_O] - """The class to which this :class:`_orm.Mapper` is mapped.""" - - _identity_class: Type[_O] - - always_refresh: bool - allow_partial_pks: bool - version_id_col: Optional[ColumnElement[Any]] - @util.deprecated_params( non_primary=( "1.3", @@ -148,33 +198,39 @@ class Mapper( def __init__( self, class_: Type[_O], - local_table=None, - properties=None, - primary_key=None, - non_primary=False, - inherits=None, - inherit_condition=None, - inherit_foreign_keys=None, - always_refresh=False, - version_id_col=None, - version_id_generator=None, - polymorphic_on=None, - _polymorphic_map=None, - polymorphic_identity=None, - concrete=False, - with_polymorphic=None, - polymorphic_load=None, - allow_partial_pks=True, - batch=True, - column_prefix=None, - include_properties=None, - exclude_properties=None, - passive_updates=True, - passive_deletes=False, - confirm_deleted_rows=True, - eager_defaults=False, - legacy_is_orphan=False, - _compiled_cache_size=100, + local_table: Optional[FromClause] = None, + properties: Optional[Mapping[str, MapperProperty[Any]]] = None, + primary_key: Optional[Iterable[_ORMColumnExprArgument[Any]]] = None, + non_primary: bool = False, + inherits: Optional[Union[Mapper[Any], Type[Any]]] = None, + inherit_condition: Optional[_ColumnExpressionArgument[bool]] = None, + inherit_foreign_keys: Optional[ + Sequence[_ORMColumnExprArgument[Any]] + ] = None, + always_refresh: bool = False, + version_id_col: Optional[_ORMColumnExprArgument[Any]] = None, + version_id_generator: Optional[ + Union[Literal[False], Callable[[Any], Any]] + ] = None, + polymorphic_on: Optional[ + Union[_ORMColumnExprArgument[Any], str, MapperProperty[Any]] + ] = None, + _polymorphic_map: Optional[Dict[Any, Mapper[Any]]] = None, + polymorphic_identity: Optional[Any] = None, + concrete: bool = False, + with_polymorphic: Optional[_WithPolymorphicArg] = None, + polymorphic_load: Optional[Literal["selectin", "inline"]] = None, + allow_partial_pks: bool = True, + batch: bool = True, + column_prefix: Optional[str] = None, + include_properties: Optional[Sequence[str]] = None, + exclude_properties: Optional[Sequence[str]] = None, + passive_updates: bool = True, + passive_deletes: bool = False, + confirm_deleted_rows: bool = True, + eager_defaults: bool = False, + legacy_is_orphan: bool = False, + _compiled_cache_size: int = 100, ): r"""Direct constructor for a new :class:`_orm.Mapper` object. @@ -593,8 +649,6 @@ class Mapper( self.class_.__name__, ) - self.class_manager = None - self._primary_key_argument = util.to_list(primary_key) self.non_primary = non_primary @@ -623,17 +677,36 @@ class Mapper( self.concrete = concrete self.single = False - self.inherits = inherits + + if inherits is not None: + self.inherits = _parse_mapper_argument(inherits) + else: + self.inherits = None + if local_table is not None: self.local_table = coercions.expect( roles.StrictFromClauseRole, local_table ) + elif self.inherits: + # note this is a new flow as of 2.0 so that + # .local_table need not be Optional + self.local_table = self.inherits.local_table + self.single = True else: - self.local_table = None + raise sa_exc.ArgumentError( + f"Mapper[{self.class_.__name__}(None)] has None for a " + "primary table argument and does not specify 'inherits'" + ) + + if inherit_condition is not None: + self.inherit_condition = coercions.expect( + roles.OnClauseRole, inherit_condition + ) + else: + self.inherit_condition = None - self.inherit_condition = inherit_condition self.inherit_foreign_keys = inherit_foreign_keys - self._init_properties = properties or {} + self._init_properties = dict(properties) if properties else {} self._delete_orphans = [] self.batch = batch self.eager_defaults = eager_defaults @@ -694,7 +767,10 @@ class Mapper( # while a configure_mappers() is occurring (and defer a # configure_mappers() until construction succeeds) with _CONFIGURE_MUTEX: - self.dispatch._events._new_mapper_instance(class_, self) + + cast("MapperEvents", self.dispatch._events)._new_mapper_instance( + class_, self + ) self._configure_inheritance() self._configure_class_instrumentation() self._configure_properties() @@ -704,16 +780,21 @@ class Mapper( self._log("constructed") self._expire_memoizations() - # major attributes initialized at the classlevel so that - # they can be Sphinx-documented. + def _gen_cache_key(self, anon_map, bindparams): + return (self,) + + # ### BEGIN + # ATTRIBUTE DECLARATIONS START HERE is_mapper = True """Part of the inspection API.""" represents_outer_join = False + registry: _RegistryType + @property - def mapper(self): + def mapper(self) -> Mapper[_O]: """Part of the inspection API. Returns self. @@ -721,9 +802,6 @@ class Mapper( """ return self - def _gen_cache_key(self, anon_map, bindparams): - return (self,) - @property def entity(self): r"""Part of the inspection API. @@ -733,49 +811,109 @@ class Mapper( """ return self.class_ - local_table = None - """The :class:`_expression.Selectable` which this :class:`_orm.Mapper` - manages. + class_: Type[_O] + """The class to which this :class:`_orm.Mapper` is mapped.""" + + _identity_class: Type[_O] + + _delete_orphans: List[Tuple[str, Type[Any]]] + _dependency_processors: List[DependencyProcessor] + _memoized_values: Dict[Any, Callable[[], Any]] + _inheriting_mappers: util.WeakSequence[Mapper[Any]] + _all_tables: Set[Table] + + _pks_by_table: Dict[FromClause, OrderedSet[ColumnClause[Any]]] + _cols_by_table: Dict[FromClause, OrderedSet[ColumnElement[Any]]] + + _props: util.OrderedDict[str, MapperProperty[Any]] + _init_properties: Dict[str, MapperProperty[Any]] + + _columntoproperty: _ColumnMapping + + _set_polymorphic_identity: Optional[Callable[[InstanceState[_O]], None]] + _validate_polymorphic_identity: Optional[ + Callable[[Mapper[_O], InstanceState[_O], _InstanceDict], None] + ] + + tables: Sequence[Table] + """A sequence containing the collection of :class:`_schema.Table` objects + which this :class:`_orm.Mapper` is aware of. + + If the mapper is mapped to a :class:`_expression.Join`, or an + :class:`_expression.Alias` + representing a :class:`_expression.Select`, the individual + :class:`_schema.Table` + objects that comprise the full construct will be represented here. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + validators: util.immutabledict[str, Tuple[str, Dict[str, Any]]] + """An immutable dictionary of attributes which have been decorated + using the :func:`_orm.validates` decorator. + + The dictionary contains string attribute names as keys + mapped to the actual validation method. + + """ + + always_refresh: bool + allow_partial_pks: bool + version_id_col: Optional[ColumnElement[Any]] + + with_polymorphic: Optional[ + Tuple[ + Union[Literal["*"], Sequence[Union["Mapper[Any]", Type[Any]]]], + Optional["FromClause"], + ] + ] + + version_id_generator: Optional[Union[Literal[False], Callable[[Any], Any]]] + + local_table: FromClause + """The immediate :class:`_expression.FromClause` which this + :class:`_orm.Mapper` refers towards. - Typically is an instance of :class:`_schema.Table` or - :class:`_expression.Alias`. - May also be ``None``. + Typically is an instance of :class:`_schema.Table`, may be any + :class:`.FromClause`. The "local" table is the selectable that the :class:`_orm.Mapper` is directly responsible for managing from an attribute access and flush perspective. For - non-inheriting mappers, the local table is the same as the - "mapped" table. For joined-table inheritance mappers, local_table - will be the particular sub-table of the overall "join" which - this :class:`_orm.Mapper` represents. If this mapper is a - single-table inheriting mapper, local_table will be ``None``. + non-inheriting mappers, :attr:`.Mapper.local_table` will be the same + as :attr:`.Mapper.persist_selectable`. For inheriting mappers, + :attr:`.Mapper.local_table` refers to the specific portion of + :attr:`.Mapper.persist_selectable` that includes the columns to which + this :class:`.Mapper` is loading/persisting, such as a particular + :class:`.Table` within a join. .. seealso:: :attr:`_orm.Mapper.persist_selectable`. + :attr:`_orm.Mapper.selectable`. + """ - persist_selectable = None - """The :class:`_expression.Selectable` to which this :class:`_orm.Mapper` + persist_selectable: FromClause + """The :class:`_expression.FromClause` to which this :class:`_orm.Mapper` is mapped. - Typically an instance of :class:`_schema.Table`, - :class:`_expression.Join`, or :class:`_expression.Alias`. - - The :attr:`_orm.Mapper.persist_selectable` is separate from - :attr:`_orm.Mapper.selectable` in that the former represents columns - that are mapped on this class or its superclasses, whereas the - latter may be a "polymorphic" selectable that contains additional columns - which are in fact mapped on subclasses only. + Typically is an instance of :class:`_schema.Table`, may be any + :class:`.FromClause`. - "persist selectable" is the "thing the mapper writes to" and - "selectable" is the "thing the mapper selects from". - - :attr:`_orm.Mapper.persist_selectable` is also separate from - :attr:`_orm.Mapper.local_table`, which represents the set of columns that - are locally mapped on this class directly. + The :attr:`_orm.Mapper.persist_selectable` is similar to + :attr:`.Mapper.local_table`, but represents the :class:`.FromClause` that + represents the inheriting class hierarchy overall in an inheritance + scenario. + :attr.`.Mapper.persist_selectable` is also separate from the + :attr:`.Mapper.selectable` attribute, the latter of which may be an + alternate subquery used for selecting columns. + :attr.`.Mapper.persist_selectable` is oriented towards columns that + will be written on a persist operation. .. seealso:: @@ -785,16 +923,15 @@ class Mapper( """ - inherits = None + inherits: Optional[Mapper[Any]] """References the :class:`_orm.Mapper` which this :class:`_orm.Mapper` inherits from, if any. - This is a *read only* attribute determined during mapper construction. - Behavior is undefined if directly modified. - """ - configured = False + inherit_condition: Optional[ColumnElement[bool]] + + configured: bool = False """Represent ``True`` if this :class:`_orm.Mapper` has been configured. This is a *read only* attribute determined during mapper construction. @@ -806,7 +943,7 @@ class Mapper( """ - concrete = None + concrete: bool """Represent ``True`` if this :class:`_orm.Mapper` is a concrete inheritance mapper. @@ -815,21 +952,6 @@ class Mapper( """ - tables = None - """An iterable containing the collection of :class:`_schema.Table` objects - which this :class:`_orm.Mapper` is aware of. - - If the mapper is mapped to a :class:`_expression.Join`, or an - :class:`_expression.Alias` - representing a :class:`_expression.Select`, the individual - :class:`_schema.Table` - objects that comprise the full construct will be represented here. - - This is a *read only* attribute determined during mapper construction. - Behavior is undefined if directly modified. - - """ - primary_key: Tuple[Column[Any], ...] """An iterable containing the collection of :class:`_schema.Column` objects @@ -854,14 +976,6 @@ class Mapper( """ - class_: Type[_O] - """The Python class which this :class:`_orm.Mapper` maps. - - This is a *read only* attribute determined during mapper construction. - Behavior is undefined if directly modified. - - """ - class_manager: ClassManager[_O] """The :class:`.ClassManager` which maintains event listeners and class-bound descriptors for this :class:`_orm.Mapper`. @@ -871,7 +985,7 @@ class Mapper( """ - single = None + single: bool """Represent ``True`` if this :class:`_orm.Mapper` is a single table inheritance mapper. @@ -882,7 +996,7 @@ class Mapper( """ - non_primary = None + non_primary: bool """Represent ``True`` if this :class:`_orm.Mapper` is a "non-primary" mapper, e.g. a mapper that is used only to select rows but not for persistence management. @@ -892,7 +1006,7 @@ class Mapper( """ - polymorphic_on = None + polymorphic_on: Optional[ColumnElement[Any]] """The :class:`_schema.Column` or SQL expression specified as the ``polymorphic_on`` argument for this :class:`_orm.Mapper`, within an inheritance scenario. @@ -906,7 +1020,7 @@ class Mapper( """ - polymorphic_map = None + polymorphic_map: Dict[Any, Mapper[Any]] """A mapping of "polymorphic identity" identifiers mapped to :class:`_orm.Mapper` instances, within an inheritance scenario. @@ -922,7 +1036,7 @@ class Mapper( """ - polymorphic_identity = None + polymorphic_identity: Optional[Any] """Represent an identifier which is matched against the :attr:`_orm.Mapper.polymorphic_on` column during result row loading. @@ -935,7 +1049,7 @@ class Mapper( """ - base_mapper = None + base_mapper: Mapper[Any] """The base-most :class:`_orm.Mapper` in an inheritance chain. In a non-inheriting scenario, this attribute will always be this @@ -948,7 +1062,7 @@ class Mapper( """ - columns = None + columns: ReadOnlyColumnCollection[str, Column[Any]] """A collection of :class:`_schema.Column` or other scalar expression objects maintained by this :class:`_orm.Mapper`. @@ -965,25 +1079,16 @@ class Mapper( """ - validators = None - """An immutable dictionary of attributes which have been decorated - using the :func:`_orm.validates` decorator. - - The dictionary contains string attribute names as keys - mapped to the actual validation method. - - """ - - c = None + c: ReadOnlyColumnCollection[str, Column[Any]] """A synonym for :attr:`_orm.Mapper.columns`.""" - @property + @util.non_memoized_property @util.deprecated("1.3", "Use .persist_selectable") def mapped_table(self): return self.persist_selectable @util.memoized_property - def _path_registry(self) -> PathRegistry: + def _path_registry(self) -> CachingEntityRegistry: return PathRegistry.per_mapper(self) def _configure_inheritance(self): @@ -994,8 +1099,6 @@ class Mapper( self._inheriting_mappers = util.WeakSequence() if self.inherits: - if isinstance(self.inherits, type): - self.inherits = class_mapper(self.inherits, configure=False) if not issubclass(self.class_, self.inherits.class_): raise sa_exc.ArgumentError( "Class '%s' does not inherit from '%s'" @@ -1011,11 +1114,9 @@ class Mapper( "only allowed from a %s mapper" % (np, self.class_.__name__, np) ) - # inherit_condition is optional. - if self.local_table is None: - self.local_table = self.inherits.local_table + + if self.single: self.persist_selectable = self.inherits.persist_selectable - self.single = True elif self.local_table is not self.inherits.local_table: if self.concrete: self.persist_selectable = self.local_table @@ -1068,6 +1169,7 @@ class Mapper( self.local_table.description, ) ) from afe + assert self.inherits.persist_selectable is not None self.persist_selectable = sql.join( self.inherits.persist_selectable, self.local_table, @@ -1149,6 +1251,7 @@ class Mapper( else: self._all_tables = set() self.base_mapper = self + assert self.local_table is not None self.persist_selectable = self.local_table if self.polymorphic_identity is not None: self.polymorphic_map[self.polymorphic_identity] = self @@ -1160,21 +1263,34 @@ class Mapper( % self ) - def _set_with_polymorphic(self, with_polymorphic): + def _set_with_polymorphic( + self, with_polymorphic: Optional[_WithPolymorphicArg] + ) -> None: if with_polymorphic == "*": self.with_polymorphic = ("*", None) elif isinstance(with_polymorphic, (tuple, list)): if isinstance(with_polymorphic[0], (str, tuple, list)): - self.with_polymorphic = with_polymorphic + self.with_polymorphic = cast( + """Tuple[ + Union[ + Literal["*"], + Sequence[Union["Mapper[Any]", Type[Any]]], + ], + Optional["FromClause"], + ]""", + with_polymorphic, + ) else: self.with_polymorphic = (with_polymorphic, None) elif with_polymorphic is not None: - raise sa_exc.ArgumentError("Invalid setting for with_polymorphic") + raise sa_exc.ArgumentError( + f"Invalid setting for with_polymorphic: {with_polymorphic!r}" + ) else: self.with_polymorphic = None if self.with_polymorphic and self.with_polymorphic[1] is not None: - self.with_polymorphic = ( + self.with_polymorphic = ( # type: ignore self.with_polymorphic[0], coercions.expect( roles.StrictFromClauseRole, @@ -1191,6 +1307,7 @@ class Mapper( if self.with_polymorphic is None: self._set_with_polymorphic((subcl,)) elif self.with_polymorphic[0] != "*": + assert isinstance(self.with_polymorphic[0], tuple) self._set_with_polymorphic( (self.with_polymorphic[0] + (subcl,), self.with_polymorphic[1]) ) @@ -1241,7 +1358,7 @@ class Mapper( # we expect that declarative has applied the class manager # already and set up a registry. if this is None, # this raises as of 2.0. - manager = attributes.manager_of_class(self.class_) + manager = attributes.opt_manager_of_class(self.class_) if self.non_primary: if not manager or not manager.is_mapped: @@ -1251,6 +1368,8 @@ class Mapper( "Mapper." % self.class_ ) self.class_manager = manager + + assert manager.registry is not None self.registry = manager.registry self._identity_class = manager.mapper._identity_class manager.registry._add_non_primary_mapper(self) @@ -1275,7 +1394,7 @@ class Mapper( manager = instrumentation.register_class( self.class_, mapper=self, - expired_attribute_loader=util.partial( + expired_attribute_loader=util.partial( # type: ignore loading.load_scalar_attributes, self ), # finalize flag means instrument the __init__ method @@ -1284,6 +1403,8 @@ class Mapper( ) self.class_manager = manager + + assert manager.registry is not None self.registry = manager.registry # The remaining members can be added by any mapper, @@ -1315,15 +1436,25 @@ class Mapper( {name: (method, validation_opts)} ) - def _set_dispose_flags(self): + def _set_dispose_flags(self) -> None: self.configured = True self._ready_for_configure = True self._dispose_called = True self.__dict__.pop("_configure_failed", None) - def _configure_pks(self): - self.tables = sql_util.find_tables(self.persist_selectable) + def _configure_pks(self) -> None: + self.tables = cast( + "List[Table]", sql_util.find_tables(self.persist_selectable) + ) + for t in self.tables: + if not isinstance(t, Table): + raise sa_exc.ArgumentError( + f"ORM mappings can only be made against schema-level " + f"Table objects, not TableClause; got " + f"tableclause {t.name !r}" + ) + self._all_tables.update(t for t in self.tables if isinstance(t, Table)) self._pks_by_table = {} self._cols_by_table = {} @@ -1335,16 +1466,16 @@ class Mapper( pk_cols = util.column_set(c for c in all_cols if c.primary_key) # identify primary key columns which are also mapped by this mapper. - tables = set(self.tables + [self.persist_selectable]) - self._all_tables.update(tables) - for t in tables: - if t.primary_key and pk_cols.issuperset(t.primary_key): + for fc in set(self.tables).union([self.persist_selectable]): + if fc.primary_key and pk_cols.issuperset(fc.primary_key): # ordering is important since it determines the ordering of # mapper.primary_key (and therefore query.get()) - self._pks_by_table[t] = util.ordered_column_set( - t.primary_key - ).intersection(pk_cols) - self._cols_by_table[t] = util.ordered_column_set(t.c).intersection( + self._pks_by_table[fc] = util.ordered_column_set( # type: ignore # noqa: E501 + fc.primary_key + ).intersection( + pk_cols + ) + self._cols_by_table[fc] = util.ordered_column_set(fc.c).intersection( # type: ignore # noqa: E501 all_cols ) @@ -1386,10 +1517,15 @@ class Mapper( self.primary_key = self.inherits.primary_key else: # determine primary key from argument or persist_selectable pks + primary_key: Collection[ColumnElement[Any]] + if self._primary_key_argument: primary_key = [ - self.persist_selectable.corresponding_column(c) - for c in self._primary_key_argument + cc if cc is not None else c + for cc, c in ( + (self.persist_selectable.corresponding_column(c), c) + for c in self._primary_key_argument + ) ] else: # if heuristically determined PKs, reduce to the minimal set @@ -1413,7 +1549,7 @@ class Mapper( # determine cols that aren't expressed within our tables; mark these # as "read only" properties which are refreshed upon INSERT/UPDATE - self._readonly_props = set( + self._readonly_props = { self._columntoproperty[col] for col in self._columntoproperty if self._columntoproperty[col] not in self._identity_key_props @@ -1421,12 +1557,12 @@ class Mapper( not hasattr(col, "table") or col.table not in self._cols_by_table ) - ) + } - def _configure_properties(self): + def _configure_properties(self) -> None: # TODO: consider using DedupeColumnCollection - self.columns = self.c = sql_base.ColumnCollection() + self.columns = self.c = sql_base.ColumnCollection() # type: ignore # object attribute names mapped to MapperProperty objects self._props = util.OrderedDict() @@ -1454,7 +1590,6 @@ class Mapper( continue column_key = (self.column_prefix or "") + column.key - if self._should_exclude( column.key, column_key, @@ -1542,6 +1677,7 @@ class Mapper( col = self.polymorphic_on if isinstance(col, schema.Column) and ( self.with_polymorphic is None + or self.with_polymorphic[1] is None or self.with_polymorphic[1].corresponding_column(col) is None ): @@ -1763,8 +1899,8 @@ class Mapper( self.columns.add(col, key) for col in prop.columns + prop._orig_columns: - for col in col.proxy_set: - self._columntoproperty[col] = prop + for proxy_col in col.proxy_set: + self._columntoproperty[proxy_col] = prop prop.key = key @@ -2033,7 +2169,9 @@ class Mapper( self._check_configure() return iter(self._props.values()) - def _mappers_from_spec(self, spec, selectable): + def _mappers_from_spec( + self, spec: Any, selectable: Optional[FromClause] + ) -> Sequence[Mapper[Any]]: """given a with_polymorphic() argument, return the set of mappers it represents. @@ -2044,7 +2182,7 @@ class Mapper( if spec == "*": mappers = list(self.self_and_descendants) elif spec: - mappers = set() + mapper_set = set() for m in util.to_list(spec): m = _class_to_mapper(m) if not m.isa(self): @@ -2053,10 +2191,10 @@ class Mapper( ) if selectable is None: - mappers.update(m.iterate_to_root()) + mapper_set.update(m.iterate_to_root()) else: - mappers.add(m) - mappers = [m for m in self.self_and_descendants if m in mappers] + mapper_set.add(m) + mappers = [m for m in self.self_and_descendants if m in mapper_set] else: mappers = [] @@ -2067,7 +2205,9 @@ class Mapper( mappers = [m for m in mappers if m.local_table in tables] return mappers - def _selectable_from_mappers(self, mappers, innerjoin): + def _selectable_from_mappers( + self, mappers: Iterable[Mapper[Any]], innerjoin: bool + ) -> FromClause: """given a list of mappers (assumed to be within this mapper's inheritance hierarchy), construct an outerjoin amongst those mapper's mapped tables. @@ -2098,13 +2238,13 @@ class Mapper( def _single_table_criterion(self): if self.single and self.inherits and self.polymorphic_on is not None: return self.polymorphic_on._annotate({"parentmapper": self}).in_( - m.polymorphic_identity for m in self.self_and_descendants + [m.polymorphic_identity for m in self.self_and_descendants] ) else: return None @HasMemoized.memoized_attribute - def _with_polymorphic_mappers(self): + def _with_polymorphic_mappers(self) -> Sequence[Mapper[Any]]: self._check_configure() if not self.with_polymorphic: @@ -2124,8 +2264,8 @@ class Mapper( """ self._check_configure() - @HasMemoized.memoized_attribute - def _with_polymorphic_selectable(self): + @HasMemoized_ro_memoized_attribute + def _with_polymorphic_selectable(self) -> FromClause: if not self.with_polymorphic: return self.persist_selectable @@ -2143,7 +2283,7 @@ class Mapper( """ - @HasMemoized.memoized_attribute + @HasMemoized_ro_memoized_attribute def _insert_cols_evaluating_none(self): return dict( ( @@ -2250,7 +2390,7 @@ class Mapper( @HasMemoized.memoized_instancemethod def __clause_element__(self): - annotations = { + annotations: Dict[str, Any] = { "entity_namespace": self, "parententity": self, "parentmapper": self, @@ -2290,7 +2430,7 @@ class Mapper( ) @property - def selectable(self): + def selectable(self) -> FromClause: """The :class:`_schema.FromClause` construct this :class:`_orm.Mapper` selects from by default. @@ -2302,8 +2442,11 @@ class Mapper( return self._with_polymorphic_selectable def _with_polymorphic_args( - self, spec=None, selectable=False, innerjoin=False - ): + self, + spec: Any = None, + selectable: Union[Literal[False, None], FromClause] = False, + innerjoin: bool = False, + ) -> Tuple[Sequence[Mapper[Any]], FromClause]: if selectable not in (None, False): selectable = coercions.expect( roles.StrictFromClauseRole, selectable, allow_select=True @@ -2357,7 +2500,7 @@ class Mapper( ] @HasMemoized.memoized_attribute - def _polymorphic_adapter(self): + def _polymorphic_adapter(self) -> Optional[sql_util.ColumnAdapter]: if self.with_polymorphic: return sql_util.ColumnAdapter( self.selectable, equivalents=self._equivalent_columns @@ -2394,7 +2537,7 @@ class Mapper( yield c @HasMemoized.memoized_attribute - def attrs(self) -> util.ReadOnlyProperties["MapperProperty"]: + def attrs(self) -> util.ReadOnlyProperties[MapperProperty[Any]]: """A namespace of all :class:`.MapperProperty` objects associated this mapper. @@ -2432,7 +2575,7 @@ class Mapper( return util.ReadOnlyProperties(self._props) @HasMemoized.memoized_attribute - def all_orm_descriptors(self): + def all_orm_descriptors(self) -> util.ReadOnlyProperties[InspectionAttr]: """A namespace of all :class:`.InspectionAttr` attributes associated with the mapped class. @@ -2503,7 +2646,7 @@ class Mapper( @HasMemoized.memoized_attribute @util.preload_module("sqlalchemy.orm.descriptor_props") - def synonyms(self): + def synonyms(self) -> util.ReadOnlyProperties[Synonym[Any]]: """Return a namespace of all :class:`.Synonym` properties maintained by this :class:`_orm.Mapper`. @@ -2523,7 +2666,7 @@ class Mapper( return self.class_ @HasMemoized.memoized_attribute - def column_attrs(self): + def column_attrs(self) -> util.ReadOnlyProperties[ColumnProperty[Any]]: """Return a namespace of all :class:`.ColumnProperty` properties maintained by this :class:`_orm.Mapper`. @@ -2536,9 +2679,9 @@ class Mapper( """ return self._filter_properties(properties.ColumnProperty) - @util.preload_module("sqlalchemy.orm.relationships") @HasMemoized.memoized_attribute - def relationships(self): + @util.preload_module("sqlalchemy.orm.relationships") + def relationships(self) -> util.ReadOnlyProperties[Relationship[Any]]: """A namespace of all :class:`.Relationship` properties maintained by this :class:`_orm.Mapper`. @@ -2567,7 +2710,7 @@ class Mapper( @HasMemoized.memoized_attribute @util.preload_module("sqlalchemy.orm.descriptor_props") - def composites(self): + def composites(self) -> util.ReadOnlyProperties[Composite[Any]]: """Return a namespace of all :class:`.Composite` properties maintained by this :class:`_orm.Mapper`. @@ -2582,7 +2725,9 @@ class Mapper( util.preloaded.orm_descriptor_props.Composite ) - def _filter_properties(self, type_): + def _filter_properties( + self, type_: Type[_MP] + ) -> util.ReadOnlyProperties[_MP]: self._check_configure() return util.ReadOnlyProperties( util.OrderedDict( @@ -2610,7 +2755,7 @@ class Mapper( ) @HasMemoized.memoized_attribute - def _equivalent_columns(self): + def _equivalent_columns(self) -> _EquivalentColumnMap: """Create a map of all equivalent columns, based on the determination of column pairs that are equated to one another based on inherit condition. This is designed @@ -2630,18 +2775,18 @@ class Mapper( } """ - result = util.column_dict() + result: _EquivalentColumnMap = {} def visit_binary(binary): if binary.operator == operators.eq: if binary.left in result: result[binary.left].add(binary.right) else: - result[binary.left] = util.column_set((binary.right,)) + result[binary.left] = {binary.right} if binary.right in result: result[binary.right].add(binary.left) else: - result[binary.right] = util.column_set((binary.left,)) + result[binary.right] = {binary.left} for mapper in self.base_mapper.self_and_descendants: if mapper.inherit_condition is not None: @@ -2711,13 +2856,13 @@ class Mapper( return False - def common_parent(self, other): + def common_parent(self, other: Mapper[Any]) -> bool: """Return true if the given mapper shares a common inherited parent as this mapper.""" return self.base_mapper is other.base_mapper - def is_sibling(self, other): + def is_sibling(self, other: Mapper[Any]) -> bool: """return true if the other mapper is an inheriting sibling to this one. common parent but different branch @@ -2728,7 +2873,9 @@ class Mapper( and not other.isa(self) ) - def _canload(self, state, allow_subtypes): + def _canload( + self, state: InstanceState[Any], allow_subtypes: bool + ) -> bool: s = self.primary_mapper() if self.polymorphic_on is not None or allow_subtypes: return _state_mapper(state).isa(s) @@ -2738,19 +2885,19 @@ class Mapper( def isa(self, other: Mapper[Any]) -> bool: """Return True if the this mapper inherits from the given mapper.""" - m = self + m: Optional[Mapper[Any]] = self while m and m is not other: m = m.inherits return bool(m) - def iterate_to_root(self): - m = self + def iterate_to_root(self) -> Iterator[Mapper[Any]]: + m: Optional[Mapper[Any]] = self while m: yield m m = m.inherits @HasMemoized.memoized_attribute - def self_and_descendants(self): + def self_and_descendants(self) -> Sequence[Mapper[Any]]: """The collection including this mapper and all descendant mappers. This includes not just the immediately inheriting mappers but @@ -2765,7 +2912,7 @@ class Mapper( stack.extend(item._inheriting_mappers) return util.WeakSequence(descendants) - def polymorphic_iterator(self): + def polymorphic_iterator(self) -> Iterator[Mapper[Any]]: """Iterate through the collection including this mapper and all descendant mappers. @@ -2778,18 +2925,18 @@ class Mapper( """ return iter(self.self_and_descendants) - def primary_mapper(self): + def primary_mapper(self) -> Mapper[Any]: """Return the primary mapper corresponding to this mapper's class key (class).""" return self.class_manager.mapper @property - def primary_base_mapper(self): + def primary_base_mapper(self) -> Mapper[Any]: return self.class_manager.mapper.base_mapper def _result_has_identity_key(self, result, adapter=None): - pk_cols = self.primary_key + pk_cols: Sequence[ColumnClause[Any]] = self.primary_key if adapter: pk_cols = [adapter.columns[c] for c in pk_cols] rk = result.keys() @@ -2799,25 +2946,35 @@ class Mapper( else: return True - def identity_key_from_row(self, row, identity_token=None, adapter=None): + def identity_key_from_row( + self, + row: Optional[Union[Row, RowMapping]], + identity_token: Optional[Any] = None, + adapter: Optional[ColumnAdapter] = None, + ) -> _IdentityKeyType[_O]: """Return an identity-map key for use in storing/retrieving an item from the identity map. - :param row: A :class:`.Row` instance. The columns which are - mapped by this :class:`_orm.Mapper` should be locatable in the row, - preferably via the :class:`_schema.Column` - object directly (as is the case - when a :func:`_expression.select` construct is executed), or - via string names of the form ``<tablename>_<colname>``. + :param row: A :class:`.Row` or :class:`.RowMapping` produced from a + result set that selected from the ORM mapped primary key columns. + + .. versionchanged:: 2.0 + :class:`.Row` or :class:`.RowMapping` are accepted + for the "row" argument """ - pk_cols = self.primary_key + pk_cols: Sequence[ColumnClause[Any]] = self.primary_key if adapter: pk_cols = [adapter.columns[c] for c in pk_cols] + if hasattr(row, "_mapping"): + mapping = row._mapping # type: ignore + else: + mapping = cast("Mapping[Any, Any]", row) + return ( self._identity_class, - tuple(row[column] for column in pk_cols), + tuple(mapping[column] for column in pk_cols), # type: ignore identity_token, ) @@ -2852,12 +3009,12 @@ class Mapper( """ state = attributes.instance_state(instance) - return self._identity_key_from_state(state, attributes.PASSIVE_OFF) + return self._identity_key_from_state(state, PassiveFlag.PASSIVE_OFF) def _identity_key_from_state( self, state: InstanceState[_O], - passive: PassiveFlag = attributes.PASSIVE_RETURN_NO_VALUE, + passive: PassiveFlag = PassiveFlag.PASSIVE_RETURN_NO_VALUE, ) -> _IdentityKeyType[_O]: dict_ = state.dict manager = state.manager @@ -2884,7 +3041,7 @@ class Mapper( """ state = attributes.instance_state(instance) identity_key = self._identity_key_from_state( - state, attributes.PASSIVE_OFF + state, PassiveFlag.PASSIVE_OFF ) return identity_key[1] @@ -2913,14 +3070,14 @@ class Mapper( @HasMemoized.memoized_attribute def _all_pk_cols(self): - collection = set() + collection: Set[ColumnClause[Any]] = set() for table in self.tables: collection.update(self._pks_by_table[table]) return collection @HasMemoized.memoized_attribute def _should_undefer_in_wildcard(self): - cols = set(self.primary_key) + cols: Set[ColumnElement[Any]] = set(self.primary_key) if self.polymorphic_on is not None: cols.add(self.polymorphic_on) return cols @@ -2951,11 +3108,11 @@ class Mapper( state = attributes.instance_state(obj) dict_ = attributes.instance_dict(obj) return self._get_committed_state_attr_by_column( - state, dict_, column, passive=attributes.PASSIVE_OFF + state, dict_, column, passive=PassiveFlag.PASSIVE_OFF ) def _get_committed_state_attr_by_column( - self, state, dict_, column, passive=attributes.PASSIVE_RETURN_NO_VALUE + self, state, dict_, column, passive=PassiveFlag.PASSIVE_RETURN_NO_VALUE ): prop = self._columntoproperty[column] @@ -2978,7 +3135,7 @@ class Mapper( col_attribute_names = set(attribute_names).intersection( state.mapper.column_attrs.keys() ) - tables = set( + tables: Set[FromClause] = set( chain( *[ sql_util.find_tables(c, check_columns=True) @@ -3002,7 +3159,7 @@ class Mapper( state, state.dict, leftcol, - passive=attributes.PASSIVE_NO_INITIALIZE, + passive=PassiveFlag.PASSIVE_NO_INITIALIZE, ) if leftval in orm_util._none_set: raise _OptGetColumnsNotAvailable() @@ -3014,7 +3171,7 @@ class Mapper( state, state.dict, rightcol, - passive=attributes.PASSIVE_NO_INITIALIZE, + passive=PassiveFlag.PASSIVE_NO_INITIALIZE, ) if rightval in orm_util._none_set: raise _OptGetColumnsNotAvailable() @@ -3022,7 +3179,7 @@ class Mapper( None, rightval, type_=binary.right.type ) - allconds = [] + allconds: List[ColumnElement[bool]] = [] start = False @@ -3035,6 +3192,9 @@ class Mapper( elif not isinstance(mapper.local_table, expression.TableClause): return None if start and not mapper.single: + assert mapper.inherits + assert not mapper.concrete + assert mapper.inherit_condition is not None allconds.append(mapper.inherit_condition) tables.add(mapper.local_table) @@ -3043,11 +3203,13 @@ class Mapper( # descendant-most class should all be present and joined to each # other. try: - allconds[0] = visitors.cloned_traverse( + _traversed = visitors.cloned_traverse( allconds[0], {}, {"binary": visit_binary} ) except _OptGetColumnsNotAvailable: return None + else: + allconds[0] = _traversed cond = sql.and_(*allconds) @@ -3145,6 +3307,8 @@ class Mapper( for pk in self.primary_key ] + in_expr: ColumnElement[Any] + if len(primary_key) > 1: in_expr = sql.tuple_(*primary_key) else: @@ -3209,11 +3373,22 @@ class Mapper( traverse all objects without relying on cascades. """ - visited_states = set() + visited_states: Set[InstanceState[Any]] = set() prp, mpp = object(), object() assert state.mapper.isa(self) + # this is actually a recursive structure, fully typing it seems + # a little too difficult for what it's worth here + visitables: Deque[ + Tuple[ + Deque[Any], + object, + Optional[InstanceState[Any]], + Optional[_InstanceDict], + ] + ] + visitables = deque( [(deque(state.mapper._props.values()), prp, state, state.dict)] ) @@ -3226,8 +3401,10 @@ class Mapper( if item_type is prp: prop = iterator.popleft() - if type_ not in prop.cascade: + if not prop.cascade or type_ not in prop.cascade: continue + assert parent_state is not None + assert parent_dict is not None queue = deque( prop.cascade_iterator( type_, @@ -3267,7 +3444,7 @@ class Mapper( @HasMemoized.memoized_attribute def _sorted_tables(self): - table_to_mapper = {} + table_to_mapper: Dict[Table, Mapper[Any]] = {} for mapper in self.base_mapper.self_and_descendants: for t in mapper.tables: @@ -3316,9 +3493,9 @@ class Mapper( ret[t] = table_to_mapper[t] return ret - def _memo(self, key, callable_): + def _memo(self, key: Any, callable_: Callable[[], _T]) -> _T: if key in self._memoized_values: - return self._memoized_values[key] + return cast(_T, self._memoized_values[key]) else: self._memoized_values[key] = value = callable_() return value @@ -3328,14 +3505,22 @@ class Mapper( """memoized map of tables to collections of columns to be synchronized upwards to the base mapper.""" - result = util.defaultdict(list) + result: util.defaultdict[ + Table, + List[ + Tuple[ + Mapper[Any], + List[Tuple[ColumnElement[Any], ColumnElement[Any]]], + ] + ], + ] = util.defaultdict(list) for table in self._sorted_tables: cols = set(table.c) for m in self.iterate_to_root(): if m._inherits_equated_pairs and cols.intersection( reduce( - set.union, + set.union, # type: ignore [l.proxy_set for l, r in m._inherits_equated_pairs], ) ): @@ -3440,7 +3625,7 @@ def _configure_registries(registries, cascade): else: return - Mapper.dispatch._for_class(Mapper).before_configured() + Mapper.dispatch._for_class(Mapper).before_configured() # type: ignore # noqa: E501 # initialize properties on all mappers # note that _mapper_registry is unordered, which # may randomly conceal/reveal issues related to @@ -3449,7 +3634,7 @@ def _configure_registries(registries, cascade): _do_configure_registries(registries, cascade) finally: _already_compiling = False - Mapper.dispatch._for_class(Mapper).after_configured() + Mapper.dispatch._for_class(Mapper).after_configured() # type: ignore @util.preload_module("sqlalchemy.orm.decl_api") @@ -3480,7 +3665,7 @@ def _do_configure_registries(registries, cascade): "Original exception was: %s" % (mapper, mapper._configure_failed) ) - e._configure_failed = mapper._configure_failed + e._configure_failed = mapper._configure_failed # type: ignore raise e if not mapper.configured: @@ -3636,7 +3821,7 @@ def _event_on_init(state, args, kwargs): instrumenting_mapper._set_polymorphic_identity(state) -class _ColumnMapping(dict): +class _ColumnMapping(Dict["ColumnElement[Any]", "MapperProperty[Any]"]): """Error reporting helper for mapper._columntoproperty.""" __slots__ = ("mapper",) diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py index e2cf1d5b0..361cea975 100644 --- a/lib/sqlalchemy/orm/path_registry.py +++ b/lib/sqlalchemy/orm/path_registry.py @@ -13,22 +13,70 @@ from __future__ import annotations from functools import reduce from itertools import chain import logging +import operator from typing import Any +from typing import cast +from typing import Dict +from typing import Iterator +from typing import List +from typing import Optional +from typing import overload from typing import Sequence from typing import Tuple +from typing import TYPE_CHECKING from typing import Union from . import base as orm_base +from ._typing import insp_is_mapper_property from .. import exc -from .. import inspection from .. import util from ..sql import visitors from ..sql.cache_key import HasCacheKey +if TYPE_CHECKING: + from ._typing import _InternalEntityType + from .interfaces import MapperProperty + from .mapper import Mapper + from .relationships import Relationship + from .util import AliasedInsp + from ..sql.cache_key import _CacheKeyTraversalType + from ..sql.elements import BindParameter + from ..sql.visitors import anon_map + from ..util.typing import TypeGuard + + def is_root(path: PathRegistry) -> TypeGuard[RootRegistry]: + ... + + def is_entity(path: PathRegistry) -> TypeGuard[AbstractEntityRegistry]: + ... + +else: + is_root = operator.attrgetter("is_root") + is_entity = operator.attrgetter("is_entity") + + +_SerializedPath = List[Any] + +_PathElementType = Union[ + str, "_InternalEntityType[Any]", "MapperProperty[Any]" +] + +# the representation is in fact +# a tuple with alternating: +# [_InternalEntityType[Any], Union[str, MapperProperty[Any]], +# _InternalEntityType[Any], Union[str, MapperProperty[Any]], ...] +# this might someday be a tuple of 2-tuples instead, but paths can be +# chopped at odd intervals as well so this is less flexible +_PathRepresentation = Tuple[_PathElementType, ...] + +_OddPathRepresentation = Sequence["_InternalEntityType[Any]"] +_EvenPathRepresentation = Sequence[Union["MapperProperty[Any]", str]] + + log = logging.getLogger(__name__) -def _unreduce_path(path): +def _unreduce_path(path: _SerializedPath) -> PathRegistry: return PathRegistry.deserialize(path) @@ -67,17 +115,18 @@ class PathRegistry(HasCacheKey): is_token = False is_root = False has_entity = False + is_entity = False - path: Tuple - natural_path: Tuple - parent: Union["PathRegistry", None] + path: _PathRepresentation + natural_path: _PathRepresentation + parent: Optional[PathRegistry] + root: RootRegistry - root: "PathRegistry" - _cache_key_traversal = [ + _cache_key_traversal: _CacheKeyTraversalType = [ ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key_list) ] - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: try: return other is not None and self.path == other._path_for_compare except AttributeError: @@ -87,7 +136,7 @@ class PathRegistry(HasCacheKey): ) return False - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: try: return other is None or self.path != other._path_for_compare except AttributeError: @@ -98,74 +147,88 @@ class PathRegistry(HasCacheKey): return True @property - def _path_for_compare(self): + def _path_for_compare(self) -> Optional[_PathRepresentation]: return self.path - def set(self, attributes, key, value): + def set(self, attributes: Dict[Any, Any], key: Any, value: Any) -> None: log.debug("set '%s' on path '%s' to '%s'", key, self, value) attributes[(key, self.natural_path)] = value - def setdefault(self, attributes, key, value): + def setdefault( + self, attributes: Dict[Any, Any], key: Any, value: Any + ) -> None: log.debug("setdefault '%s' on path '%s' to '%s'", key, self, value) attributes.setdefault((key, self.natural_path), value) - def get(self, attributes, key, value=None): + def get( + self, attributes: Dict[Any, Any], key: Any, value: Optional[Any] = None + ) -> Any: key = (key, self.natural_path) if key in attributes: return attributes[key] else: return value - def __len__(self): + def __len__(self) -> int: return len(self.path) - def __hash__(self): + def __hash__(self) -> int: return id(self) - def __getitem__(self, key: Any) -> "PathRegistry": + def __getitem__(self, key: Any) -> PathRegistry: raise NotImplementedError() + # TODO: what are we using this for? @property - def length(self): + def length(self) -> int: return len(self.path) - def pairs(self): - path = self.path - for i in range(0, len(path), 2): - yield path[i], path[i + 1] - - def contains_mapper(self, mapper): - for path_mapper in [self.path[i] for i in range(0, len(self.path), 2)]: + def pairs( + self, + ) -> Iterator[ + Tuple[_InternalEntityType[Any], Union[str, MapperProperty[Any]]] + ]: + odd_path = cast(_OddPathRepresentation, self.path) + even_path = cast(_EvenPathRepresentation, odd_path) + for i in range(0, len(odd_path), 2): + yield odd_path[i], even_path[i + 1] + + def contains_mapper(self, mapper: Mapper[Any]) -> bool: + _m_path = cast(_OddPathRepresentation, self.path) + for path_mapper in [_m_path[i] for i in range(0, len(_m_path), 2)]: if path_mapper.is_mapper and path_mapper.isa(mapper): return True else: return False - def contains(self, attributes, key): + def contains(self, attributes: Dict[Any, Any], key: Any) -> bool: return (key, self.path) in attributes - def __reduce__(self): + def __reduce__(self) -> Any: return _unreduce_path, (self.serialize(),) @classmethod - def _serialize_path(cls, path): + def _serialize_path(cls, path: _PathRepresentation) -> _SerializedPath: + _m_path = cast(_OddPathRepresentation, path) + _p_path = cast(_EvenPathRepresentation, path) + return list( zip( - [ + tuple( m.class_ if (m.is_mapper or m.is_aliased_class) else str(m) - for m in [path[i] for i in range(0, len(path), 2)] - ], - [ - path[i].key if (path[i].is_property) else str(path[i]) - for i in range(1, len(path), 2) - ] - + [None], + for m in [_m_path[i] for i in range(0, len(_m_path), 2)] + ), + tuple( + p.key if insp_is_mapper_property(p) else str(p) + for p in [_p_path[i] for i in range(1, len(_p_path), 2)] + ) + + (None,), ) ) @classmethod - def _deserialize_path(cls, path): - def _deserialize_mapper_token(mcls): + def _deserialize_path(cls, path: _SerializedPath) -> _PathRepresentation: + def _deserialize_mapper_token(mcls: Any) -> Any: return ( # note: we likely dont want configure=True here however # this is maintained at the moment for backwards compatibility @@ -174,15 +237,15 @@ class PathRegistry(HasCacheKey): else PathToken._intern[mcls] ) - def _deserialize_key_token(mcls, key): + def _deserialize_key_token(mcls: Any, key: Any) -> Any: if key is None: return None elif key in PathToken._intern: return PathToken._intern[key] else: - return orm_base._inspect_mapped_class( - mcls, configure=True - ).attrs[key] + mp = orm_base._inspect_mapped_class(mcls, configure=True) + assert mp is not None + return mp.attrs[key] p = tuple( chain( @@ -199,28 +262,63 @@ class PathRegistry(HasCacheKey): p = p[0:-1] return p - def serialize(self) -> Sequence[Any]: + def serialize(self) -> _SerializedPath: path = self.path return self._serialize_path(path) @classmethod - def deserialize(cls, path: Sequence[Any]) -> PathRegistry: + def deserialize(cls, path: _SerializedPath) -> PathRegistry: assert path is not None p = cls._deserialize_path(path) return cls.coerce(p) + @overload @classmethod - def per_mapper(cls, mapper): + def per_mapper(cls, mapper: Mapper[Any]) -> CachingEntityRegistry: + ... + + @overload + @classmethod + def per_mapper(cls, mapper: AliasedInsp[Any]) -> SlotsEntityRegistry: + ... + + @classmethod + def per_mapper( + cls, mapper: _InternalEntityType[Any] + ) -> AbstractEntityRegistry: if mapper.is_mapper: return CachingEntityRegistry(cls.root, mapper) else: return SlotsEntityRegistry(cls.root, mapper) @classmethod - def coerce(cls, raw): - return reduce(lambda prev, next: prev[next], raw, cls.root) + def coerce(cls, raw: _PathRepresentation) -> PathRegistry: + def _red(prev: PathRegistry, next_: _PathElementType) -> PathRegistry: + return prev[next_] + + # can't quite get mypy to appreciate this one :) + return reduce(_red, raw, cls.root) # type: ignore + + def __add__(self, other: PathRegistry) -> PathRegistry: + def _red(prev: PathRegistry, next_: _PathElementType) -> PathRegistry: + return prev[next_] - def token(self, token): + return reduce(_red, other.path, self) + + def __str__(self) -> str: + return f"ORM Path[{' -> '.join(str(elem) for elem in self.path)}]" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.path!r})" + + +class CreatesToken(PathRegistry): + __slots__ = () + + is_aliased_class: bool + is_root: bool + + def token(self, token: str) -> TokenRegistry: if token.endswith(f":{_WILDCARD_TOKEN}"): return TokenRegistry(self, token) elif token.endswith(f":{_DEFAULT_TOKEN}"): @@ -228,34 +326,47 @@ class PathRegistry(HasCacheKey): else: raise exc.ArgumentError(f"invalid token: {token}") - def __add__(self, other): - return reduce(lambda prev, next: prev[next], other.path, self) - - def __str__(self): - return f"ORM Path[{' -> '.join(str(elem) for elem in self.path)}]" - - def __repr__(self): - return f"{self.__class__.__name__}({self.path!r})" - -class RootRegistry(PathRegistry): +class RootRegistry(CreatesToken): """Root registry, defers to mappers so that paths are maintained per-root-mapper. """ + __slots__ = () + inherit_cache = True path = natural_path = () has_entity = False is_aliased_class = False is_root = True + is_unnatural = False + + @overload + def __getitem__(self, entity: str) -> TokenRegistry: + ... + + @overload + def __getitem__( + self, entity: _InternalEntityType[Any] + ) -> AbstractEntityRegistry: + ... - def __getitem__(self, entity): + def __getitem__( + self, entity: Union[str, _InternalEntityType[Any]] + ) -> Union[TokenRegistry, AbstractEntityRegistry]: if entity in PathToken._intern: + if TYPE_CHECKING: + assert isinstance(entity, str) return TokenRegistry(self, PathToken._intern[entity]) else: - return inspection.inspect(entity)._path_registry + try: + return entity._path_registry # type: ignore + except AttributeError: + raise IndexError( + f"invalid argument for RootRegistry.__getitem__: {entity}" + ) PathRegistry.root = RootRegistry() @@ -264,17 +375,19 @@ PathRegistry.root = RootRegistry() class PathToken(orm_base.InspectionAttr, HasCacheKey, str): """cacheable string token""" - _intern = {} + _intern: Dict[str, PathToken] = {} - def _gen_cache_key(self, anon_map, bindparams): + def _gen_cache_key( + self, anon_map: anon_map, bindparams: List[BindParameter[Any]] + ) -> Tuple[Any, ...]: return (str(self),) @property - def _path_for_compare(self): + def _path_for_compare(self) -> Optional[_PathRepresentation]: return None @classmethod - def intern(cls, strvalue): + def intern(cls, strvalue: str) -> PathToken: if strvalue in cls._intern: return cls._intern[strvalue] else: @@ -287,7 +400,10 @@ class TokenRegistry(PathRegistry): inherit_cache = True - def __init__(self, parent, token): + token: str + parent: CreatesToken + + def __init__(self, parent: CreatesToken, token: str): token = PathToken.intern(token) self.token = token @@ -299,21 +415,33 @@ class TokenRegistry(PathRegistry): is_token = True - def generate_for_superclasses(self): - if not self.parent.is_aliased_class and not self.parent.is_root: - for ent in self.parent.mapper.iterate_to_root(): - yield TokenRegistry(self.parent.parent[ent], self.token) + def generate_for_superclasses(self) -> Iterator[PathRegistry]: + parent = self.parent + if is_root(parent): + yield self + return + + if TYPE_CHECKING: + assert isinstance(parent, AbstractEntityRegistry) + if not parent.is_aliased_class: + for mp_ent in parent.mapper.iterate_to_root(): + yield TokenRegistry(parent.parent[mp_ent], self.token) elif ( - self.parent.is_aliased_class - and self.parent.entity._is_with_polymorphic + parent.is_aliased_class + and cast( + "AliasedInsp[Any]", + parent.entity, + )._is_with_polymorphic ): yield self - for ent in self.parent.entity._with_polymorphic_entities: - yield TokenRegistry(self.parent.parent[ent], self.token) + for ent in cast( + "AliasedInsp[Any]", parent.entity + )._with_polymorphic_entities: + yield TokenRegistry(parent.parent[ent], self.token) else: yield self - def __getitem__(self, entity): + def __getitem__(self, entity: Any) -> Any: try: return self.path[entity] except TypeError as err: @@ -321,23 +449,42 @@ class TokenRegistry(PathRegistry): class PropRegistry(PathRegistry): - is_unnatural = False + __slots__ = ( + "prop", + "parent", + "path", + "natural_path", + "has_entity", + "entity", + "mapper", + "_wildcard_path_loader_key", + "_default_path_loader_key", + "_loader_key", + "is_unnatural", + ) inherit_cache = True - def __init__(self, parent, prop): + prop: MapperProperty[Any] + mapper: Optional[Mapper[Any]] + entity: Optional[_InternalEntityType[Any]] + + def __init__( + self, parent: AbstractEntityRegistry, prop: MapperProperty[Any] + ): # restate this path in terms of the # given MapperProperty's parent. - insp = inspection.inspect(parent[-1]) - natural_parent = parent + insp = cast("_InternalEntityType[Any]", parent[-1]) + natural_parent: AbstractEntityRegistry = parent + self.is_unnatural = False - if not insp.is_aliased_class or insp._use_mapper_path: + if not insp.is_aliased_class or insp._use_mapper_path: # type: ignore parent = natural_parent = parent.parent[prop.parent] elif ( insp.is_aliased_class and insp.with_polymorphic_mappers and prop.parent in insp.with_polymorphic_mappers ): - subclass_entity = parent[-1]._entity_for_mapper(prop.parent) + subclass_entity: _InternalEntityType[Any] = parent[-1]._entity_for_mapper(prop.parent) # type: ignore # noqa: E501 parent = parent.parent[subclass_entity] # when building a path where with_polymorphic() is in use, @@ -388,43 +535,74 @@ class PropRegistry(PathRegistry): self.parent = parent self.path = parent.path + (prop,) self.natural_path = natural_parent.natural_path + (prop,) + self.has_entity = prop._links_to_entity + if prop._is_relationship: + if TYPE_CHECKING: + assert isinstance(prop, Relationship) + self.entity = prop.entity + self.mapper = prop.mapper + else: + self.entity = None + self.mapper = None self._wildcard_path_loader_key = ( "loader", - parent.path + self.prop._wildcard_token, + parent.path + self.prop._wildcard_token, # type: ignore ) self._default_path_loader_key = self.prop._default_path_loader_key self._loader_key = ("loader", self.natural_path) - @util.memoized_property - def has_entity(self): - return self.prop._links_to_entity + @property + def entity_path(self) -> AbstractEntityRegistry: + assert self.entity is not None + return self[self.entity] - @util.memoized_property - def entity(self): - return self.prop.entity + @overload + def __getitem__(self, entity: slice) -> _PathRepresentation: + ... - @property - def mapper(self): - return self.prop.mapper + @overload + def __getitem__(self, entity: int) -> _PathElementType: + ... - @property - def entity_path(self): - return self[self.entity] + @overload + def __getitem__( + self, entity: _InternalEntityType[Any] + ) -> AbstractEntityRegistry: + ... - def __getitem__(self, entity): + def __getitem__( + self, entity: Union[int, slice, _InternalEntityType[Any]] + ) -> Union[AbstractEntityRegistry, _PathElementType, _PathRepresentation]: if isinstance(entity, (int, slice)): return self.path[entity] else: return SlotsEntityRegistry(self, entity) -class AbstractEntityRegistry(PathRegistry): - __slots__ = () +class AbstractEntityRegistry(CreatesToken): + __slots__ = ( + "key", + "parent", + "is_aliased_class", + "path", + "entity", + "natural_path", + ) has_entity = True - - def __init__(self, parent, entity): + is_entity = True + + parent: Union[RootRegistry, PropRegistry] + key: _InternalEntityType[Any] + entity: _InternalEntityType[Any] + is_aliased_class: bool + + def __init__( + self, + parent: Union[RootRegistry, PropRegistry], + entity: _InternalEntityType[Any], + ): self.key = entity self.parent = parent self.is_aliased_class = entity.is_aliased_class @@ -447,11 +625,11 @@ class AbstractEntityRegistry(PathRegistry): if parent.path and (self.is_aliased_class or parent.is_unnatural): # this is an infrequent code path used only for loader strategies # that also make use of of_type(). - if entity.mapper.isa(parent.natural_path[-1].entity): + if entity.mapper.isa(parent.natural_path[-1].entity): # type: ignore # noqa: E501 self.natural_path = parent.natural_path + (entity.mapper,) else: self.natural_path = parent.natural_path + ( - parent.natural_path[-1].entity, + parent.natural_path[-1].entity, # type: ignore ) # it seems to make sense that since these paths get mixed up # with statements that are cached or not, we should make @@ -465,19 +643,35 @@ class AbstractEntityRegistry(PathRegistry): self.natural_path = self.path @property - def entity_path(self): + def entity_path(self) -> PathRegistry: return self @property - def mapper(self): - return inspection.inspect(self.entity).mapper + def mapper(self) -> Mapper[Any]: + return self.entity.mapper - def __bool__(self): + def __bool__(self) -> bool: return True - __nonzero__ = __bool__ + @overload + def __getitem__(self, entity: MapperProperty[Any]) -> PropRegistry: + ... + + @overload + def __getitem__(self, entity: str) -> TokenRegistry: + ... + + @overload + def __getitem__(self, entity: int) -> _PathElementType: + ... - def __getitem__(self, entity): + @overload + def __getitem__(self, entity: slice) -> _PathRepresentation: + ... + + def __getitem__( + self, entity: Any + ) -> Union[_PathElementType, _PathRepresentation, PathRegistry]: if isinstance(entity, (int, slice)): return self.path[entity] elif entity in PathToken._intern: @@ -491,31 +685,40 @@ class SlotsEntityRegistry(AbstractEntityRegistry): # version inherit_cache = True - __slots__ = ( - "key", - "parent", - "is_aliased_class", - "entity", - "path", - "natural_path", - ) + +class _ERDict(Dict[Any, Any]): + def __init__(self, registry: CachingEntityRegistry): + self.registry = registry + + def __missing__(self, key: Any) -> PropRegistry: + self[key] = item = PropRegistry(self.registry, key) + + return item -class CachingEntityRegistry(AbstractEntityRegistry, dict): +class CachingEntityRegistry(AbstractEntityRegistry): # for long lived mapper, return dict based caching # version that creates reference cycles + __slots__ = ("_cache",) + inherit_cache = True - def __getitem__(self, entity): + def __init__( + self, + parent: Union[RootRegistry, PropRegistry], + entity: _InternalEntityType[Any], + ): + super().__init__(parent, entity) + self._cache = _ERDict(self) + + def pop(self, key: Any, default: Any) -> Any: + return self._cache.pop(key, default) + + def __getitem__(self, entity: Any) -> Any: if isinstance(entity, (int, slice)): return self.path[entity] elif isinstance(entity, PathToken): return TokenRegistry(self, entity) else: - return dict.__getitem__(self, entity) - - def __missing__(self, key): - self[key] = item = PropRegistry(self, key) - - return item + return self._cache[entity] diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index c01825b6d..9f37e8457 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -19,6 +19,8 @@ from typing import cast from typing import List from typing import Optional from typing import Set +from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar from . import attributes @@ -38,17 +40,22 @@ from .util import _orm_full_deannotate from .. import exc as sa_exc from .. import ForeignKey from .. import log -from .. import sql from .. import util from ..sql import coercions from ..sql import roles from ..sql import sqltypes from ..sql.schema import Column +from ..sql.schema import SchemaConst from ..util.typing import de_optionalize_union_types from ..util.typing import de_stringify_annotation from ..util.typing import is_fwd_ref from ..util.typing import NoneType +if TYPE_CHECKING: + from ._typing import _ORMColumnExprArgument + from ..sql._typing import _InfoType + from ..sql.elements import ColumnElement + _T = TypeVar("_T", bound=Any) _PT = TypeVar("_PT", bound=Any) @@ -78,6 +85,10 @@ class ColumnProperty( inherit_cache = True _links_to_entity = False + columns: List[ColumnElement[Any]] + + _is_polymorphic_discriminator: bool + __slots__ = ( "_orig_columns", "columns", @@ -99,7 +110,19 @@ class ColumnProperty( ) def __init__( - self, column: sql.ColumnElement[_T], *additional_columns, **kwargs + self, + column: _ORMColumnExprArgument[_T], + *additional_columns: _ORMColumnExprArgument[Any], + group: Optional[str] = None, + deferred: bool = False, + raiseload: bool = False, + comparator_factory: Optional[Type[PropComparator]] = None, + descriptor: Optional[Any] = None, + active_history: bool = False, + expire_on_flush: bool = True, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, + _instrument: bool = True, ): super(ColumnProperty, self).__init__() columns = (column,) + additional_columns @@ -112,23 +135,24 @@ class ColumnProperty( ) for c in columns ] - self.parent = self.key = None - self.group = kwargs.pop("group", None) - self.deferred = kwargs.pop("deferred", False) - self.raiseload = kwargs.pop("raiseload", False) - self.instrument = kwargs.pop("_instrument", True) - self.comparator_factory = kwargs.pop( - "comparator_factory", self.__class__.Comparator + self.group = group + self.deferred = deferred + self.raiseload = raiseload + self.instrument = _instrument + self.comparator_factory = ( + comparator_factory + if comparator_factory is not None + else self.__class__.Comparator ) - self.descriptor = kwargs.pop("descriptor", None) - self.active_history = kwargs.pop("active_history", False) - self.expire_on_flush = kwargs.pop("expire_on_flush", True) + self.descriptor = descriptor + self.active_history = active_history + self.expire_on_flush = expire_on_flush - if "info" in kwargs: - self.info = kwargs.pop("info") + if info is not None: + self.info = info - if "doc" in kwargs: - self.doc = kwargs.pop("doc") + if doc is not None: + self.doc = doc else: for col in reversed(self.columns): doc = getattr(col, "doc", None) @@ -138,12 +162,6 @@ class ColumnProperty( else: self.doc = None - if kwargs: - raise TypeError( - "%s received unexpected keyword argument(s): %s" - % (self.__class__.__name__, ", ".join(sorted(kwargs.keys()))) - ) - util.set_creation_order(self) self.strategy_key = ( @@ -445,7 +463,10 @@ class MappedColumn( self.deferred = kw.pop("deferred", False) self.column = cast("Column[_T]", Column(*arg, **kw)) self.foreign_keys = self.column.foreign_keys - self._has_nullable = "nullable" in kw + self._has_nullable = "nullable" in kw and kw.get("nullable") not in ( + None, + SchemaConst.NULL_UNSPECIFIED, + ) util.set_creation_order(self) def _copy(self, **kw): diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index a754bd4f2..395d01a1e 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -30,6 +30,7 @@ from typing import Optional from typing import Tuple from typing import TYPE_CHECKING from typing import TypeVar +from typing import Union from . import exc as orm_exc from . import interfaces @@ -77,6 +78,8 @@ from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL if TYPE_CHECKING: from ..sql.selectable import _SetupJoinsElement + from ..sql.selectable import Alias + from ..sql.selectable import Subquery __all__ = ["Query", "QueryContext"] @@ -2769,14 +2772,14 @@ class AliasOption(interfaces.LoaderOption): "for entities to be matched up to a query that is established " "via :meth:`.Query.from_statement` and now does nothing.", ) - def __init__(self, alias): + def __init__(self, alias: Union[Alias, Subquery]): r"""Return a :class:`.MapperOption` that will indicate to the :class:`_query.Query` that the main table has been aliased. """ - def process_compile_state(self, compile_state): + def process_compile_state(self, compile_state: ORMCompileState): pass diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 58c7c4efd..66021c9c2 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -21,7 +21,10 @@ import re import typing from typing import Any from typing import Callable +from typing import Dict from typing import Optional +from typing import Sequence +from typing import Tuple from typing import Type from typing import TypeVar from typing import Union @@ -30,6 +33,7 @@ import weakref from . import attributes from . import strategy_options from .base import _is_mapped_class +from .base import class_mapper from .base import state_str from .interfaces import _IntrospectsAnnotations from .interfaces import MANYTOMANY @@ -53,7 +57,9 @@ from ..sql import expression from ..sql import operators from ..sql import roles from ..sql import visitors -from ..sql.elements import SQLCoreOperations +from ..sql._typing import _ColumnExpressionArgument +from ..sql._typing import _HasClauseElement +from ..sql.elements import ColumnClause from ..sql.util import _deep_deannotate from ..sql.util import _shallow_annotate from ..sql.util import adapt_criterion_to_null @@ -61,11 +67,14 @@ from ..sql.util import ClauseAdapter from ..sql.util import join_condition from ..sql.util import selectables_overlap from ..sql.util import visit_binary_product +from ..util.typing import Literal if typing.TYPE_CHECKING: + from ._typing import _EntityType from .mapper import Mapper from .util import AliasedClass from .util import AliasedInsp + from ..sql.elements import ColumnElement _T = TypeVar("_T", bound=Any) _PT = TypeVar("_PT", bound=Any) @@ -81,6 +90,34 @@ _RelationshipArgumentType = Union[ Callable[[], "AliasedClass[_T]"], ] +_LazyLoadArgumentType = Literal[ + "select", + "joined", + "selectin", + "subquery", + "raise", + "raise_on_sql", + "noload", + "immediate", + "dynamic", + True, + False, + None, +] + + +_RelationshipJoinConditionArgument = Union[ + str, _ColumnExpressionArgument[bool] +] +_ORMOrderByArgument = Union[ + Literal[False], str, _ColumnExpressionArgument[Any] +] +_ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]] +_ORMColCollectionArgument = Union[ + str, + Sequence[Union[ColumnClause[Any], _HasClauseElement, roles.DMLColumnRole]], +] + def remote(expr): """Annotate a portion of a primaryjoin expression @@ -144,6 +181,7 @@ class Relationship( inherit_cache = True _links_to_entity = True + _is_relationship = True _persistence_only = dict( passive_deletes=False, @@ -159,38 +197,39 @@ class Relationship( self, argument: Optional[_RelationshipArgumentType[_T]] = None, secondary=None, + *, + uselist=None, + collection_class=None, primaryjoin=None, secondaryjoin=None, - foreign_keys=None, - uselist=None, + back_populates=None, order_by=False, backref=None, - back_populates=None, + cascade_backrefs=False, overlaps=None, post_update=False, - cascade=False, + cascade="save-update, merge", viewonly=False, - lazy="select", - collection_class=None, - passive_deletes=_persistence_only["passive_deletes"], - passive_updates=_persistence_only["passive_updates"], + lazy: _LazyLoadArgumentType = "select", + passive_deletes=False, + passive_updates=True, + active_history=False, + enable_typechecks=True, + foreign_keys=None, remote_side=None, - enable_typechecks=_persistence_only["enable_typechecks"], join_depth=None, comparator_factory=None, single_parent=False, innerjoin=False, distinct_target_key=None, - doc=None, - active_history=_persistence_only["active_history"], - cascade_backrefs=_persistence_only["cascade_backrefs"], load_on_pending=False, - bake_queries=True, - _local_remote_pairs=None, query_class=None, info=None, omit_join=None, sync_backref=None, + doc=None, + bake_queries=True, + _local_remote_pairs=None, _legacy_inactive_history_style=False, ): super(Relationship, self).__init__() @@ -250,7 +289,6 @@ class Relationship( self.omit_join = omit_join self.local_remote_pairs = _local_remote_pairs - self.bake_queries = bake_queries self.load_on_pending = load_on_pending self.comparator_factory = comparator_factory or Relationship.Comparator self.comparator = self.comparator_factory(self, None) @@ -267,12 +305,7 @@ class Relationship( else: self._overlaps = () - if cascade is not False: - self.cascade = cascade - elif self.viewonly: - self.cascade = "none" - else: - self.cascade = "save-update, merge" + self.cascade = cascade self.order_by = order_by @@ -539,9 +572,9 @@ class Relationship( def _criterion_exists( self, - criterion: Optional[SQLCoreOperations[Any]] = None, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, **kwargs: Any, - ) -> Exists[bool]: + ) -> Exists: if getattr(self, "_of_type", None): info = inspect(self._of_type) target_mapper, to_selectable, is_aliased_class = ( @@ -898,7 +931,12 @@ class Relationship( comparator: Comparator[_T] - def _with_parent(self, instance, alias_secondary=True, from_entity=None): + def _with_parent( + self, + instance: object, + alias_secondary: bool = True, + from_entity: Optional[_EntityType[Any]] = None, + ) -> ColumnElement[bool]: assert instance is not None adapt_source = None if from_entity is not None: @@ -1502,7 +1540,7 @@ class Relationship( argument = argument if isinstance(argument, type): - entity = mapperlib.class_mapper(argument, configure=False) + entity = class_mapper(argument, configure=False) else: try: entity = inspect(argument) @@ -1568,7 +1606,7 @@ class Relationship( """Test that this relationship is legal, warn about inheritance conflicts.""" mapperlib = util.preloaded.orm_mapper - if self.parent.non_primary and not mapperlib.class_mapper( + if self.parent.non_primary and not class_mapper( self.parent.class_, configure=False ).has_property(self.key): raise sa_exc.ArgumentError( @@ -1585,29 +1623,23 @@ class Relationship( ) @property - def cascade(self): + def cascade(self) -> CascadeOptions: """Return the current cascade setting for this :class:`.Relationship`. """ return self._cascade @cascade.setter - def cascade(self, cascade): + def cascade(self, cascade: Union[str, CascadeOptions]): self._set_cascade(cascade) - def _set_cascade(self, cascade): - cascade = CascadeOptions(cascade) + def _set_cascade(self, cascade_arg: Union[str, CascadeOptions]): + cascade = CascadeOptions(cascade_arg) if self.viewonly: - non_viewonly = set(cascade).difference( - CascadeOptions._viewonly_cascades + cascade = CascadeOptions( + cascade.intersection(CascadeOptions._viewonly_cascades) ) - if non_viewonly: - raise sa_exc.ArgumentError( - 'Cascade settings "%s" apply to persistence operations ' - "and should not be combined with a viewonly=True " - "relationship." % (", ".join(sorted(non_viewonly))) - ) if "mapper" in self.__dict__: self._check_cascade_settings(cascade) @@ -1754,8 +1786,8 @@ class Relationship( relationship = Relationship( parent, self.secondary, - pj, - sj, + primaryjoin=pj, + secondaryjoin=sj, foreign_keys=foreign_keys, back_populates=self.key, **kwargs, diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 5b1d0bb08..74035ec0a 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -39,6 +39,7 @@ from . import persistence from . import query from . import state as statelib from ._typing import _O +from ._typing import insp_is_mapper from ._typing import is_composite_class from ._typing import is_user_defined_option from .base import _class_to_mapper @@ -69,12 +70,14 @@ from ..engine.util import TransactionalContext from ..event import dispatcher from ..event import EventTarget from ..inspection import inspect +from ..inspection import Inspectable from ..sql import coercions from ..sql import dml from ..sql import roles from ..sql import Select from ..sql import visitors from ..sql.base import CompileState +from ..sql.schema import Table from ..sql.selectable import ForUpdateArg from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..util import IdentitySet @@ -90,6 +93,7 @@ if typing.TYPE_CHECKING: from .path_registry import PathRegistry from ..engine import Result from ..engine import Row + from ..engine import RowMapping from ..engine.base import Transaction from ..engine.base import TwoPhaseTransaction from ..engine.interfaces import _CoreAnyExecuteParams @@ -103,6 +107,7 @@ if typing.TYPE_CHECKING: from ..sql.base import Executable from ..sql.elements import ClauseElement from ..sql.schema import Table + from ..sql.selectable import TableClause __all__ = [ "Session", @@ -184,7 +189,7 @@ class _SessionClassMethods: ident: Union[Any, Tuple[Any, ...]] = None, *, instance: Optional[Any] = None, - row: Optional[Row] = None, + row: Optional[Union[Row, RowMapping]] = None, identity_token: Optional[Any] = None, ) -> _IdentityKeyType[Any]: """Return an identity key. @@ -2050,9 +2055,12 @@ class Session(_SessionClassMethods, EventTarget): else: self.__binds[key] = bind else: - if insp.is_selectable: + if TYPE_CHECKING: + assert isinstance(insp, Inspectable) + + if isinstance(insp, Table): self.__binds[insp] = bind - elif insp.is_mapper: + elif insp_is_mapper(insp): self.__binds[insp.class_] = bind for _selectable in insp._all_tables: self.__binds[_selectable] = bind @@ -2211,7 +2219,7 @@ class Session(_SessionClassMethods, EventTarget): # we don't have self.bind and either have self.__binds # or we don't have self.__binds (which is legacy). Look at the # mapper and the clause - if mapper is clause is None: + if mapper is None and clause is None: if self.bind: return self.bind else: @@ -2350,7 +2358,10 @@ class Session(_SessionClassMethods, EventTarget): key = mapper.identity_key_from_primary_key( primary_key_identity, identity_token=identity_token ) - return loading.get_from_identity(self, mapper, key, passive) + + # work around: https://github.com/python/typing/discussions/1143 + return_value = loading.get_from_identity(self, mapper, key, passive) + return return_value @util.non_memoized_property @contextlib.contextmanager diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 85e015193..2d85ba7f6 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -2162,10 +2162,11 @@ class JoinedLoader(AbstractRelationshipLoader): else: to_adapt = self._gen_pooled_aliased_class(compile_state) - clauses = inspect(to_adapt)._memo( + to_adapt_insp = inspect(to_adapt) + clauses = to_adapt_insp._memo( ("joinedloader_ormadapter", self), orm_util.ORMAdapter, - to_adapt, + to_adapt_insp, equivalents=self.mapper._equivalent_columns, adapt_required=True, allow_label_resolve=False, diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 4699781a4..3934de535 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -11,8 +11,16 @@ import re import types import typing from typing import Any +from typing import cast +from typing import Dict +from typing import FrozenSet from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Match from typing import Optional +from typing import Sequence from typing import Tuple from typing import Type from typing import TypeVar @@ -20,32 +28,35 @@ from typing import Union import weakref from . import attributes # noqa -from .base import _class_to_mapper # noqa -from .base import _never_set # noqa -from .base import _none_set # noqa -from .base import attribute_str # noqa -from .base import class_mapper # noqa -from .base import InspectionAttr # noqa -from .base import instance_str # noqa -from .base import object_mapper # noqa -from .base import object_state # noqa -from .base import state_attribute_str # noqa -from .base import state_class_str # noqa -from .base import state_str # noqa +from ._typing import _O +from ._typing import insp_is_aliased_class +from ._typing import insp_is_mapper +from ._typing import prop_is_relationship +from .base import _class_to_mapper as _class_to_mapper +from .base import _never_set as _never_set +from .base import _none_set as _none_set +from .base import attribute_str as attribute_str +from .base import class_mapper as class_mapper +from .base import InspectionAttr as InspectionAttr +from .base import instance_str as instance_str +from .base import object_mapper as object_mapper +from .base import object_state as object_state +from .base import state_attribute_str as state_attribute_str +from .base import state_class_str as state_class_str +from .base import state_str as state_str from .interfaces import CriteriaOption -from .interfaces import MapperProperty # noqa +from .interfaces import MapperProperty as MapperProperty from .interfaces import ORMColumnsClauseRole from .interfaces import ORMEntityColumnsClauseRole from .interfaces import ORMFromClauseRole -from .interfaces import PropComparator # noqa -from .path_registry import PathRegistry # noqa +from .interfaces import PropComparator as PropComparator +from .path_registry import PathRegistry as PathRegistry from .. import event from .. import exc as sa_exc from .. import inspection from .. import sql from .. import util from ..engine.result import result_tuple -from ..sql import base as sql_base from ..sql import coercions from ..sql import expression from ..sql import lambdas @@ -54,19 +65,39 @@ from ..sql import util as sql_util from ..sql import visitors from ..sql.annotation import SupportsCloneAnnotations from ..sql.base import ColumnCollection +from ..sql.cache_key import HasCacheKey +from ..sql.cache_key import MemoizedHasCacheKey +from ..sql.elements import ColumnElement from ..sql.selectable import FromClause from ..util.langhelpers import MemoizedSlots from ..util.typing import de_stringify_annotation from ..util.typing import is_origin_of +from ..util.typing import Literal if typing.TYPE_CHECKING: from ._typing import _EntityType from ._typing import _IdentityKeyType from ._typing import _InternalEntityType + from ._typing import _ORMColumnExprArgument + from .context import _MapperEntity + from .context import ORMCompileState from .mapper import Mapper + from .relationships import Relationship from ..engine import Row + from ..engine import RowMapping + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _EquivalentColumnMap + from ..sql._typing import _FromClauseArgument + from ..sql._typing import _OnClauseArgument from ..sql._typing import _PropagateAttrsType + from ..sql.base import ReadOnlyColumnCollection + from ..sql.elements import BindParameter + from ..sql.selectable import _ColumnsClauseElement from ..sql.selectable import Alias + from ..sql.selectable import Subquery + from ..sql.visitors import _ET + from ..sql.visitors import anon_map + from ..sql.visitors import ExternallyTraversible _T = TypeVar("_T", bound=Any) @@ -84,7 +115,7 @@ all_cascades = frozenset( ) -class CascadeOptions(frozenset): +class CascadeOptions(FrozenSet[str]): """Keeps track of the options sent to :paramref:`.relationship.cascade`""" @@ -104,6 +135,13 @@ class CascadeOptions(frozenset): "delete_orphan", ) + save_update: bool + delete: bool + refresh_expire: bool + merge: bool + expunge: bool + delete_orphan: bool + def __new__(cls, value_list): if isinstance(value_list, str) or value_list is None: return cls.from_string(value_list) @@ -127,7 +165,7 @@ class CascadeOptions(frozenset): values.clear() values.discard("all") - self = frozenset.__new__(CascadeOptions, values) + self = super().__new__(cls, values) # type: ignore self.save_update = "save-update" in values self.delete = "delete" in values self.refresh_expire = "refresh-expire" in values @@ -238,7 +276,7 @@ def polymorphic_union( """ - colnames = util.OrderedSet() + colnames: util.OrderedSet[str] = util.OrderedSet() colnamemaps = {} types = {} for key in table_map: @@ -299,13 +337,13 @@ def polymorphic_union( def identity_key( - class_: Optional[Type[Any]] = None, + class_: Optional[Type[_T]] = None, ident: Union[Any, Tuple[Any, ...]] = None, *, - instance: Optional[Any] = None, - row: Optional[Row] = None, + instance: Optional[_T] = None, + row: Optional[Union[Row, RowMapping]] = None, identity_token: Optional[Any] = None, -) -> _IdentityKeyType: +) -> _IdentityKeyType[_T]: r"""Generate "identity key" tuples, as are used as keys in the :attr:`.Session.identity_map` dictionary. @@ -351,7 +389,7 @@ def identity_key( * ``identity_key(class, row=row, identity_token=token)`` This form is similar to the class/tuple form, except is passed a - database result row as a :class:`.Row` object. + database result row as a :class:`.Row` or :class:`.RowMapping` object. E.g.:: @@ -375,7 +413,7 @@ def identity_key( if ident is None: raise sa_exc.ArgumentError("ident or row is required") return mapper.identity_key_from_primary_key( - util.to_list(ident), identity_token=identity_token + tuple(util.to_list(ident)), identity_token=identity_token ) else: return mapper.identity_key_from_row( @@ -394,24 +432,26 @@ class ORMAdapter(sql_util.ColumnAdapter): """ - is_aliased_class = False - aliased_insp = None + is_aliased_class: bool + aliased_insp: Optional[AliasedInsp[Any]] def __init__( self, - entity, - equivalents=None, - adapt_required=False, - allow_label_resolve=True, - anonymize_labels=False, + entity: _InternalEntityType[Any], + equivalents: Optional[_EquivalentColumnMap] = None, + adapt_required: bool = False, + allow_label_resolve: bool = True, + anonymize_labels: bool = False, ): - info = inspection.inspect(entity) - self.mapper = info.mapper - selectable = info.selectable - if info.is_aliased_class: + self.mapper = entity.mapper + selectable = entity.selectable + if insp_is_aliased_class(entity): self.is_aliased_class = True - self.aliased_insp = info + self.aliased_insp = entity + else: + self.is_aliased_class = False + self.aliased_insp = None sql_util.ColumnAdapter.__init__( self, @@ -428,7 +468,7 @@ class ORMAdapter(sql_util.ColumnAdapter): return not entity or entity.isa(self.mapper) -class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]): +class AliasedClass(inspection.Inspectable["AliasedInsp[_O]"], Generic[_O]): r"""Represents an "aliased" form of a mapped class for usage with Query. The ORM equivalent of a :func:`~sqlalchemy.sql.expression.alias` @@ -489,19 +529,20 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]): def __init__( self, - mapped_class_or_ac: Union[Type[_T], "Mapper[_T]", "AliasedClass[_T]"], - alias=None, - name=None, - flat=False, - adapt_on_names=False, - # TODO: None for default here? - with_polymorphic_mappers=(), - with_polymorphic_discriminator=None, - base_alias=None, - use_mapper_path=False, - represents_outer_join=False, + mapped_class_or_ac: _EntityType[_O], + alias: Optional[FromClause] = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, + with_polymorphic_mappers: Optional[Sequence[Mapper[Any]]] = None, + with_polymorphic_discriminator: Optional[ColumnElement[Any]] = None, + base_alias: Optional[AliasedInsp[Any]] = None, + use_mapper_path: bool = False, + represents_outer_join: bool = False, ): - insp = inspection.inspect(mapped_class_or_ac) + insp = cast( + "_InternalEntityType[_O]", inspection.inspect(mapped_class_or_ac) + ) mapper = insp.mapper nest_adapters = False @@ -519,6 +560,7 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]): elif insp.is_aliased_class: nest_adapters = True + assert alias is not None self._aliased_insp = AliasedInsp( self, insp, @@ -540,7 +582,9 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]): self.__name__ = f"aliased({mapper.class_.__name__})" @classmethod - def _reconstitute_from_aliased_insp(cls, aliased_insp): + def _reconstitute_from_aliased_insp( + cls, aliased_insp: AliasedInsp[_O] + ) -> AliasedClass[_O]: obj = cls.__new__(cls) obj.__name__ = f"aliased({aliased_insp.mapper.class_.__name__})" obj._aliased_insp = aliased_insp @@ -555,7 +599,7 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]): return obj - def __getattr__(self, key): + def __getattr__(self, key: str) -> Any: try: _aliased_insp = self.__dict__["_aliased_insp"] except KeyError: @@ -584,7 +628,9 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]): return attr - def _get_from_serialized(self, key, mapped_class, aliased_insp): + def _get_from_serialized( + self, key: str, mapped_class: _O, aliased_insp: AliasedInsp[_O] + ) -> Any: # this method is only used in terms of the # sqlalchemy.ext.serializer extension attr = getattr(mapped_class, key) @@ -605,23 +651,25 @@ class AliasedClass(inspection.Inspectable["AliasedInsp"], Generic[_T]): return attr - def __repr__(self): + def __repr__(self) -> str: return "<AliasedClass at 0x%x; %s>" % ( id(self), self._aliased_insp._target.__name__, ) - def __str__(self): + def __str__(self) -> str: return str(self._aliased_insp) +@inspection._self_inspects class AliasedInsp( ORMEntityColumnsClauseRole, ORMFromClauseRole, - sql_base.HasCacheKey, + HasCacheKey, InspectionAttr, MemoizedSlots, - Generic[_T], + inspection.Inspectable["AliasedInsp[_O]"], + Generic[_O], ): """Provide an inspection interface for an :class:`.AliasedClass` object. @@ -685,19 +733,36 @@ class AliasedInsp( "_nest_adapters", ) + mapper: Mapper[_O] + selectable: FromClause + _adapter: sql_util.ColumnAdapter + with_polymorphic_mappers: Sequence[Mapper[Any]] + _with_polymorphic_entities: Sequence[AliasedInsp[Any]] + + _weak_entity: weakref.ref[AliasedClass[_O]] + """the AliasedClass that refers to this AliasedInsp""" + + _target: Union[_O, AliasedClass[_O]] + """the thing referred towards by the AliasedClass/AliasedInsp. + + In the vast majority of cases, this is the mapped class. However + it may also be another AliasedClass (alias of alias). + + """ + def __init__( self, - entity: _EntityType, - inspected: _InternalEntityType, - selectable, - name, - with_polymorphic_mappers, - polymorphic_on, - _base_alias, - _use_mapper_path, - adapt_on_names, - represents_outer_join, - nest_adapters, + entity: AliasedClass[_O], + inspected: _InternalEntityType[_O], + selectable: FromClause, + name: Optional[str], + with_polymorphic_mappers: Optional[Sequence[Mapper[Any]]], + polymorphic_on: Optional[ColumnElement[Any]], + _base_alias: Optional[AliasedInsp[Any]], + _use_mapper_path: bool, + adapt_on_names: bool, + represents_outer_join: bool, + nest_adapters: bool, ): mapped_class_or_ac = inspected.entity @@ -752,23 +817,22 @@ class AliasedInsp( ) if nest_adapters: + # supports "aliased class of aliased class" use case + assert isinstance(inspected, AliasedInsp) self._adapter = inspected._adapter.wrap(self._adapter) self._adapt_on_names = adapt_on_names self._target = mapped_class_or_ac - # self._target = mapper.class_ # mapped_class_or_ac @classmethod def _alias_factory( cls, - element: Union[ - Type[_T], "Mapper[_T]", "FromClause", "AliasedClass[_T]" - ], - alias=None, - name=None, - flat=False, - adapt_on_names=False, - ) -> Union["AliasedClass[_T]", "Alias"]: + element: Union[_EntityType[_O], FromClause], + alias: Optional[Union[Alias, Subquery]] = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, + ) -> Union[AliasedClass[_O], FromClause]: if isinstance(element, FromClause): if adapt_on_names: @@ -793,16 +857,16 @@ class AliasedInsp( @classmethod def _with_polymorphic_factory( cls, - base, - classes, - selectable=False, - flat=False, - polymorphic_on=None, - aliased=False, - innerjoin=False, - adapt_on_names=False, - _use_mapper_path=False, - ): + base: Union[_O, Mapper[_O]], + classes: Iterable[Type[Any]], + selectable: Union[Literal[False, None], FromClause] = False, + flat: bool = False, + polymorphic_on: Optional[ColumnElement[Any]] = None, + aliased: bool = False, + innerjoin: bool = False, + adapt_on_names: bool = False, + _use_mapper_path: bool = False, + ) -> AliasedClass[_O]: primary_mapper = _class_to_mapper(base) @@ -816,7 +880,9 @@ class AliasedInsp( classes, selectable, innerjoin=innerjoin ) if aliased or flat: + assert selectable is not None selectable = selectable._anonymous_fromclause(flat=flat) + return AliasedClass( base, selectable, @@ -828,7 +894,7 @@ class AliasedInsp( ) @property - def entity(self): + def entity(self) -> AliasedClass[_O]: # to eliminate reference cycles, the AliasedClass is held weakly. # this produces some situations where the AliasedClass gets lost, # particularly when one is created internally and only the AliasedInsp @@ -844,7 +910,7 @@ class AliasedInsp( is_aliased_class = True "always returns True" - def _memoized_method___clause_element__(self): + def _memoized_method___clause_element__(self) -> FromClause: return self.selectable._annotate( { "parentmapper": self.mapper, @@ -856,7 +922,7 @@ class AliasedInsp( ) @property - def entity_namespace(self): + def entity_namespace(self) -> AliasedClass[_O]: return self.entity _cache_key_traversal = [ @@ -866,7 +932,7 @@ class AliasedInsp( ] @property - def class_(self): + def class_(self) -> Type[_O]: """Return the mapped class ultimately represented by this :class:`.AliasedInsp`.""" return self.mapper.class_ @@ -878,7 +944,7 @@ class AliasedInsp( else: return PathRegistry.per_mapper(self) - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: return { "entity": self.entity, "mapper": self.mapper, @@ -893,8 +959,8 @@ class AliasedInsp( "nest_adapters": self._nest_adapters, } - def __setstate__(self, state): - self.__init__( + def __setstate__(self, state: Dict[str, Any]) -> None: + self.__init__( # type: ignore state["entity"], state["mapper"], state["alias"], @@ -908,7 +974,7 @@ class AliasedInsp( state["nest_adapters"], ) - def _merge_with(self, other): + def _merge_with(self, other: AliasedInsp[_O]) -> AliasedInsp[_O]: # assert self._is_with_polymorphic # assert other._is_with_polymorphic @@ -929,7 +995,6 @@ class AliasedInsp( classes, None, innerjoin=not other.represents_outer_join ) selectable = selectable._anonymous_fromclause(flat=True) - return AliasedClass( primary_mapper, selectable, @@ -937,10 +1002,13 @@ class AliasedInsp( with_polymorphic_discriminator=other.polymorphic_on, use_mapper_path=other._use_mapper_path, represents_outer_join=other.represents_outer_join, - ) + )._aliased_insp - def _adapt_element(self, elem, key=None): - d = { + def _adapt_element( + self, elem: _ORMColumnExprArgument[_T], key: Optional[str] = None + ) -> _ORMColumnExprArgument[_T]: + assert isinstance(elem, ColumnElement) + d: Dict[str, Any] = { "parententity": self, "parentmapper": self.mapper, } @@ -1084,35 +1152,45 @@ class LoaderCriteriaOption(CriteriaOption): ("propagate_to_loaders", visitors.InternalTraversal.dp_boolean), ] + root_entity: Optional[Type[Any]] + entity: Optional[_InternalEntityType[Any]] + where_criteria: Union[ColumnElement[bool], lambdas.DeferredLambdaElement] + deferred_where_criteria: bool + include_aliases: bool + propagate_to_loaders: bool + def __init__( self, - entity_or_base, - where_criteria, - loader_only=False, - include_aliases=False, - propagate_to_loaders=True, - track_closure_variables=True, + entity_or_base: _EntityType[Any], + where_criteria: _ColumnExpressionArgument[bool], + loader_only: bool = False, + include_aliases: bool = False, + propagate_to_loaders: bool = True, + track_closure_variables: bool = True, ): - entity = inspection.inspect(entity_or_base, False) + entity = cast( + "_InternalEntityType[Any]", + inspection.inspect(entity_or_base, False), + ) if entity is None: - self.root_entity = entity_or_base + self.root_entity = cast("Type[Any]", entity_or_base) self.entity = None else: self.root_entity = None self.entity = entity if callable(where_criteria): + if self.root_entity is not None: + wrap_entity = self.root_entity + else: + assert entity is not None + wrap_entity = entity.entity + self.deferred_where_criteria = True self.where_criteria = lambdas.DeferredLambdaElement( - where_criteria, + where_criteria, # type: ignore roles.WhereHavingRole, - lambda_args=( - _WrapUserEntity( - self.root_entity - if self.root_entity is not None - else self.entity.entity, - ), - ), + lambda_args=(_WrapUserEntity(wrap_entity),), opts=lambdas.LambdaOptions( track_closure_variables=track_closure_variables ), @@ -1126,22 +1204,27 @@ class LoaderCriteriaOption(CriteriaOption): self.include_aliases = include_aliases self.propagate_to_loaders = propagate_to_loaders - def _all_mappers(self): + def _all_mappers(self) -> Iterator[Mapper[Any]]: + if self.entity: - for ent in self.entity.mapper.self_and_descendants: - yield ent + for mp_ent in self.entity.mapper.self_and_descendants: + yield mp_ent else: + assert self.root_entity stack = list(self.root_entity.__subclasses__()) while stack: subclass = stack.pop(0) - ent = inspection.inspect(subclass, raiseerr=False) + ent = cast( + "_InternalEntityType[Any]", + inspection.inspect(subclass, raiseerr=False), + ) if ent: for mp in ent.mapper.self_and_descendants: yield mp else: stack.extend(subclass.__subclasses__()) - def _should_include(self, compile_state): + def _should_include(self, compile_state: ORMCompileState) -> bool: if ( compile_state.select_statement._annotations.get( "for_loader_criteria", None @@ -1151,21 +1234,29 @@ class LoaderCriteriaOption(CriteriaOption): return False return True - def _resolve_where_criteria(self, ext_info): + def _resolve_where_criteria( + self, ext_info: _InternalEntityType[Any] + ) -> ColumnElement[bool]: if self.deferred_where_criteria: - crit = self.where_criteria._resolve_with_args(ext_info.entity) + crit = cast( + "ColumnElement[bool]", + self.where_criteria._resolve_with_args(ext_info.entity), + ) else: - crit = self.where_criteria + crit = self.where_criteria # type: ignore + assert isinstance(crit, ColumnElement) return sql_util._deep_annotate( crit, {"for_loader_criteria": self}, detect_subquery_cols=True ) def process_compile_state_replaced_entities( - self, compile_state, mapper_entities - ): - return self.process_compile_state(compile_state) + self, + compile_state: ORMCompileState, + mapper_entities: Iterable[_MapperEntity], + ) -> None: + self.process_compile_state(compile_state) - def process_compile_state(self, compile_state): + def process_compile_state(self, compile_state: ORMCompileState) -> None: """Apply a modification to a given :class:`.CompileState`.""" # if options to limit the criteria to immediate query only, @@ -1173,7 +1264,7 @@ class LoaderCriteriaOption(CriteriaOption): self.get_global_criteria(compile_state.global_attributes) - def get_global_criteria(self, attributes): + def get_global_criteria(self, attributes: Dict[Any, Any]) -> None: for mp in self._all_mappers(): load_criteria = attributes.setdefault( ("additional_entity_criteria", mp), [] @@ -1183,14 +1274,14 @@ class LoaderCriteriaOption(CriteriaOption): inspection._inspects(AliasedClass)(lambda target: target._aliased_insp) -inspection._inspects(AliasedInsp)(lambda target: target) @inspection._self_inspects class Bundle( ORMColumnsClauseRole, SupportsCloneAnnotations, - sql_base.MemoizedHasCacheKey, + MemoizedHasCacheKey, + inspection.Inspectable["Bundle"], InspectionAttr, ): """A grouping of SQL expressions that are returned by a :class:`.Query` @@ -1227,7 +1318,11 @@ class Bundle( _propagate_attrs: _PropagateAttrsType = util.immutabledict() - def __init__(self, name, *exprs, **kw): + exprs: List[_ColumnsClauseElement] + + def __init__( + self, name: str, *exprs: _ColumnExpressionArgument[Any], **kw: Any + ): r"""Construct a new :class:`.Bundle`. e.g.:: @@ -1246,37 +1341,43 @@ class Bundle( """ self.name = self._label = name - self.exprs = exprs = [ + coerced_exprs = [ coercions.expect( roles.ColumnsClauseRole, expr, apply_propagate_attrs=self ) for expr in exprs ] + self.exprs = coerced_exprs self.c = self.columns = ColumnCollection( (getattr(col, "key", col._label), col) - for col in [e._annotations.get("bundle", e) for e in exprs] - ) + for col in [e._annotations.get("bundle", e) for e in coerced_exprs] + ).as_readonly() self.single_entity = kw.pop("single_entity", self.single_entity) - def _gen_cache_key(self, anon_map, bindparams): + def _gen_cache_key( + self, anon_map: anon_map, bindparams: List[BindParameter[Any]] + ) -> Tuple[Any, ...]: return (self.__class__, self.name, self.single_entity) + tuple( [expr._gen_cache_key(anon_map, bindparams) for expr in self.exprs] ) @property - def mapper(self): + def mapper(self) -> Mapper[Any]: return self.exprs[0]._annotations.get("parentmapper", None) @property - def entity(self): + def entity(self) -> _InternalEntityType[Any]: return self.exprs[0]._annotations.get("parententity", None) @property - def entity_namespace(self): + def entity_namespace( + self, + ) -> ReadOnlyColumnCollection[str, ColumnElement[Any]]: return self.c - columns = None + columns: ReadOnlyColumnCollection[str, ColumnElement[Any]] + """A namespace of SQL expressions referred to by this :class:`.Bundle`. e.g.:: @@ -1301,7 +1402,7 @@ class Bundle( """ - c = None + c: ReadOnlyColumnCollection[str, ColumnElement[Any]] """An alias for :attr:`.Bundle.columns`.""" def _clone(self): @@ -1400,32 +1501,30 @@ class _ORMJoin(expression.Join): def __init__( self, - left, - right, - onclause=None, - isouter=False, - full=False, - _left_memo=None, - _right_memo=None, - _extra_criteria=(), + left: _FromClauseArgument, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + isouter: bool = False, + full: bool = False, + _left_memo: Optional[Any] = None, + _right_memo: Optional[Any] = None, + _extra_criteria: Sequence[ColumnElement[bool]] = (), ): - left_info = inspection.inspect(left) + left_info = cast( + "Union[FromClause, _InternalEntityType[Any]]", + inspection.inspect(left), + ) - right_info = inspection.inspect(right) + right_info = cast( + "Union[FromClause, _InternalEntityType[Any]]", + inspection.inspect(right), + ) adapt_to = right_info.selectable # used by joined eager loader self._left_memo = _left_memo self._right_memo = _right_memo - # legacy, for string attr name ON clause. if that's removed - # then the "_joined_from_info" concept can go - left_orm_info = getattr(left, "_joined_from_info", left_info) - self._joined_from_info = right_info - if isinstance(onclause, str): - onclause = getattr(left_orm_info.entity, onclause) - # #### - if isinstance(onclause, attributes.QueryableAttribute): on_selectable = onclause.comparator._source_selectable() prop = onclause.property @@ -1477,20 +1576,23 @@ class _ORMJoin(expression.Join): augment_onclause = onclause is None and _extra_criteria expression.Join.__init__(self, left, right, onclause, isouter, full) + assert self.onclause is not None + if augment_onclause: self.onclause &= sql.and_(*_extra_criteria) if ( not prop and getattr(right_info, "mapper", None) - and right_info.mapper.single + and right_info.mapper.single # type: ignore ): + right_info = cast("_InternalEntityType[Any]", right_info) # if single inheritance target and we are using a manual # or implicit ON clause, augment it the same way we'd augment the # WHERE. single_crit = right_info.mapper._single_table_criterion if single_crit is not None: - if right_info.is_aliased_class: + if insp_is_aliased_class(right_info): single_crit = right_info._adapter.traverse(single_crit) self.onclause = self.onclause & single_crit @@ -1525,19 +1627,27 @@ class _ORMJoin(expression.Join): def join( self, - right, - onclause=None, - isouter=False, - full=False, - join_to_left=None, - ): + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + isouter: bool = False, + full: bool = False, + ) -> _ORMJoin: return _ORMJoin(self, right, onclause, full=full, isouter=isouter) - def outerjoin(self, right, onclause=None, full=False, join_to_left=None): + def outerjoin( + self, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + full: bool = False, + ) -> _ORMJoin: return _ORMJoin(self, right, onclause, isouter=True, full=full) -def with_parent(instance, prop, from_entity=None): +def with_parent( + instance: object, + prop: attributes.QueryableAttribute[Any], + from_entity: Optional[_EntityType[Any]] = None, +) -> ColumnElement[bool]: """Create filtering criterion that relates this query's primary entity to the given related instance, using established :func:`_orm.relationship()` @@ -1588,6 +1698,8 @@ def with_parent(instance, prop, from_entity=None): .. versionadded:: 1.2 """ + prop_t: Relationship[Any] + if isinstance(prop, str): raise sa_exc.ArgumentError( "with_parent() accepts class-bound mapped attributes, not strings" @@ -1595,12 +1707,19 @@ def with_parent(instance, prop, from_entity=None): elif isinstance(prop, attributes.QueryableAttribute): if prop._of_type: from_entity = prop._of_type - prop = prop.property + if not prop_is_relationship(prop.property): + raise sa_exc.ArgumentError( + f"Expected relationship property for with_parent(), " + f"got {prop.property}" + ) + prop_t = prop.property + else: + prop_t = prop - return prop._with_parent(instance, from_entity=from_entity) + return prop_t._with_parent(instance, from_entity=from_entity) -def has_identity(object_): +def has_identity(object_: object) -> bool: """Return True if the given object has a database identity. @@ -1616,7 +1735,7 @@ def has_identity(object_): return state.has_identity -def was_deleted(object_): +def was_deleted(object_: object) -> bool: """Return True if the given object was deleted within a session flush. @@ -1633,27 +1752,32 @@ def was_deleted(object_): return state.was_deleted -def _entity_corresponds_to(given, entity): +def _entity_corresponds_to( + given: _InternalEntityType[Any], entity: _InternalEntityType[Any] +) -> bool: """determine if 'given' corresponds to 'entity', in terms of an entity passed to Query that would match the same entity being referred to elsewhere in the query. """ - if entity.is_aliased_class: - if given.is_aliased_class: + if insp_is_aliased_class(entity): + if insp_is_aliased_class(given): if entity._base_alias() is given._base_alias(): return True return False - elif given.is_aliased_class: + elif insp_is_aliased_class(given): if given._use_mapper_path: return entity in given.with_polymorphic_mappers else: return entity is given + assert insp_is_mapper(given) return entity.common_parent(given) -def _entity_corresponds_to_use_path_impl(given, entity): +def _entity_corresponds_to_use_path_impl( + given: _InternalEntityType[Any], entity: _InternalEntityType[Any] +) -> bool: """determine if 'given' corresponds to 'entity', in terms of a path of loader options where a mapped attribute is taken to be a member of a parent entity. @@ -1673,13 +1797,13 @@ def _entity_corresponds_to_use_path_impl(given, entity): """ - if given.is_aliased_class: + if insp_is_aliased_class(given): return ( - entity.is_aliased_class + insp_is_aliased_class(entity) and not entity._use_mapper_path and (given is entity or entity in given._with_polymorphic_entities) ) - elif not entity.is_aliased_class: + elif not insp_is_aliased_class(entity): return given.isa(entity.mapper) else: return ( @@ -1688,7 +1812,7 @@ def _entity_corresponds_to_use_path_impl(given, entity): ) -def _entity_isa(given, mapper): +def _entity_isa(given: _InternalEntityType[Any], mapper: Mapper[Any]) -> bool: """determine if 'given' "is a" mapper, in terms of the given would load rows of type 'mapper'. @@ -1703,42 +1827,6 @@ def _entity_isa(given, mapper): return given.isa(mapper) -def randomize_unitofwork(): - """Use random-ordering sets within the unit of work in order - to detect unit of work sorting issues. - - This is a utility function that can be used to help reproduce - inconsistent unit of work sorting issues. For example, - if two kinds of objects A and B are being inserted, and - B has a foreign key reference to A - the A must be inserted first. - However, if there is no relationship between A and B, the unit of work - won't know to perform this sorting, and an operation may or may not - fail, depending on how the ordering works out. Since Python sets - and dictionaries have non-deterministic ordering, such an issue may - occur on some runs and not on others, and in practice it tends to - have a great dependence on the state of the interpreter. This leads - to so-called "heisenbugs" where changing entirely irrelevant aspects - of the test program still cause the failure behavior to change. - - By calling ``randomize_unitofwork()`` when a script first runs, the - ordering of a key series of sets within the unit of work implementation - are randomized, so that the script can be minimized down to the - fundamental mapping and operation that's failing, while still reproducing - the issue on at least some runs. - - This utility is also available when running the test suite via the - ``--reversetop`` flag. - - """ - from sqlalchemy.orm import unitofwork, session, mapper, dependency - from sqlalchemy.util import topological - from sqlalchemy.testing.util import RandomSet - - topological.set = ( - unitofwork.set - ) = session.set = mapper.set = dependency.set = RandomSet - - def _getitem(iterable_query, item): """calculate __getitem__ in terms of an iterable query object that also has a slice() method. @@ -1780,16 +1868,21 @@ def _getitem(iterable_query, item): return list(iterable_query[item : item + 1])[0] -def _is_mapped_annotation(raw_annotation: Union[type, str], cls: type): +def _is_mapped_annotation( + raw_annotation: Union[type, str], cls: Type[Any] +) -> bool: annotated = de_stringify_annotation(cls, raw_annotation) return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm") -def _cleanup_mapped_str_annotation(annotation): +def _cleanup_mapped_str_annotation(annotation: str) -> str: # fix up an annotation that comes in as the form: # 'Mapped[List[Address]]' so that it instead looks like: # 'Mapped[List["Address"]]' , which will allow us to get # "Address" as a string + + inner: Optional[Match[str]] + mm = re.match(r"^(.+?)\[(.+)\]$", annotation) if mm and mm.group(1) == "Mapped": stack = [] @@ -1839,8 +1932,8 @@ def _extract_mapped_subtype( else: if ( not hasattr(annotated, "__origin__") - or not issubclass(annotated.__origin__, attr_cls) - and not issubclass(attr_cls, annotated.__origin__) + or not issubclass(annotated.__origin__, attr_cls) # type: ignore + and not issubclass(attr_cls, annotated.__origin__) # type: ignore ): our_annotated_str = ( annotated.__name__ @@ -1853,9 +1946,9 @@ def _extract_mapped_subtype( f'"{attr_cls.__name__}[{our_annotated_str}]".' ) - if len(annotated.__args__) != 1: + if len(annotated.__args__) != 1: # type: ignore raise sa_exc.ArgumentError( "Expected sub-type for Mapped[] annotation" ) - return annotated.__args__[0] + return annotated.__args__[0] # type: ignore |
