diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/orm/__init__.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/_orm_constructors.py | 167 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/decl_api.py | 142 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/decl_base.py | 315 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/descriptor_props.py | 48 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/instrumentation.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/interfaces.py | 97 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/properties.py | 63 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/relationships.py | 26 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/util.py | 45 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/fixtures.py | 25 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/compat.py | 13 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/typing.py | 8 |
13 files changed, 799 insertions, 160 deletions
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index b7d1df532..4f19ba946 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -60,6 +60,7 @@ from .decl_api import DeclarativeBaseNoMeta as DeclarativeBaseNoMeta from .decl_api import DeclarativeMeta as DeclarativeMeta from .decl_api import declared_attr as declared_attr from .decl_api import has_inherited_table as has_inherited_table +from .decl_api import MappedAsDataclass as MappedAsDataclass from .decl_api import registry as registry from .decl_api import synonym_for as synonym_for from .descriptor_props import Composite as Composite diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index 0692cac09..ece6a52be 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -21,9 +21,9 @@ from typing import Union from . import mapperlib as mapperlib from ._typing import _O -from .base import Mapped from .descriptor_props import Composite from .descriptor_props import Synonym +from .interfaces import _AttributeOptions from .properties import ColumnProperty from .properties import MappedColumn from .query import AliasOption @@ -37,6 +37,8 @@ from .util import LoaderCriteriaOption from .. import sql from .. import util from ..exc import InvalidRequestError +from ..sql._typing import _no_kw +from ..sql.base import _NoArg from ..sql.base import SchemaEventTarget from ..sql.schema import SchemaConst from ..sql.selectable import FromClause @@ -105,6 +107,10 @@ def mapped_column( Union[_TypeEngineArgument[Any], SchemaEventTarget] ] = None, *args: SchemaEventTarget, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, nullable: Optional[ Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] ] = SchemaConst.NULL_UNSPECIFIED, @@ -113,7 +119,6 @@ def mapped_column( name: Optional[str] = None, type_: Optional[_TypeEngineArgument[Any]] = None, autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", - default: Optional[Any] = None, doc: Optional[str] = None, key: Optional[str] = None, index: Optional[bool] = None, @@ -300,6 +305,12 @@ def mapped_column( type_=type_, autoincrement=autoincrement, default=default, + attribute_options=_AttributeOptions( + init, + repr, + default, + default_factory, + ), doc=doc, key=key, index=index, @@ -325,6 +336,10 @@ def column_property( deferred: bool = False, raiseload: bool = False, comparator_factory: Optional[Type[PropComparator[_T]]] = None, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, active_history: bool = False, expire_on_flush: bool = True, info: Optional[_InfoType] = None, @@ -416,6 +431,12 @@ def column_property( return ColumnProperty( column, *additional_columns, + attribute_options=_AttributeOptions( + init, + repr, + default, + default_factory, + ), group=group, deferred=deferred, raiseload=raiseload, @@ -429,25 +450,61 @@ def column_property( @overload def composite( - class_: Type[_CC], + _class_or_attr: Type[_CC], *attrs: _CompositeAttrType[Any], - **kwargs: Any, + group: Optional[str] = None, + deferred: bool = False, + raiseload: bool = False, + comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None, + active_history: bool = False, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, + **__kw: Any, ) -> Composite[_CC]: ... @overload def composite( + _class_or_attr: _CompositeAttrType[Any], *attrs: _CompositeAttrType[Any], - **kwargs: Any, + group: Optional[str] = None, + deferred: bool = False, + raiseload: bool = False, + comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None, + active_history: bool = False, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, + **__kw: Any, ) -> Composite[Any]: ... def composite( - class_: Any = None, + _class_or_attr: Union[ + None, Type[_CC], Callable[..., _CC], _CompositeAttrType[Any] + ] = None, *attrs: _CompositeAttrType[Any], - **kwargs: Any, + group: Optional[str] = None, + deferred: bool = False, + raiseload: bool = False, + comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None, + active_history: bool = False, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, + **__kw: Any, ) -> Composite[Any]: r"""Return a composite column-based property for use with a Mapper. @@ -497,7 +554,26 @@ def composite( :attr:`.MapperProperty.info` attribute of this object. """ - return Composite(class_, *attrs, **kwargs) + if __kw: + raise _no_kw() + + return Composite( + _class_or_attr, + *attrs, + attribute_options=_AttributeOptions( + init, + repr, + default, + default_factory, + ), + group=group, + deferred=deferred, + raiseload=raiseload, + comparator_factory=comparator_factory, + active_history=active_history, + info=info, + doc=doc, + ) def with_loader_criteria( @@ -700,6 +776,10 @@ def relationship( post_update: bool = False, cascade: str = "save-update, merge", viewonly: bool = False, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Union[_NoArg, _T] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, lazy: _LazyLoadArgumentType = "select", passive_deletes: Union[Literal["all"], bool] = False, passive_updates: bool = True, @@ -1532,6 +1612,12 @@ def relationship( post_update=post_update, cascade=cascade, viewonly=viewonly, + attribute_options=_AttributeOptions( + init, + repr, + default, + default_factory, + ), lazy=lazy, passive_deletes=passive_deletes, passive_updates=passive_updates, @@ -1559,6 +1645,10 @@ def synonym( map_column: Optional[bool] = None, descriptor: Optional[Any] = None, comparator_factory: Optional[Type[PropComparator[_T]]] = None, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Union[_NoArg, _T] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, info: Optional[_InfoType] = None, doc: Optional[str] = None, ) -> Synonym[Any]: @@ -1670,6 +1760,12 @@ def synonym( map_column=map_column, descriptor=descriptor, comparator_factory=comparator_factory, + attribute_options=_AttributeOptions( + init, + repr, + default, + default_factory, + ), doc=doc, info=info, ) @@ -1784,7 +1880,17 @@ def backref(name: str, **kwargs: Any) -> _ORMBackrefArgument: def deferred( column: _ORMColumnExprArgument[_T], *additional_columns: _ORMColumnExprArgument[Any], - **kw: Any, + group: Optional[str] = None, + raiseload: bool = False, + comparator_factory: Optional[Type[PropComparator[_T]]] = None, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + active_history: bool = False, + expire_on_flush: bool = True, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, ) -> ColumnProperty[_T]: r"""Indicate a column-based mapped attribute that by default will not load unless accessed. @@ -1803,21 +1909,41 @@ def deferred( :ref:`deferred_raiseload` - :param \**kw: additional keyword arguments passed to - :class:`.ColumnProperty`. + Additional arguments are the same as that of :func:`_orm.column_property`. .. seealso:: :ref:`deferred` """ - kw["deferred"] = True - return ColumnProperty(column, *additional_columns, **kw) + return ColumnProperty( + column, + *additional_columns, + attribute_options=_AttributeOptions( + init, + repr, + default, + default_factory, + ), + group=group, + deferred=True, + raiseload=raiseload, + comparator_factory=comparator_factory, + active_history=active_history, + expire_on_flush=expire_on_flush, + info=info, + doc=doc, + ) def query_expression( default_expr: _ORMColumnExprArgument[_T] = sql.null(), -) -> Mapped[_T]: + *, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + expire_on_flush: bool = True, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, +) -> ColumnProperty[_T]: """Indicate an attribute that populates from a query-time SQL expression. :param default_expr: Optional SQL expression object that will be used in @@ -1840,7 +1966,18 @@ def query_expression( :ref:`mapper_querytime_expression` """ - prop = ColumnProperty(default_expr) + prop = ColumnProperty( + default_expr, + attribute_options=_AttributeOptions( + _NoArg.NO_ARG, + repr, + _NoArg.NO_ARG, + _NoArg.NO_ARG, + ), + expire_on_flush=expire_on_flush, + info=info, + doc=doc, + ) prop.strategy_key = (("query_expression", True),) return prop diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 1c343b04c..553a50107 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -33,6 +33,13 @@ from . import clsregistry from . import instrumentation from . import interfaces from . import mapperlib +from ._orm_constructors import column_property +from ._orm_constructors import composite +from ._orm_constructors import deferred +from ._orm_constructors import mapped_column +from ._orm_constructors import query_expression +from ._orm_constructors import relationship +from ._orm_constructors import synonym from .attributes import InstrumentedAttribute from .base import _inspect_mapped_class from .base import Mapped @@ -42,8 +49,13 @@ from .decl_base import _declarative_constructor from .decl_base import _DeferredMapperConfig from .decl_base import _del_attribute from .decl_base import _mapper +from .descriptor_props import Composite +from .descriptor_props import Synonym from .descriptor_props import Synonym as _orm_synonym from .mapper import Mapper +from .properties import ColumnProperty +from .properties import MappedColumn +from .relationships import Relationship from .state import InstanceState from .. import exc from .. import inspection @@ -60,9 +72,9 @@ from ..util.typing import Literal if TYPE_CHECKING: from ._typing import _O from ._typing import _RegistryType - from .descriptor_props import Synonym from .instrumentation import ClassManager from .interfaces import MapperProperty + from .state import InstanceState # noqa from ..sql._typing import _TypeEngineArgument _T = TypeVar("_T", bound=Any) @@ -120,6 +132,26 @@ class DeclarativeAttributeIntercept( """ +@compat_typing.dataclass_transform( + field_descriptors=( + MappedColumn[Any], + Relationship[Any], + Composite[Any], + ColumnProperty[Any], + Synonym[Any], + mapped_column, + relationship, + composite, + column_property, + synonym, + deferred, + query_expression, + ), +) +class DCTransformDeclarative(DeclarativeAttributeIntercept): + """metaclass that includes @dataclass_transforms""" + + class DeclarativeMeta( _DynamicAttributesType, inspection.Inspectable[Mapper[Any]] ): @@ -543,12 +575,42 @@ class DeclarativeBaseNoMeta(inspection.Inspectable[Mapper[Any]]): cls._sa_registry.map_declaratively(cls) +class MappedAsDataclass(metaclass=DCTransformDeclarative): + """Mixin class to indicate when mapping this class, also convert it to be + a dataclass. + + .. seealso:: + + :meth:`_orm.registry.mapped_as_dataclass` + + .. versionadded:: 2.0 + """ + + def __init_subclass__( + cls, + init: bool = True, + repr: bool = True, # noqa: A002 + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + ) -> None: + cls._sa_apply_dc_transforms = { + "init": init, + "repr": repr, + "eq": eq, + "order": order, + "unsafe_hash": unsafe_hash, + } + super().__init_subclass__() + + class DeclarativeBase( inspection.Inspectable[InstanceState[Any]], metaclass=DeclarativeAttributeIntercept, ): """Base class used for declarative class definitions. + The :class:`_orm.DeclarativeBase` allows for the creation of new declarative bases in such a way that is compatible with type checkers:: @@ -1121,7 +1183,7 @@ class registry: bases = not isinstance(cls, tuple) and (cls,) or cls - class_dict = dict(registry=self, metadata=metadata) + class_dict: Dict[str, Any] = dict(registry=self, metadata=metadata) if isinstance(cls, type): class_dict["__doc__"] = cls.__doc__ @@ -1142,6 +1204,78 @@ class registry: return metaclass(name, bases, class_dict) + @compat_typing.dataclass_transform( + field_descriptors=( + MappedColumn[Any], + Relationship[Any], + Composite[Any], + ColumnProperty[Any], + Synonym[Any], + mapped_column, + relationship, + composite, + column_property, + synonym, + deferred, + query_expression, + ), + ) + @overload + def mapped_as_dataclass(self, __cls: Type[_O]) -> Type[_O]: + ... + + @overload + def mapped_as_dataclass( + self, + __cls: Literal[None] = ..., + *, + init: bool = True, + repr: bool = True, # noqa: A002 + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + ) -> Callable[[Type[_O]], Type[_O]]: + ... + + def mapped_as_dataclass( + self, + __cls: Optional[Type[_O]] = None, + *, + init: bool = True, + repr: bool = True, # noqa: A002 + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + ) -> Union[Type[_O], Callable[[Type[_O]], Type[_O]]]: + """Class decorator that will apply the Declarative mapping process + to a given class, and additionally convert the class to be a + Python dataclass. + + .. seealso:: + + :meth:`_orm.registry.mapped` + + .. versionadded:: 2.0 + + + """ + + def decorate(cls: Type[_O]) -> Type[_O]: + cls._sa_apply_dc_transforms = { + "init": init, + "repr": repr, + "eq": eq, + "order": order, + "unsafe_hash": unsafe_hash, + } + _as_declarative(self, cls, cls.__dict__) + return cls + + if __cls: + return decorate(__cls) + else: + return decorate + def mapped(self, cls: Type[_O]) -> Type[_O]: """Class decorator that will apply the Declarative mapping process to a given class. @@ -1174,6 +1308,10 @@ class registry: that will apply Declarative mapping to subclasses automatically using a Python metaclass. + .. seealso:: + + :meth:`_orm.registry.mapped_as_dataclass` + """ _as_declarative(self, cls, cls.__dict__) return cls diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index a66421e22..54a272f86 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -10,6 +10,8 @@ from __future__ import annotations import collections +import dataclasses +import re from typing import Any from typing import Callable from typing import cast @@ -40,6 +42,7 @@ from .base import _is_mapped_class from .base import InspectionAttr from .descriptor_props import Composite from .descriptor_props import Synonym +from .interfaces import _AttributeOptions from .interfaces import _IntrospectsAnnotations from .interfaces import _MappedAttribute from .interfaces import _MapsColumns @@ -48,15 +51,18 @@ from .mapper import Mapper as mapper from .mapper import Mapper from .properties import ColumnProperty from .properties import MappedColumn +from .util import _extract_mapped_subtype from .util import _is_mapped_annotation from .util import class_mapper from .. import event from .. import exc from .. import util from ..sql import expression +from ..sql.base import _NoArg from ..sql.schema import Column from ..sql.schema import Table from ..util import topological +from ..util.typing import _AnnotationScanType from ..util.typing import Protocol if TYPE_CHECKING: @@ -392,11 +398,13 @@ class _ClassScanMapperConfig(_MapperConfig): "mapper_args", "mapper_args_fn", "inherits", + "allow_dataclass_fields", + "dataclass_setup_arguments", ) registry: _RegistryType clsdict_view: _ClassDict - collected_annotations: Dict[str, Tuple[Any, bool]] + collected_annotations: Dict[str, Tuple[Any, Any, bool]] collected_attributes: Dict[str, Any] local_table: Optional[FromClause] persist_selectable: Optional[FromClause] @@ -411,6 +419,17 @@ class _ClassScanMapperConfig(_MapperConfig): mapper_args_fn: Optional[Callable[[], Dict[str, Any]]] inherits: Optional[Type[Any]] + dataclass_setup_arguments: Optional[Dict[str, Any]] + """if the class has SQLAlchemy native dataclass parameters, where + we will create a SQLAlchemy dataclass (not a real dataclass). + + """ + + allow_dataclass_fields: bool + """if true, look for dataclass-processed Field objects on the target + class as well as superclasses and extract ORM mapping directives from + the "metadata" attribute of each Field""" + def __init__( self, registry: _RegistryType, @@ -434,10 +453,37 @@ class _ClassScanMapperConfig(_MapperConfig): self.declared_columns = util.OrderedSet() self.column_copies = {} + self.dataclass_setup_arguments = dca = getattr( + self.cls, "_sa_apply_dc_transforms", None + ) + + cld = dataclasses.is_dataclass(cls_) + + sdk = _get_immediate_cls_attr(cls_, "__sa_dataclass_metadata_key__") + + # we don't want to consume Field objects from a not-already-dataclass. + # the Field objects won't have their "name" or "type" populated, + # and while it seems like we could just set these on Field as we + # read them, Field is documented as "user read only" and we need to + # stay far away from any off-label use of dataclasses APIs. + if (not cld or dca) and sdk: + raise exc.InvalidRequestError( + "SQLAlchemy mapped dataclasses can't consume mapping " + "information from dataclass.Field() objects if the immediate " + "class is not already a dataclass." + ) + + # if already a dataclass, and __sa_dataclass_metadata_key__ present, + # then also look inside of dataclass.Field() objects yielded by + # dataclasses.get_fields(cls) when scanning for attributes + self.allow_dataclass_fields = bool(sdk and cld) + self._setup_declared_events() self._scan_attributes() + self._setup_dataclasses_transforms() + with mapperlib._CONFIGURE_MUTEX: clsregistry.add_class( self.classname, self.cls, registry._class_registry @@ -477,11 +523,15 @@ class _ClassScanMapperConfig(_MapperConfig): attribute, taking SQLAlchemy-enabled dataclass fields into account. """ - sa_dataclass_metadata_key = _get_immediate_cls_attr( - cls, "__sa_dataclass_metadata_key__" - ) - if sa_dataclass_metadata_key is None: + if self.allow_dataclass_fields: + sa_dataclass_metadata_key = _get_immediate_cls_attr( + cls, "__sa_dataclass_metadata_key__" + ) + else: + sa_dataclass_metadata_key = None + + if not sa_dataclass_metadata_key: def attribute_is_overridden(key: str, obj: Any) -> bool: return getattr(cls, key) is not obj @@ -551,6 +601,7 @@ class _ClassScanMapperConfig(_MapperConfig): "__dict__", "__weakref__", "_sa_class_manager", + "_sa_apply_dc_transforms", "__dict__", "__weakref__", ] @@ -563,10 +614,6 @@ class _ClassScanMapperConfig(_MapperConfig): adjusting for SQLAlchemy fields embedded in dataclass fields. """ - sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr( - cls, "__sa_dataclass_metadata_key__" - ) - cls_annotations = util.get_annotations(cls) cls_vars = vars(cls) @@ -576,7 +623,15 @@ class _ClassScanMapperConfig(_MapperConfig): names = util.merge_lists_w_ordering( [n for n in cls_vars if n not in skip], list(cls_annotations) ) - if sa_dataclass_metadata_key is None: + + if self.allow_dataclass_fields: + sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr( + cls, "__sa_dataclass_metadata_key__" + ) + else: + sa_dataclass_metadata_key = None + + if not sa_dataclass_metadata_key: def local_attributes_for_class() -> Iterable[ Tuple[str, Any, Any, bool] @@ -652,45 +707,51 @@ class _ClassScanMapperConfig(_MapperConfig): name, obj, annotation, - is_dataclass, + is_dataclass_field, ) in local_attributes_for_class(): - if name == "__mapper_args__": - check_decl = _check_declared_props_nocascade( - obj, name, cls - ) - if not mapper_args_fn and (not class_mapped or check_decl): - # don't even invoke __mapper_args__ until - # after we've determined everything about the - # mapped table. - # make a copy of it so a class-level dictionary - # is not overwritten when we update column-based - # arguments. - def _mapper_args_fn() -> Dict[str, Any]: - return dict(cls_as_Decl.__mapper_args__) - - mapper_args_fn = _mapper_args_fn - - elif name == "__tablename__": - check_decl = _check_declared_props_nocascade( - obj, name, cls - ) - if not tablename and (not class_mapped or check_decl): - tablename = cls_as_Decl.__tablename__ - elif name == "__table_args__": - check_decl = _check_declared_props_nocascade( - obj, name, cls - ) - if not table_args and (not class_mapped or check_decl): - table_args = cls_as_Decl.__table_args__ - if not isinstance( - table_args, (tuple, dict, type(None)) + if re.match(r"^__.+__$", name): + if name == "__mapper_args__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + if not mapper_args_fn and ( + not class_mapped or check_decl ): - raise exc.ArgumentError( - "__table_args__ value must be a tuple, " - "dict, or None" - ) - if base is not cls: - inherited_table_args = True + # don't even invoke __mapper_args__ until + # after we've determined everything about the + # mapped table. + # make a copy of it so a class-level dictionary + # is not overwritten when we update column-based + # arguments. + def _mapper_args_fn() -> Dict[str, Any]: + return dict(cls_as_Decl.__mapper_args__) + + mapper_args_fn = _mapper_args_fn + + elif name == "__tablename__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + if not tablename and (not class_mapped or check_decl): + tablename = cls_as_Decl.__tablename__ + elif name == "__table_args__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + if not table_args and (not class_mapped or check_decl): + table_args = cls_as_Decl.__table_args__ + if not isinstance( + table_args, (tuple, dict, type(None)) + ): + raise exc.ArgumentError( + "__table_args__ value must be a tuple, " + "dict, or None" + ) + if base is not cls: + inherited_table_args = True + else: + # skip all other dunder names + continue elif class_mapped: if _is_declarative_props(obj): util.warn( @@ -706,9 +767,8 @@ class _ClassScanMapperConfig(_MapperConfig): # acting like that for now. if isinstance(obj, (Column, MappedColumn)): - self.collected_annotations[name] = ( - annotation, - False, + self._collect_annotation( + name, annotation, is_dataclass_field, True, obj ) # already copied columns to the mapped class. continue @@ -745,7 +805,7 @@ class _ClassScanMapperConfig(_MapperConfig): ] = ret = obj.__get__(obj, cls) setattr(cls, name, ret) else: - if is_dataclass: + if is_dataclass_field: # access attribute using normal class access # first, to see if it's been mapped on a # superclass. note if the dataclasses.field() @@ -789,14 +849,16 @@ class _ClassScanMapperConfig(_MapperConfig): ): ret.doc = obj.__doc__ - self.collected_annotations[name] = ( + self._collect_annotation( + name, obj._collect_return_annotation(), False, + True, + obj, ) elif _is_mapped_annotation(annotation, cls): - self.collected_annotations[name] = ( - annotation, - is_dataclass, + self._collect_annotation( + name, annotation, is_dataclass_field, True, obj ) if obj is None: if not fixed_table: @@ -809,7 +871,7 @@ class _ClassScanMapperConfig(_MapperConfig): # declarative mapping. however, check for some # more common mistakes self._warn_for_decl_attributes(base, name, obj) - elif is_dataclass and ( + elif is_dataclass_field and ( name not in clsdict_view or clsdict_view[name] is not obj ): # here, we are definitely looking at the target class @@ -826,14 +888,12 @@ class _ClassScanMapperConfig(_MapperConfig): obj = obj.fget() collected_attributes[name] = obj - self.collected_annotations[name] = ( - annotation, - True, + self._collect_annotation( + name, annotation, True, False, obj ) else: - self.collected_annotations[name] = ( - annotation, - False, + self._collect_annotation( + name, annotation, False, None, obj ) if ( obj is None @@ -843,6 +903,10 @@ class _ClassScanMapperConfig(_MapperConfig): collected_attributes[name] = MappedColumn() elif name in clsdict_view: collected_attributes[name] = obj + # else if the name is not in the cls.__dict__, + # don't collect it as an attribute. + # we will see the annotation only, which is meaningful + # both for mapping and dataclasses setup if inherited_table_args and not tablename: table_args = None @@ -851,6 +915,77 @@ class _ClassScanMapperConfig(_MapperConfig): self.tablename = tablename self.mapper_args_fn = mapper_args_fn + def _setup_dataclasses_transforms(self) -> None: + + dataclass_setup_arguments = self.dataclass_setup_arguments + if not dataclass_setup_arguments: + return + + manager = instrumentation.manager_of_class(self.cls) + assert manager is not None + + field_list = [ + _AttributeOptions._get_arguments_for_make_dataclass( + key, + anno, + self.collected_attributes.get(key, _NoArg.NO_ARG), + ) + for key, anno in ( + (key, mapped_anno if mapped_anno else raw_anno) + for key, ( + raw_anno, + mapped_anno, + is_dc, + ) in self.collected_annotations.items() + ) + ] + + annotations = {} + defaults = {} + for item in field_list: + if len(item) == 2: + name, tp = item # type: ignore + elif len(item) == 3: + name, tp, spec = item # type: ignore + defaults[name] = spec + else: + assert False + annotations[name] = tp + + for k, v in defaults.items(): + setattr(self.cls, k, v) + self.cls.__annotations__ = annotations + + dataclasses.dataclass(self.cls, **dataclass_setup_arguments) + + def _collect_annotation( + self, + name: str, + raw_annotation: _AnnotationScanType, + is_dataclass: bool, + expect_mapped: Optional[bool], + attr_value: Any, + ) -> None: + + if expect_mapped is None: + expect_mapped = isinstance(attr_value, _MappedAttribute) + + extracted_mapped_annotation = _extract_mapped_subtype( + raw_annotation, + self.cls, + name, + type(attr_value), + required=False, + is_dataclass_field=False, + expect_mapped=expect_mapped and not self.allow_dataclass_fields, + ) + + self.collected_annotations[name] = ( + raw_annotation, + extracted_mapped_annotation, + is_dataclass, + ) + def _warn_for_decl_attributes( self, cls: Type[Any], key: str, c: Any ) -> None: @@ -982,13 +1117,53 @@ class _ClassScanMapperConfig(_MapperConfig): _undefer_column_name( k, self.column_copies.get(value, value) # type: ignore ) - elif isinstance(value, _IntrospectsAnnotations): - annotation, is_dataclass = self.collected_annotations.get( - k, (None, False) - ) - value.declarative_scan( - self.registry, cls, k, annotation, is_dataclass - ) + else: + if isinstance(value, _IntrospectsAnnotations): + ( + annotation, + extracted_mapped_annotation, + is_dataclass, + ) = self.collected_annotations.get(k, (None, None, False)) + value.declarative_scan( + self.registry, + cls, + k, + annotation, + extracted_mapped_annotation, + is_dataclass, + ) + + if ( + isinstance(value, (MapperProperty, _MapsColumns)) + and value._has_dataclass_arguments + and not self.dataclass_setup_arguments + ): + if isinstance(value, MapperProperty): + argnames = [ + "init", + "default_factory", + "repr", + "default", + ] + else: + argnames = ["init", "default_factory", "repr"] + + args = { + a + for a in argnames + if getattr( + value._attribute_options, f"dataclasses_{a}" + ) + is not _NoArg.NO_ARG + } + raise exc.ArgumentError( + f"Attribute '{k}' on class {cls} includes dataclasses " + f"argument(s): " + f"{', '.join(sorted(repr(a) for a in args))} but " + f"class does not specify " + "SQLAlchemy native dataclass configuration." + ) + our_stuff[k] = value def _extract_declared_columns(self) -> None: @@ -997,6 +1172,7 @@ class _ClassScanMapperConfig(_MapperConfig): # extract columns from the class dict declared_columns = self.declared_columns name_to_prop_key = collections.defaultdict(set) + for key, c in list(our_stuff.items()): if isinstance(c, _MapsColumns): @@ -1019,7 +1195,6 @@ class _ClassScanMapperConfig(_MapperConfig): # otherwise, Mapper will map it under the column key. if mp_to_assign is None and key != col.key: our_stuff[key] = col - elif isinstance(c, Column): # undefer previously occurred here, and now occurs earlier. # ensure every column we get here has been named diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 8c89f96aa..a366a9534 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -35,11 +35,11 @@ from .base import LoaderCallableStatus from .base import Mapped from .base import PassiveFlag from .base import SQLORMOperations +from .interfaces import _AttributeOptions from .interfaces import _IntrospectsAnnotations from .interfaces import _MapsColumns from .interfaces import MapperProperty from .interfaces import PropComparator -from .util import _extract_mapped_subtype from .util import _none_set from .. import event from .. import exc as sa_exc @@ -200,24 +200,26 @@ class Composite( def __init__( self, - class_: Union[ + _class_or_attr: Union[ None, Type[_CC], Callable[..., _CC], _CompositeAttrType[Any] ] = None, *attrs: _CompositeAttrType[Any], + attribute_options: Optional[_AttributeOptions] = None, active_history: bool = False, deferred: bool = False, group: Optional[str] = None, comparator_factory: Optional[Type[Comparator[_CC]]] = None, info: Optional[_InfoType] = None, + **kwargs: Any, ): - super().__init__() + super().__init__(attribute_options=attribute_options) - if isinstance(class_, (Mapped, str, sql.ColumnElement)): - self.attrs = (class_,) + attrs + if isinstance(_class_or_attr, (Mapped, str, sql.ColumnElement)): + self.attrs = (_class_or_attr,) + attrs # will initialize within declarative_scan self.composite_class = None # type: ignore else: - self.composite_class = class_ # type: ignore + self.composite_class = _class_or_attr # type: ignore self.attrs = attrs self.active_history = active_history @@ -332,19 +334,15 @@ class Composite( cls: Type[Any], key: str, annotation: Optional[_AnnotationScanType], + extracted_mapped_annotation: Optional[_AnnotationScanType], is_dataclass_field: bool, ) -> None: - MappedColumn = util.preloaded.orm_properties.MappedColumn - - argument = _extract_mapped_subtype( - annotation, - cls, - key, - MappedColumn, - self.composite_class is None, - is_dataclass_field, - ) - + if ( + self.composite_class is None + and extracted_mapped_annotation is None + ): + self._raise_for_required(key, cls) + argument = extracted_mapped_annotation if argument and self.composite_class is None: if isinstance(argument, str) or hasattr( argument, "__forward_arg__" @@ -371,11 +369,18 @@ class Composite( for param, attr in itertools.zip_longest( insp.parameters.values(), self.attrs ): - if param is None or attr is None: + if param is None: raise sa_exc.ArgumentError( - f"number of arguments to {self.composite_class.__name__} " - f"class and number of attributes don't match" + f"number of composite attributes " + f"{len(self.attrs)} exceeds " + f"that of the number of attributes in class " + f"{self.composite_class.__name__} {len(insp.parameters)}" ) + if attr is None: + # fill in missing attr spots with empty MappedColumn + attr = MappedColumn() + self.attrs += (attr,) + if isinstance(attr, MappedColumn): attr.declarative_scan_for_composite( registry, cls, key, param.name, param.annotation @@ -800,10 +805,11 @@ class Synonym(DescriptorProperty[_T]): map_column: Optional[bool] = None, descriptor: Optional[Any] = None, comparator_factory: Optional[Type[PropComparator[_T]]] = None, + attribute_options: Optional[_AttributeOptions] = None, info: Optional[_InfoType] = None, doc: Optional[str] = None, ): - super().__init__() + super().__init__(attribute_options=attribute_options) self.name = name self.map_column = map_column diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py index 4fa61b7ce..33de2aee9 100644 --- a/lib/sqlalchemy/orm/instrumentation.py +++ b/lib/sqlalchemy/orm/instrumentation.py @@ -113,6 +113,7 @@ class ClassManager( "previously known as deferred_scalar_loader" init_method: Optional[Callable[..., None]] + original_init: Optional[Callable[..., None]] = None factory: Optional[_ManagerFactory] @@ -229,7 +230,7 @@ class ClassManager( if finalize and not self._finalized: self._finalize() - def _finalize(self): + def _finalize(self) -> None: if self._finalized: return self._finalized = True @@ -238,14 +239,14 @@ class ClassManager( _instrumentation_factory.dispatch.class_instrument(self.class_) - def __hash__(self): + def __hash__(self) -> int: # type: ignore[override] return id(self) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return other is self @property - def is_mapped(self): + def is_mapped(self) -> bool: return "mapper" in self.__dict__ @HasMemoized.memoized_attribute diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index b5569ce06..e0034061d 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -19,6 +19,7 @@ are exposed when inspecting mappings. from __future__ import annotations import collections +import dataclasses import typing from typing import Any from typing import Callable @@ -27,6 +28,8 @@ from typing import ClassVar from typing import Dict from typing import Iterator from typing import List +from typing import NamedTuple +from typing import NoReturn from typing import Optional from typing import Sequence from typing import Set @@ -51,11 +54,13 @@ from .base import ONETOMANY as ONETOMANY # noqa: F401 from .base import RelationshipDirection as RelationshipDirection # noqa: F401 from .base import SQLORMOperations from .. import ColumnElement +from .. import exc as sa_exc from .. import inspection from .. import util from ..sql import operators from ..sql import roles from ..sql import visitors +from ..sql.base import _NoArg from ..sql.base import ExecutableOption from ..sql.cache_key import HasCacheKey from ..sql.schema import Column @@ -141,6 +146,7 @@ class _IntrospectsAnnotations: cls: Type[Any], key: str, annotation: Optional[_AnnotationScanType], + extracted_mapped_annotation: Optional[_AnnotationScanType], is_dataclass_field: bool, ) -> None: """Perform class-specific initializaton at early declarative scanning @@ -150,6 +156,70 @@ class _IntrospectsAnnotations: """ + def _raise_for_required(self, key: str, cls: Type[Any]) -> NoReturn: + raise sa_exc.ArgumentError( + f"Python typing annotation is required for attribute " + f'"{cls.__name__}.{key}" when primary argument(s) for ' + f'"{self.__class__.__name__}" construct are None or not present' + ) + + +class _AttributeOptions(NamedTuple): + """define Python-local attribute behavior options common to all + :class:`.MapperProperty` objects. + + Currently this includes dataclass-generation arguments. + + .. versionadded:: 2.0 + + """ + + dataclasses_init: Union[_NoArg, bool] + dataclasses_repr: Union[_NoArg, bool] + dataclasses_default: Union[_NoArg, Any] + dataclasses_default_factory: Union[_NoArg, Callable[[], Any]] + + def _as_dataclass_field(self) -> Any: + """Return a ``dataclasses.Field`` object given these arguments.""" + + kw: Dict[str, Any] = {} + if self.dataclasses_default_factory is not _NoArg.NO_ARG: + kw["default_factory"] = self.dataclasses_default_factory + if self.dataclasses_default is not _NoArg.NO_ARG: + kw["default"] = self.dataclasses_default + if self.dataclasses_init is not _NoArg.NO_ARG: + kw["init"] = self.dataclasses_init + if self.dataclasses_repr is not _NoArg.NO_ARG: + kw["repr"] = self.dataclasses_repr + + return dataclasses.field(**kw) + + @classmethod + def _get_arguments_for_make_dataclass( + cls, key: str, annotation: Type[Any], elem: _T + ) -> Union[ + Tuple[str, Type[Any]], Tuple[str, Type[Any], dataclasses.Field[Any]] + ]: + """given attribute key, annotation, and value from a class, return + the argument tuple we would pass to dataclasses.make_dataclass() + for this attribute. + + """ + if isinstance(elem, (MapperProperty, _MapsColumns)): + dc_field = elem._attribute_options._as_dataclass_field() + + return (key, annotation, dc_field) + elif elem is not _NoArg.NO_ARG: + # why is typing not erroring on this? + return (key, annotation, elem) + else: + return (key, annotation) + + +_DEFAULT_ATTRIBUTE_OPTIONS = _AttributeOptions( + _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG +) + class _MapsColumns(_MappedAttribute[_T]): """interface for declarative-capable construct that delivers one or more @@ -158,6 +228,9 @@ class _MapsColumns(_MappedAttribute[_T]): __slots__ = () + _attribute_options: _AttributeOptions + _has_dataclass_arguments: bool + @property def mapper_property_to_assign(self) -> Optional[MapperProperty[_T]]: """return a MapperProperty to be assigned to the declarative mapping""" @@ -199,6 +272,8 @@ class MapperProperty( __slots__ = ( "_configure_started", "_configure_finished", + "_attribute_options", + "_has_dataclass_arguments", "parent", "key", "info", @@ -241,6 +316,15 @@ class MapperProperty( doc: Optional[str] """optional documentation string""" + _attribute_options: _AttributeOptions + """behavioral options for ORM-enabled Python attributes + + .. versionadded:: 2.0 + + """ + + _has_dataclass_arguments: bool + def _memoized_attr_info(self) -> _InfoType: """Info dictionary associated with the object, allowing user-defined data to be associated with this :class:`.InspectionAttr`. @@ -349,9 +433,20 @@ class MapperProperty( """ - def __init__(self) -> None: + def __init__( + self, attribute_options: Optional[_AttributeOptions] = None + ) -> None: self._configure_started = False self._configure_finished = False + if ( + attribute_options + and attribute_options != _DEFAULT_ATTRIBUTE_OPTIONS + ): + self._has_dataclass_arguments = True + self._attribute_options = attribute_options + else: + self._has_dataclass_arguments = False + self._attribute_options = _DEFAULT_ATTRIBUTE_OPTIONS def init(self) -> None: """Called after all mappers are created to assemble diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index ad3e9f248..7655f3ae2 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -30,13 +30,14 @@ from . import strategy_options from .descriptor_props import Composite from .descriptor_props import ConcreteInheritedProperty from .descriptor_props import Synonym +from .interfaces import _AttributeOptions +from .interfaces import _DEFAULT_ATTRIBUTE_OPTIONS from .interfaces import _IntrospectsAnnotations from .interfaces import _MapsColumns from .interfaces import MapperProperty from .interfaces import PropComparator from .interfaces import StrategizedProperty from .relationships import Relationship -from .util import _extract_mapped_subtype from .util import _orm_full_deannotate from .. import exc as sa_exc from .. import ForeignKey @@ -45,6 +46,7 @@ from .. import util from ..sql import coercions from ..sql import roles from ..sql import sqltypes +from ..sql.base import _NoArg from ..sql.elements import SQLCoreOperations from ..sql.schema import Column from ..sql.schema import SchemaConst @@ -131,6 +133,7 @@ class ColumnProperty( self, column: _ORMColumnExprArgument[_T], *additional_columns: _ORMColumnExprArgument[Any], + attribute_options: Optional[_AttributeOptions] = None, group: Optional[str] = None, deferred: bool = False, raiseload: bool = False, @@ -141,7 +144,9 @@ class ColumnProperty( doc: Optional[str] = None, _instrument: bool = True, ): - super(ColumnProperty, self).__init__() + super(ColumnProperty, self).__init__( + attribute_options=attribute_options + ) columns = (column,) + additional_columns self._orig_columns = [ coercions.expect(roles.LabeledColumnExprRole, c) for c in columns @@ -193,6 +198,7 @@ class ColumnProperty( cls: Type[Any], key: str, annotation: Optional[_AnnotationScanType], + extracted_mapped_annotation: Optional[_AnnotationScanType], is_dataclass_field: bool, ) -> None: column = self.columns[0] @@ -487,13 +493,38 @@ class MappedColumn( "foreign_keys", "_has_nullable", "deferred", + "_attribute_options", + "_has_dataclass_arguments", ) deferred: bool column: Column[_T] foreign_keys: Optional[Set[ForeignKey]] + _attribute_options: _AttributeOptions def __init__(self, *arg: Any, **kw: Any): + self._attribute_options = attr_opts = kw.pop( + "attribute_options", _DEFAULT_ATTRIBUTE_OPTIONS + ) + + self._has_dataclass_arguments = False + + if attr_opts is not None and attr_opts != _DEFAULT_ATTRIBUTE_OPTIONS: + if attr_opts.dataclasses_default_factory is not _NoArg.NO_ARG: + self._has_dataclass_arguments = True + kw["default"] = attr_opts.dataclasses_default_factory + elif attr_opts.dataclasses_default is not _NoArg.NO_ARG: + kw["default"] = attr_opts.dataclasses_default + + if ( + attr_opts.dataclasses_init is not _NoArg.NO_ARG + or attr_opts.dataclasses_repr is not _NoArg.NO_ARG + ): + self._has_dataclass_arguments = True + + if "default" in kw and kw["default"] is _NoArg.NO_ARG: + kw.pop("default") + self.deferred = kw.pop("deferred", False) self.column = cast("Column[_T]", Column(*arg, **kw)) self.foreign_keys = self.column.foreign_keys @@ -509,13 +540,19 @@ class MappedColumn( new.deferred = self.deferred new.foreign_keys = new.column.foreign_keys new._has_nullable = self._has_nullable + new._attribute_options = self._attribute_options + new._has_dataclass_arguments = self._has_dataclass_arguments util.set_creation_order(new) return new @property def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]: if self.deferred: - return ColumnProperty(self.column, deferred=True) + return ColumnProperty( + self.column, + deferred=True, + attribute_options=self._attribute_options, + ) else: return None @@ -543,6 +580,7 @@ class MappedColumn( cls: Type[Any], key: str, annotation: Optional[_AnnotationScanType], + extracted_mapped_annotation: Optional[_AnnotationScanType], is_dataclass_field: bool, ) -> None: column = self.column @@ -553,18 +591,15 @@ class MappedColumn( sqltype = column.type - argument = _extract_mapped_subtype( - annotation, - cls, - key, - MappedColumn, - sqltype._isnull and not self.column.foreign_keys, - is_dataclass_field, - ) - if argument is None: - return + if extracted_mapped_annotation is None: + if sqltype._isnull and not self.column.foreign_keys: + self._raise_for_required(key, cls) + else: + return - self._init_column_for_annotation(cls, registry, argument) + self._init_column_for_annotation( + cls, registry, extracted_mapped_annotation + ) @util.preload_module("sqlalchemy.orm.decl_base") def declarative_scan_for_composite( diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 1186f0f54..deaf52147 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -49,6 +49,7 @@ from .base import class_mapper from .base import LoaderCallableStatus from .base import PassiveFlag from .base import state_str +from .interfaces import _AttributeOptions from .interfaces import _IntrospectsAnnotations from .interfaces import MANYTOMANY from .interfaces import MANYTOONE @@ -56,7 +57,6 @@ from .interfaces import ONETOMANY from .interfaces import PropComparator from .interfaces import RelationshipDirection from .interfaces import StrategizedProperty -from .util import _extract_mapped_subtype from .util import _orm_annotate from .util import _orm_deannotate from .util import CascadeOptions @@ -355,6 +355,7 @@ class Relationship( post_update: bool = False, cascade: str = "save-update, merge", viewonly: bool = False, + attribute_options: Optional[_AttributeOptions] = None, lazy: _LazyLoadArgumentType = "select", passive_deletes: Union[Literal["all"], bool] = False, passive_updates: bool = True, @@ -380,7 +381,7 @@ class Relationship( _local_remote_pairs: Optional[_ColumnPairs] = None, _legacy_inactive_history_style: bool = False, ): - super(Relationship, self).__init__() + super(Relationship, self).__init__(attribute_options=attribute_options) self.uselist = uselist self.argument = argument @@ -1701,18 +1702,19 @@ class Relationship( cls: Type[Any], key: str, annotation: Optional[_AnnotationScanType], + extracted_mapped_annotation: Optional[_AnnotationScanType], is_dataclass_field: bool, ) -> None: - argument = _extract_mapped_subtype( - annotation, - cls, - key, - Relationship, - self.argument is None, - is_dataclass_field, - ) - if argument is None: - return + argument = extracted_mapped_annotation + + if extracted_mapped_annotation is None: + + if self.argument is None: + self._raise_for_required(key, cls) + else: + return + + argument = extracted_mapped_annotation if hasattr(argument, "__origin__"): diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index c50cc5bac..520c95672 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -1927,7 +1927,7 @@ def _getitem(iterable_query: Query[Any], item: Any) -> Any: def _is_mapped_annotation( - raw_annotation: Union[type, str], cls: Type[Any] + raw_annotation: _AnnotationScanType, cls: Type[Any] ) -> bool: annotated = de_stringify_annotation(cls, raw_annotation) return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm") @@ -1969,9 +1969,14 @@ def _extract_mapped_subtype( attr_cls: Type[Any], required: bool, is_dataclass_field: bool, - superclasses: Optional[Tuple[Type[Any], ...]] = None, + expect_mapped: bool = True, ) -> Optional[Union[type, str]]: + """given an annotation, figure out if it's ``Mapped[something]`` and if + so, return the ``something`` part. + Includes error raise scenarios and other options. + + """ if raw_annotation is None: if required: @@ -1989,25 +1994,29 @@ def _extract_mapped_subtype( if is_dataclass_field: return annotated else: - # TODO: there don't seem to be tests for the failure - # conditions here - if not hasattr(annotated, "__origin__") or ( - not issubclass( - annotated.__origin__, # type: ignore - superclasses if superclasses else attr_cls, - ) - and not issubclass(attr_cls, annotated.__origin__) # type: ignore + if not hasattr(annotated, "__origin__") or not is_origin_of( + annotated, "Mapped", module="sqlalchemy.orm" ): - our_annotated_str = ( - annotated.__name__ + anno_name = ( + getattr(annotated, "__name__", None) if not isinstance(annotated, str) - else repr(annotated) - ) - raise sa_exc.ArgumentError( - f'Type annotation for "{cls.__name__}.{key}" should use the ' - f'syntax "Mapped[{our_annotated_str}]" or ' - f'"{attr_cls.__name__}[{our_annotated_str}]".' + else None ) + if anno_name is None: + our_annotated_str = repr(annotated) + else: + our_annotated_str = anno_name + + if expect_mapped: + raise sa_exc.ArgumentError( + f'Type annotation for "{cls.__name__}.{key}" ' + "should use the " + f'syntax "Mapped[{our_annotated_str}]" or ' + f'"{attr_cls.__name__}[{our_annotated_str}]".' + ) + + else: + return annotated if len(annotated.__args__) != 1: # type: ignore raise sa_exc.ArgumentError( diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index 53f76f3ce..d4e4d2dca 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -25,6 +25,7 @@ from .. import event from .. import util from ..orm import declarative_base from ..orm import DeclarativeBase +from ..orm import MappedAsDataclass from ..orm import registry from ..schema import sort_tables_and_constraints @@ -90,7 +91,14 @@ class TestBase: @config.fixture() def registry(self, metadata): - reg = registry(metadata=metadata) + reg = registry( + metadata=metadata, + type_annotation_map={ + str: sa.String().with_variant( + sa.String(50), "mysql", "mariadb" + ) + }, + ) yield reg reg.dispose() @@ -109,6 +117,21 @@ class TestBase: yield Base Base.registry.dispose() + @config.fixture + def dc_decl_base(self, metadata): + _md = metadata + + class Base(MappedAsDataclass, DeclarativeBase): + metadata = _md + type_annotation_map = { + str: sa.String().with_variant( + sa.String(50), "mysql", "mariadb" + ) + } + + yield Base + Base.registry.dispose() + @config.fixture() def future_connection(self, future_engine, connection): # integrate the future_engine and connection fixtures so diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index adbbf143f..4ce1e7ff3 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -230,7 +230,11 @@ def inspect_formatargspec( def dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]: """Return a sequence of all dataclasses.Field objects associated - with a class.""" + with a class as an already processed dataclass. + + The class must **already be a dataclass** for Field objects to be returned. + + """ if dataclasses.is_dataclass(cls): return dataclasses.fields(cls) @@ -240,7 +244,12 @@ def dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]: def local_dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]: """Return a sequence of all dataclasses.Field objects associated with - a class, excluding those that originate from a superclass.""" + an already processed dataclass, excluding those that originate from a + superclass. + + The class must **already be a dataclass** for Field objects to be returned. + + """ if dataclasses.is_dataclass(cls): super_fields: Set[dataclasses.Field[Any]] = set() diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 44e26f609..454de100b 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -23,6 +23,14 @@ from typing_extensions import NotRequired as NotRequired # noqa: F401 from . import compat + +# more zimports issues +if True: + from typing_extensions import ( # noqa: F401 + dataclass_transform as dataclass_transform, + ) + + _T = TypeVar("_T", bound=Any) _KT = TypeVar("_KT") _KT_co = TypeVar("_KT_co", covariant=True) |
