summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-02-18 10:05:12 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2022-05-20 14:19:02 -0400
commita463b1109abb60fc85f8356f30c0351a4e2ed71e (patch)
treede8f96b7bce319fc0f19f56b302202ea3e4e91db
parent9e7bed9df601ead02fd96bf2fc787b23b536d2d6 (diff)
downloadsqlalchemy-a463b1109abb60fc85f8356f30c0351a4e2ed71e.tar.gz
implement dataclass_transforms
Implement a new means of creating a mapped dataclass where instead of applying the `@dataclass` decorator distinctly, the declarative process itself can create the dataclass. MapperProperty and MappedColumn objects themselves take the place of the dataclasses.Field object when constructing the class. The overall approach is made possible at the typing level using pep-681 dataclass transforms [1]. This new approach should be able to completely supersede the previous "dataclasses" approach of embedding metadata into Field() objects, which remains a mutually exclusive declarative setup style (mixing them introduces new issues that are not worth solving). [1] https://peps.python.org/pep-0681/#transform-descriptor-types-example Fixes: #7642 Change-Id: I6ba88a87c5df38270317b4faf085904d91c8a63c
-rw-r--r--lib/sqlalchemy/orm/__init__.py1
-rw-r--r--lib/sqlalchemy/orm/_orm_constructors.py167
-rw-r--r--lib/sqlalchemy/orm/decl_api.py142
-rw-r--r--lib/sqlalchemy/orm/decl_base.py315
-rw-r--r--lib/sqlalchemy/orm/descriptor_props.py48
-rw-r--r--lib/sqlalchemy/orm/instrumentation.py9
-rw-r--r--lib/sqlalchemy/orm/interfaces.py97
-rw-r--r--lib/sqlalchemy/orm/properties.py63
-rw-r--r--lib/sqlalchemy/orm/relationships.py26
-rw-r--r--lib/sqlalchemy/orm/util.py45
-rw-r--r--lib/sqlalchemy/testing/fixtures.py25
-rw-r--r--lib/sqlalchemy/util/compat.py13
-rw-r--r--lib/sqlalchemy/util/typing.py8
-rw-r--r--pyproject.toml1
-rw-r--r--setup.cfg2
-rw-r--r--test/orm/declarative/test_dc_transforms.py816
-rw-r--r--test/orm/declarative/test_typed_mapping.py46
17 files changed, 1661 insertions, 163 deletions
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py
index b7d1df532..4f19ba946 100644
--- a/lib/sqlalchemy/orm/__init__.py
+++ b/lib/sqlalchemy/orm/__init__.py
@@ -60,6 +60,7 @@ from .decl_api import DeclarativeBaseNoMeta as DeclarativeBaseNoMeta
from .decl_api import DeclarativeMeta as DeclarativeMeta
from .decl_api import declared_attr as declared_attr
from .decl_api import has_inherited_table as has_inherited_table
+from .decl_api import MappedAsDataclass as MappedAsDataclass
from .decl_api import registry as registry
from .decl_api import synonym_for as synonym_for
from .descriptor_props import Composite as Composite
diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py
index 0692cac09..ece6a52be 100644
--- a/lib/sqlalchemy/orm/_orm_constructors.py
+++ b/lib/sqlalchemy/orm/_orm_constructors.py
@@ -21,9 +21,9 @@ from typing import Union
from . import mapperlib as mapperlib
from ._typing import _O
-from .base import Mapped
from .descriptor_props import Composite
from .descriptor_props import Synonym
+from .interfaces import _AttributeOptions
from .properties import ColumnProperty
from .properties import MappedColumn
from .query import AliasOption
@@ -37,6 +37,8 @@ from .util import LoaderCriteriaOption
from .. import sql
from .. import util
from ..exc import InvalidRequestError
+from ..sql._typing import _no_kw
+from ..sql.base import _NoArg
from ..sql.base import SchemaEventTarget
from ..sql.schema import SchemaConst
from ..sql.selectable import FromClause
@@ -105,6 +107,10 @@ def mapped_column(
Union[_TypeEngineArgument[Any], SchemaEventTarget]
] = None,
*args: SchemaEventTarget,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Optional[Any] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
nullable: Optional[
Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]]
] = SchemaConst.NULL_UNSPECIFIED,
@@ -113,7 +119,6 @@ def mapped_column(
name: Optional[str] = None,
type_: Optional[_TypeEngineArgument[Any]] = None,
autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto",
- default: Optional[Any] = None,
doc: Optional[str] = None,
key: Optional[str] = None,
index: Optional[bool] = None,
@@ -300,6 +305,12 @@ def mapped_column(
type_=type_,
autoincrement=autoincrement,
default=default,
+ attribute_options=_AttributeOptions(
+ init,
+ repr,
+ default,
+ default_factory,
+ ),
doc=doc,
key=key,
index=index,
@@ -325,6 +336,10 @@ def column_property(
deferred: bool = False,
raiseload: bool = False,
comparator_factory: Optional[Type[PropComparator[_T]]] = None,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Optional[Any] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
active_history: bool = False,
expire_on_flush: bool = True,
info: Optional[_InfoType] = None,
@@ -416,6 +431,12 @@ def column_property(
return ColumnProperty(
column,
*additional_columns,
+ attribute_options=_AttributeOptions(
+ init,
+ repr,
+ default,
+ default_factory,
+ ),
group=group,
deferred=deferred,
raiseload=raiseload,
@@ -429,25 +450,61 @@ def column_property(
@overload
def composite(
- class_: Type[_CC],
+ _class_or_attr: Type[_CC],
*attrs: _CompositeAttrType[Any],
- **kwargs: Any,
+ group: Optional[str] = None,
+ deferred: bool = False,
+ raiseload: bool = False,
+ comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None,
+ active_history: bool = False,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Optional[Any] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+ info: Optional[_InfoType] = None,
+ doc: Optional[str] = None,
+ **__kw: Any,
) -> Composite[_CC]:
...
@overload
def composite(
+ _class_or_attr: _CompositeAttrType[Any],
*attrs: _CompositeAttrType[Any],
- **kwargs: Any,
+ group: Optional[str] = None,
+ deferred: bool = False,
+ raiseload: bool = False,
+ comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None,
+ active_history: bool = False,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Optional[Any] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+ info: Optional[_InfoType] = None,
+ doc: Optional[str] = None,
+ **__kw: Any,
) -> Composite[Any]:
...
def composite(
- class_: Any = None,
+ _class_or_attr: Union[
+ None, Type[_CC], Callable[..., _CC], _CompositeAttrType[Any]
+ ] = None,
*attrs: _CompositeAttrType[Any],
- **kwargs: Any,
+ group: Optional[str] = None,
+ deferred: bool = False,
+ raiseload: bool = False,
+ comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None,
+ active_history: bool = False,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Optional[Any] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+ info: Optional[_InfoType] = None,
+ doc: Optional[str] = None,
+ **__kw: Any,
) -> Composite[Any]:
r"""Return a composite column-based property for use with a Mapper.
@@ -497,7 +554,26 @@ def composite(
:attr:`.MapperProperty.info` attribute of this object.
"""
- return Composite(class_, *attrs, **kwargs)
+ if __kw:
+ raise _no_kw()
+
+ return Composite(
+ _class_or_attr,
+ *attrs,
+ attribute_options=_AttributeOptions(
+ init,
+ repr,
+ default,
+ default_factory,
+ ),
+ group=group,
+ deferred=deferred,
+ raiseload=raiseload,
+ comparator_factory=comparator_factory,
+ active_history=active_history,
+ info=info,
+ doc=doc,
+ )
def with_loader_criteria(
@@ -700,6 +776,10 @@ def relationship(
post_update: bool = False,
cascade: str = "save-update, merge",
viewonly: bool = False,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Union[_NoArg, _T] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
lazy: _LazyLoadArgumentType = "select",
passive_deletes: Union[Literal["all"], bool] = False,
passive_updates: bool = True,
@@ -1532,6 +1612,12 @@ def relationship(
post_update=post_update,
cascade=cascade,
viewonly=viewonly,
+ attribute_options=_AttributeOptions(
+ init,
+ repr,
+ default,
+ default_factory,
+ ),
lazy=lazy,
passive_deletes=passive_deletes,
passive_updates=passive_updates,
@@ -1559,6 +1645,10 @@ def synonym(
map_column: Optional[bool] = None,
descriptor: Optional[Any] = None,
comparator_factory: Optional[Type[PropComparator[_T]]] = None,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Union[_NoArg, _T] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
info: Optional[_InfoType] = None,
doc: Optional[str] = None,
) -> Synonym[Any]:
@@ -1670,6 +1760,12 @@ def synonym(
map_column=map_column,
descriptor=descriptor,
comparator_factory=comparator_factory,
+ attribute_options=_AttributeOptions(
+ init,
+ repr,
+ default,
+ default_factory,
+ ),
doc=doc,
info=info,
)
@@ -1784,7 +1880,17 @@ def backref(name: str, **kwargs: Any) -> _ORMBackrefArgument:
def deferred(
column: _ORMColumnExprArgument[_T],
*additional_columns: _ORMColumnExprArgument[Any],
- **kw: Any,
+ group: Optional[str] = None,
+ raiseload: bool = False,
+ comparator_factory: Optional[Type[PropComparator[_T]]] = None,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Optional[Any] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+ active_history: bool = False,
+ expire_on_flush: bool = True,
+ info: Optional[_InfoType] = None,
+ doc: Optional[str] = None,
) -> ColumnProperty[_T]:
r"""Indicate a column-based mapped attribute that by default will
not load unless accessed.
@@ -1803,21 +1909,41 @@ def deferred(
:ref:`deferred_raiseload`
- :param \**kw: additional keyword arguments passed to
- :class:`.ColumnProperty`.
+ Additional arguments are the same as that of :func:`_orm.column_property`.
.. seealso::
:ref:`deferred`
"""
- kw["deferred"] = True
- return ColumnProperty(column, *additional_columns, **kw)
+ return ColumnProperty(
+ column,
+ *additional_columns,
+ attribute_options=_AttributeOptions(
+ init,
+ repr,
+ default,
+ default_factory,
+ ),
+ group=group,
+ deferred=True,
+ raiseload=raiseload,
+ comparator_factory=comparator_factory,
+ active_history=active_history,
+ expire_on_flush=expire_on_flush,
+ info=info,
+ doc=doc,
+ )
def query_expression(
default_expr: _ORMColumnExprArgument[_T] = sql.null(),
-) -> Mapped[_T]:
+ *,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ expire_on_flush: bool = True,
+ info: Optional[_InfoType] = None,
+ doc: Optional[str] = None,
+) -> ColumnProperty[_T]:
"""Indicate an attribute that populates from a query-time SQL expression.
:param default_expr: Optional SQL expression object that will be used in
@@ -1840,7 +1966,18 @@ def query_expression(
:ref:`mapper_querytime_expression`
"""
- prop = ColumnProperty(default_expr)
+ prop = ColumnProperty(
+ default_expr,
+ attribute_options=_AttributeOptions(
+ _NoArg.NO_ARG,
+ repr,
+ _NoArg.NO_ARG,
+ _NoArg.NO_ARG,
+ ),
+ expire_on_flush=expire_on_flush,
+ info=info,
+ doc=doc,
+ )
prop.strategy_key = (("query_expression", True),)
return prop
diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py
index 1c343b04c..553a50107 100644
--- a/lib/sqlalchemy/orm/decl_api.py
+++ b/lib/sqlalchemy/orm/decl_api.py
@@ -33,6 +33,13 @@ from . import clsregistry
from . import instrumentation
from . import interfaces
from . import mapperlib
+from ._orm_constructors import column_property
+from ._orm_constructors import composite
+from ._orm_constructors import deferred
+from ._orm_constructors import mapped_column
+from ._orm_constructors import query_expression
+from ._orm_constructors import relationship
+from ._orm_constructors import synonym
from .attributes import InstrumentedAttribute
from .base import _inspect_mapped_class
from .base import Mapped
@@ -42,8 +49,13 @@ from .decl_base import _declarative_constructor
from .decl_base import _DeferredMapperConfig
from .decl_base import _del_attribute
from .decl_base import _mapper
+from .descriptor_props import Composite
+from .descriptor_props import Synonym
from .descriptor_props import Synonym as _orm_synonym
from .mapper import Mapper
+from .properties import ColumnProperty
+from .properties import MappedColumn
+from .relationships import Relationship
from .state import InstanceState
from .. import exc
from .. import inspection
@@ -60,9 +72,9 @@ from ..util.typing import Literal
if TYPE_CHECKING:
from ._typing import _O
from ._typing import _RegistryType
- from .descriptor_props import Synonym
from .instrumentation import ClassManager
from .interfaces import MapperProperty
+ from .state import InstanceState # noqa
from ..sql._typing import _TypeEngineArgument
_T = TypeVar("_T", bound=Any)
@@ -120,6 +132,26 @@ class DeclarativeAttributeIntercept(
"""
+@compat_typing.dataclass_transform(
+ field_descriptors=(
+ MappedColumn[Any],
+ Relationship[Any],
+ Composite[Any],
+ ColumnProperty[Any],
+ Synonym[Any],
+ mapped_column,
+ relationship,
+ composite,
+ column_property,
+ synonym,
+ deferred,
+ query_expression,
+ ),
+)
+class DCTransformDeclarative(DeclarativeAttributeIntercept):
+ """metaclass that includes @dataclass_transforms"""
+
+
class DeclarativeMeta(
_DynamicAttributesType, inspection.Inspectable[Mapper[Any]]
):
@@ -543,12 +575,42 @@ class DeclarativeBaseNoMeta(inspection.Inspectable[Mapper[Any]]):
cls._sa_registry.map_declaratively(cls)
+class MappedAsDataclass(metaclass=DCTransformDeclarative):
+ """Mixin class to indicate when mapping this class, also convert it to be
+ a dataclass.
+
+ .. seealso::
+
+ :meth:`_orm.registry.mapped_as_dataclass`
+
+ .. versionadded:: 2.0
+ """
+
+ def __init_subclass__(
+ cls,
+ init: bool = True,
+ repr: bool = True, # noqa: A002
+ eq: bool = True,
+ order: bool = False,
+ unsafe_hash: bool = False,
+ ) -> None:
+ cls._sa_apply_dc_transforms = {
+ "init": init,
+ "repr": repr,
+ "eq": eq,
+ "order": order,
+ "unsafe_hash": unsafe_hash,
+ }
+ super().__init_subclass__()
+
+
class DeclarativeBase(
inspection.Inspectable[InstanceState[Any]],
metaclass=DeclarativeAttributeIntercept,
):
"""Base class used for declarative class definitions.
+
The :class:`_orm.DeclarativeBase` allows for the creation of new
declarative bases in such a way that is compatible with type checkers::
@@ -1121,7 +1183,7 @@ class registry:
bases = not isinstance(cls, tuple) and (cls,) or cls
- class_dict = dict(registry=self, metadata=metadata)
+ class_dict: Dict[str, Any] = dict(registry=self, metadata=metadata)
if isinstance(cls, type):
class_dict["__doc__"] = cls.__doc__
@@ -1142,6 +1204,78 @@ class registry:
return metaclass(name, bases, class_dict)
+ @compat_typing.dataclass_transform(
+ field_descriptors=(
+ MappedColumn[Any],
+ Relationship[Any],
+ Composite[Any],
+ ColumnProperty[Any],
+ Synonym[Any],
+ mapped_column,
+ relationship,
+ composite,
+ column_property,
+ synonym,
+ deferred,
+ query_expression,
+ ),
+ )
+ @overload
+ def mapped_as_dataclass(self, __cls: Type[_O]) -> Type[_O]:
+ ...
+
+ @overload
+ def mapped_as_dataclass(
+ self,
+ __cls: Literal[None] = ...,
+ *,
+ init: bool = True,
+ repr: bool = True, # noqa: A002
+ eq: bool = True,
+ order: bool = False,
+ unsafe_hash: bool = False,
+ ) -> Callable[[Type[_O]], Type[_O]]:
+ ...
+
+ def mapped_as_dataclass(
+ self,
+ __cls: Optional[Type[_O]] = None,
+ *,
+ init: bool = True,
+ repr: bool = True, # noqa: A002
+ eq: bool = True,
+ order: bool = False,
+ unsafe_hash: bool = False,
+ ) -> Union[Type[_O], Callable[[Type[_O]], Type[_O]]]:
+ """Class decorator that will apply the Declarative mapping process
+ to a given class, and additionally convert the class to be a
+ Python dataclass.
+
+ .. seealso::
+
+ :meth:`_orm.registry.mapped`
+
+ .. versionadded:: 2.0
+
+
+ """
+
+ def decorate(cls: Type[_O]) -> Type[_O]:
+ cls._sa_apply_dc_transforms = {
+ "init": init,
+ "repr": repr,
+ "eq": eq,
+ "order": order,
+ "unsafe_hash": unsafe_hash,
+ }
+ _as_declarative(self, cls, cls.__dict__)
+ return cls
+
+ if __cls:
+ return decorate(__cls)
+ else:
+ return decorate
+
def mapped(self, cls: Type[_O]) -> Type[_O]:
"""Class decorator that will apply the Declarative mapping process
to a given class.
@@ -1174,6 +1308,10 @@ class registry:
that will apply Declarative mapping to subclasses automatically
using a Python metaclass.
+ .. seealso::
+
+ :meth:`_orm.registry.mapped_as_dataclass`
+
"""
_as_declarative(self, cls, cls.__dict__)
return cls
diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py
index a66421e22..54a272f86 100644
--- a/lib/sqlalchemy/orm/decl_base.py
+++ b/lib/sqlalchemy/orm/decl_base.py
@@ -10,6 +10,8 @@
from __future__ import annotations
import collections
+import dataclasses
+import re
from typing import Any
from typing import Callable
from typing import cast
@@ -40,6 +42,7 @@ from .base import _is_mapped_class
from .base import InspectionAttr
from .descriptor_props import Composite
from .descriptor_props import Synonym
+from .interfaces import _AttributeOptions
from .interfaces import _IntrospectsAnnotations
from .interfaces import _MappedAttribute
from .interfaces import _MapsColumns
@@ -48,15 +51,18 @@ from .mapper import Mapper as mapper
from .mapper import Mapper
from .properties import ColumnProperty
from .properties import MappedColumn
+from .util import _extract_mapped_subtype
from .util import _is_mapped_annotation
from .util import class_mapper
from .. import event
from .. import exc
from .. import util
from ..sql import expression
+from ..sql.base import _NoArg
from ..sql.schema import Column
from ..sql.schema import Table
from ..util import topological
+from ..util.typing import _AnnotationScanType
from ..util.typing import Protocol
if TYPE_CHECKING:
@@ -392,11 +398,13 @@ class _ClassScanMapperConfig(_MapperConfig):
"mapper_args",
"mapper_args_fn",
"inherits",
+ "allow_dataclass_fields",
+ "dataclass_setup_arguments",
)
registry: _RegistryType
clsdict_view: _ClassDict
- collected_annotations: Dict[str, Tuple[Any, bool]]
+ collected_annotations: Dict[str, Tuple[Any, Any, bool]]
collected_attributes: Dict[str, Any]
local_table: Optional[FromClause]
persist_selectable: Optional[FromClause]
@@ -411,6 +419,17 @@ class _ClassScanMapperConfig(_MapperConfig):
mapper_args_fn: Optional[Callable[[], Dict[str, Any]]]
inherits: Optional[Type[Any]]
+ dataclass_setup_arguments: Optional[Dict[str, Any]]
+ """if the class has SQLAlchemy native dataclass parameters, where
+ we will create a SQLAlchemy dataclass (not a real dataclass).
+
+ """
+
+ allow_dataclass_fields: bool
+ """if true, look for dataclass-processed Field objects on the target
+ class as well as superclasses and extract ORM mapping directives from
+ the "metadata" attribute of each Field"""
+
def __init__(
self,
registry: _RegistryType,
@@ -434,10 +453,37 @@ class _ClassScanMapperConfig(_MapperConfig):
self.declared_columns = util.OrderedSet()
self.column_copies = {}
+ self.dataclass_setup_arguments = dca = getattr(
+ self.cls, "_sa_apply_dc_transforms", None
+ )
+
+ cld = dataclasses.is_dataclass(cls_)
+
+ sdk = _get_immediate_cls_attr(cls_, "__sa_dataclass_metadata_key__")
+
+ # we don't want to consume Field objects from a not-already-dataclass.
+ # the Field objects won't have their "name" or "type" populated,
+ # and while it seems like we could just set these on Field as we
+ # read them, Field is documented as "user read only" and we need to
+ # stay far away from any off-label use of dataclasses APIs.
+ if (not cld or dca) and sdk:
+ raise exc.InvalidRequestError(
+ "SQLAlchemy mapped dataclasses can't consume mapping "
+ "information from dataclass.Field() objects if the immediate "
+ "class is not already a dataclass."
+ )
+
+ # if already a dataclass, and __sa_dataclass_metadata_key__ present,
+ # then also look inside of dataclass.Field() objects yielded by
+ # dataclasses.get_fields(cls) when scanning for attributes
+ self.allow_dataclass_fields = bool(sdk and cld)
+
self._setup_declared_events()
self._scan_attributes()
+ self._setup_dataclasses_transforms()
+
with mapperlib._CONFIGURE_MUTEX:
clsregistry.add_class(
self.classname, self.cls, registry._class_registry
@@ -477,11 +523,15 @@ class _ClassScanMapperConfig(_MapperConfig):
attribute, taking SQLAlchemy-enabled dataclass fields into account.
"""
- sa_dataclass_metadata_key = _get_immediate_cls_attr(
- cls, "__sa_dataclass_metadata_key__"
- )
- if sa_dataclass_metadata_key is None:
+ if self.allow_dataclass_fields:
+ sa_dataclass_metadata_key = _get_immediate_cls_attr(
+ cls, "__sa_dataclass_metadata_key__"
+ )
+ else:
+ sa_dataclass_metadata_key = None
+
+ if not sa_dataclass_metadata_key:
def attribute_is_overridden(key: str, obj: Any) -> bool:
return getattr(cls, key) is not obj
@@ -551,6 +601,7 @@ class _ClassScanMapperConfig(_MapperConfig):
"__dict__",
"__weakref__",
"_sa_class_manager",
+ "_sa_apply_dc_transforms",
"__dict__",
"__weakref__",
]
@@ -563,10 +614,6 @@ class _ClassScanMapperConfig(_MapperConfig):
adjusting for SQLAlchemy fields embedded in dataclass fields.
"""
- sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr(
- cls, "__sa_dataclass_metadata_key__"
- )
-
cls_annotations = util.get_annotations(cls)
cls_vars = vars(cls)
@@ -576,7 +623,15 @@ class _ClassScanMapperConfig(_MapperConfig):
names = util.merge_lists_w_ordering(
[n for n in cls_vars if n not in skip], list(cls_annotations)
)
- if sa_dataclass_metadata_key is None:
+
+ if self.allow_dataclass_fields:
+ sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr(
+ cls, "__sa_dataclass_metadata_key__"
+ )
+ else:
+ sa_dataclass_metadata_key = None
+
+ if not sa_dataclass_metadata_key:
def local_attributes_for_class() -> Iterable[
Tuple[str, Any, Any, bool]
@@ -652,45 +707,51 @@ class _ClassScanMapperConfig(_MapperConfig):
name,
obj,
annotation,
- is_dataclass,
+ is_dataclass_field,
) in local_attributes_for_class():
- if name == "__mapper_args__":
- check_decl = _check_declared_props_nocascade(
- obj, name, cls
- )
- if not mapper_args_fn and (not class_mapped or check_decl):
- # don't even invoke __mapper_args__ until
- # after we've determined everything about the
- # mapped table.
- # make a copy of it so a class-level dictionary
- # is not overwritten when we update column-based
- # arguments.
- def _mapper_args_fn() -> Dict[str, Any]:
- return dict(cls_as_Decl.__mapper_args__)
-
- mapper_args_fn = _mapper_args_fn
-
- elif name == "__tablename__":
- check_decl = _check_declared_props_nocascade(
- obj, name, cls
- )
- if not tablename and (not class_mapped or check_decl):
- tablename = cls_as_Decl.__tablename__
- elif name == "__table_args__":
- check_decl = _check_declared_props_nocascade(
- obj, name, cls
- )
- if not table_args and (not class_mapped or check_decl):
- table_args = cls_as_Decl.__table_args__
- if not isinstance(
- table_args, (tuple, dict, type(None))
+ if re.match(r"^__.+__$", name):
+ if name == "__mapper_args__":
+ check_decl = _check_declared_props_nocascade(
+ obj, name, cls
+ )
+ if not mapper_args_fn and (
+ not class_mapped or check_decl
):
- raise exc.ArgumentError(
- "__table_args__ value must be a tuple, "
- "dict, or None"
- )
- if base is not cls:
- inherited_table_args = True
+ # don't even invoke __mapper_args__ until
+ # after we've determined everything about the
+ # mapped table.
+ # make a copy of it so a class-level dictionary
+ # is not overwritten when we update column-based
+ # arguments.
+ def _mapper_args_fn() -> Dict[str, Any]:
+ return dict(cls_as_Decl.__mapper_args__)
+
+ mapper_args_fn = _mapper_args_fn
+
+ elif name == "__tablename__":
+ check_decl = _check_declared_props_nocascade(
+ obj, name, cls
+ )
+ if not tablename and (not class_mapped or check_decl):
+ tablename = cls_as_Decl.__tablename__
+ elif name == "__table_args__":
+ check_decl = _check_declared_props_nocascade(
+ obj, name, cls
+ )
+ if not table_args and (not class_mapped or check_decl):
+ table_args = cls_as_Decl.__table_args__
+ if not isinstance(
+ table_args, (tuple, dict, type(None))
+ ):
+ raise exc.ArgumentError(
+ "__table_args__ value must be a tuple, "
+ "dict, or None"
+ )
+ if base is not cls:
+ inherited_table_args = True
+ else:
+ # skip all other dunder names
+ continue
elif class_mapped:
if _is_declarative_props(obj):
util.warn(
@@ -706,9 +767,8 @@ class _ClassScanMapperConfig(_MapperConfig):
# acting like that for now.
if isinstance(obj, (Column, MappedColumn)):
- self.collected_annotations[name] = (
- annotation,
- False,
+ self._collect_annotation(
+ name, annotation, is_dataclass_field, True, obj
)
# already copied columns to the mapped class.
continue
@@ -745,7 +805,7 @@ class _ClassScanMapperConfig(_MapperConfig):
] = ret = obj.__get__(obj, cls)
setattr(cls, name, ret)
else:
- if is_dataclass:
+ if is_dataclass_field:
# access attribute using normal class access
# first, to see if it's been mapped on a
# superclass. note if the dataclasses.field()
@@ -789,14 +849,16 @@ class _ClassScanMapperConfig(_MapperConfig):
):
ret.doc = obj.__doc__
- self.collected_annotations[name] = (
+ self._collect_annotation(
+ name,
obj._collect_return_annotation(),
False,
+ True,
+ obj,
)
elif _is_mapped_annotation(annotation, cls):
- self.collected_annotations[name] = (
- annotation,
- is_dataclass,
+ self._collect_annotation(
+ name, annotation, is_dataclass_field, True, obj
)
if obj is None:
if not fixed_table:
@@ -809,7 +871,7 @@ class _ClassScanMapperConfig(_MapperConfig):
# declarative mapping. however, check for some
# more common mistakes
self._warn_for_decl_attributes(base, name, obj)
- elif is_dataclass and (
+ elif is_dataclass_field and (
name not in clsdict_view or clsdict_view[name] is not obj
):
# here, we are definitely looking at the target class
@@ -826,14 +888,12 @@ class _ClassScanMapperConfig(_MapperConfig):
obj = obj.fget()
collected_attributes[name] = obj
- self.collected_annotations[name] = (
- annotation,
- True,
+ self._collect_annotation(
+ name, annotation, True, False, obj
)
else:
- self.collected_annotations[name] = (
- annotation,
- False,
+ self._collect_annotation(
+ name, annotation, False, None, obj
)
if (
obj is None
@@ -843,6 +903,10 @@ class _ClassScanMapperConfig(_MapperConfig):
collected_attributes[name] = MappedColumn()
elif name in clsdict_view:
collected_attributes[name] = obj
+ # else if the name is not in the cls.__dict__,
+ # don't collect it as an attribute.
+ # we will see the annotation only, which is meaningful
+ # both for mapping and dataclasses setup
if inherited_table_args and not tablename:
table_args = None
@@ -851,6 +915,77 @@ class _ClassScanMapperConfig(_MapperConfig):
self.tablename = tablename
self.mapper_args_fn = mapper_args_fn
+ def _setup_dataclasses_transforms(self) -> None:
+
+ dataclass_setup_arguments = self.dataclass_setup_arguments
+ if not dataclass_setup_arguments:
+ return
+
+ manager = instrumentation.manager_of_class(self.cls)
+ assert manager is not None
+
+ field_list = [
+ _AttributeOptions._get_arguments_for_make_dataclass(
+ key,
+ anno,
+ self.collected_attributes.get(key, _NoArg.NO_ARG),
+ )
+ for key, anno in (
+ (key, mapped_anno if mapped_anno else raw_anno)
+ for key, (
+ raw_anno,
+ mapped_anno,
+ is_dc,
+ ) in self.collected_annotations.items()
+ )
+ ]
+
+ annotations = {}
+ defaults = {}
+ for item in field_list:
+ if len(item) == 2:
+ name, tp = item # type: ignore
+ elif len(item) == 3:
+ name, tp, spec = item # type: ignore
+ defaults[name] = spec
+ else:
+ assert False
+ annotations[name] = tp
+
+ for k, v in defaults.items():
+ setattr(self.cls, k, v)
+ self.cls.__annotations__ = annotations
+
+ dataclasses.dataclass(self.cls, **dataclass_setup_arguments)
+
+ def _collect_annotation(
+ self,
+ name: str,
+ raw_annotation: _AnnotationScanType,
+ is_dataclass: bool,
+ expect_mapped: Optional[bool],
+ attr_value: Any,
+ ) -> None:
+
+ if expect_mapped is None:
+ expect_mapped = isinstance(attr_value, _MappedAttribute)
+
+ extracted_mapped_annotation = _extract_mapped_subtype(
+ raw_annotation,
+ self.cls,
+ name,
+ type(attr_value),
+ required=False,
+ is_dataclass_field=False,
+ expect_mapped=expect_mapped and not self.allow_dataclass_fields,
+ )
+
+ self.collected_annotations[name] = (
+ raw_annotation,
+ extracted_mapped_annotation,
+ is_dataclass,
+ )
+
def _warn_for_decl_attributes(
self, cls: Type[Any], key: str, c: Any
) -> None:
@@ -982,13 +1117,53 @@ class _ClassScanMapperConfig(_MapperConfig):
_undefer_column_name(
k, self.column_copies.get(value, value) # type: ignore
)
- elif isinstance(value, _IntrospectsAnnotations):
- annotation, is_dataclass = self.collected_annotations.get(
- k, (None, False)
- )
- value.declarative_scan(
- self.registry, cls, k, annotation, is_dataclass
- )
+ else:
+ if isinstance(value, _IntrospectsAnnotations):
+ (
+ annotation,
+ extracted_mapped_annotation,
+ is_dataclass,
+ ) = self.collected_annotations.get(k, (None, None, False))
+ value.declarative_scan(
+ self.registry,
+ cls,
+ k,
+ annotation,
+ extracted_mapped_annotation,
+ is_dataclass,
+ )
+
+ if (
+ isinstance(value, (MapperProperty, _MapsColumns))
+ and value._has_dataclass_arguments
+ and not self.dataclass_setup_arguments
+ ):
+ if isinstance(value, MapperProperty):
+ argnames = [
+ "init",
+ "default_factory",
+ "repr",
+ "default",
+ ]
+ else:
+ argnames = ["init", "default_factory", "repr"]
+
+ args = {
+ a
+ for a in argnames
+ if getattr(
+ value._attribute_options, f"dataclasses_{a}"
+ )
+ is not _NoArg.NO_ARG
+ }
+ raise exc.ArgumentError(
+ f"Attribute '{k}' on class {cls} includes dataclasses "
+ f"argument(s): "
+ f"{', '.join(sorted(repr(a) for a in args))} but "
+ f"class does not specify "
+ "SQLAlchemy native dataclass configuration."
+ )
+
our_stuff[k] = value
def _extract_declared_columns(self) -> None:
@@ -997,6 +1172,7 @@ class _ClassScanMapperConfig(_MapperConfig):
# extract columns from the class dict
declared_columns = self.declared_columns
name_to_prop_key = collections.defaultdict(set)
+
for key, c in list(our_stuff.items()):
if isinstance(c, _MapsColumns):
@@ -1019,7 +1195,6 @@ class _ClassScanMapperConfig(_MapperConfig):
# otherwise, Mapper will map it under the column key.
if mp_to_assign is None and key != col.key:
our_stuff[key] = col
-
elif isinstance(c, Column):
# undefer previously occurred here, and now occurs earlier.
# ensure every column we get here has been named
diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py
index 8c89f96aa..a366a9534 100644
--- a/lib/sqlalchemy/orm/descriptor_props.py
+++ b/lib/sqlalchemy/orm/descriptor_props.py
@@ -35,11 +35,11 @@ from .base import LoaderCallableStatus
from .base import Mapped
from .base import PassiveFlag
from .base import SQLORMOperations
+from .interfaces import _AttributeOptions
from .interfaces import _IntrospectsAnnotations
from .interfaces import _MapsColumns
from .interfaces import MapperProperty
from .interfaces import PropComparator
-from .util import _extract_mapped_subtype
from .util import _none_set
from .. import event
from .. import exc as sa_exc
@@ -200,24 +200,26 @@ class Composite(
def __init__(
self,
- class_: Union[
+ _class_or_attr: Union[
None, Type[_CC], Callable[..., _CC], _CompositeAttrType[Any]
] = None,
*attrs: _CompositeAttrType[Any],
+ attribute_options: Optional[_AttributeOptions] = None,
active_history: bool = False,
deferred: bool = False,
group: Optional[str] = None,
comparator_factory: Optional[Type[Comparator[_CC]]] = None,
info: Optional[_InfoType] = None,
+ **kwargs: Any,
):
- super().__init__()
+ super().__init__(attribute_options=attribute_options)
- if isinstance(class_, (Mapped, str, sql.ColumnElement)):
- self.attrs = (class_,) + attrs
+ if isinstance(_class_or_attr, (Mapped, str, sql.ColumnElement)):
+ self.attrs = (_class_or_attr,) + attrs
# will initialize within declarative_scan
self.composite_class = None # type: ignore
else:
- self.composite_class = class_ # type: ignore
+ self.composite_class = _class_or_attr # type: ignore
self.attrs = attrs
self.active_history = active_history
@@ -332,19 +334,15 @@ class Composite(
cls: Type[Any],
key: str,
annotation: Optional[_AnnotationScanType],
+ extracted_mapped_annotation: Optional[_AnnotationScanType],
is_dataclass_field: bool,
) -> None:
- MappedColumn = util.preloaded.orm_properties.MappedColumn
-
- argument = _extract_mapped_subtype(
- annotation,
- cls,
- key,
- MappedColumn,
- self.composite_class is None,
- is_dataclass_field,
- )
-
+ if (
+ self.composite_class is None
+ and extracted_mapped_annotation is None
+ ):
+ self._raise_for_required(key, cls)
+ argument = extracted_mapped_annotation
if argument and self.composite_class is None:
if isinstance(argument, str) or hasattr(
argument, "__forward_arg__"
@@ -371,11 +369,18 @@ class Composite(
for param, attr in itertools.zip_longest(
insp.parameters.values(), self.attrs
):
- if param is None or attr is None:
+ if param is None:
raise sa_exc.ArgumentError(
- f"number of arguments to {self.composite_class.__name__} "
- f"class and number of attributes don't match"
+ f"number of composite attributes "
+ f"{len(self.attrs)} exceeds "
+ f"that of the number of attributes in class "
+ f"{self.composite_class.__name__} {len(insp.parameters)}"
)
+ if attr is None:
+ # fill in missing attr spots with empty MappedColumn
+ attr = MappedColumn()
+ self.attrs += (attr,)
+
if isinstance(attr, MappedColumn):
attr.declarative_scan_for_composite(
registry, cls, key, param.name, param.annotation
@@ -800,10 +805,11 @@ class Synonym(DescriptorProperty[_T]):
map_column: Optional[bool] = None,
descriptor: Optional[Any] = None,
comparator_factory: Optional[Type[PropComparator[_T]]] = None,
+ attribute_options: Optional[_AttributeOptions] = None,
info: Optional[_InfoType] = None,
doc: Optional[str] = None,
):
- super().__init__()
+ super().__init__(attribute_options=attribute_options)
self.name = name
self.map_column = map_column
diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py
index 4fa61b7ce..33de2aee9 100644
--- a/lib/sqlalchemy/orm/instrumentation.py
+++ b/lib/sqlalchemy/orm/instrumentation.py
@@ -113,6 +113,7 @@ class ClassManager(
"previously known as deferred_scalar_loader"
init_method: Optional[Callable[..., None]]
+ original_init: Optional[Callable[..., None]] = None
factory: Optional[_ManagerFactory]
@@ -229,7 +230,7 @@ class ClassManager(
if finalize and not self._finalized:
self._finalize()
- def _finalize(self):
+ def _finalize(self) -> None:
if self._finalized:
return
self._finalized = True
@@ -238,14 +239,14 @@ class ClassManager(
_instrumentation_factory.dispatch.class_instrument(self.class_)
- def __hash__(self):
+ def __hash__(self) -> int: # type: ignore[override]
return id(self)
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return other is self
@property
- def is_mapped(self):
+ def is_mapped(self) -> bool:
return "mapper" in self.__dict__
@HasMemoized.memoized_attribute
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index b5569ce06..e0034061d 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -19,6 +19,7 @@ are exposed when inspecting mappings.
from __future__ import annotations
import collections
+import dataclasses
import typing
from typing import Any
from typing import Callable
@@ -27,6 +28,8 @@ from typing import ClassVar
from typing import Dict
from typing import Iterator
from typing import List
+from typing import NamedTuple
+from typing import NoReturn
from typing import Optional
from typing import Sequence
from typing import Set
@@ -51,11 +54,13 @@ from .base import ONETOMANY as ONETOMANY # noqa: F401
from .base import RelationshipDirection as RelationshipDirection # noqa: F401
from .base import SQLORMOperations
from .. import ColumnElement
+from .. import exc as sa_exc
from .. import inspection
from .. import util
from ..sql import operators
from ..sql import roles
from ..sql import visitors
+from ..sql.base import _NoArg
from ..sql.base import ExecutableOption
from ..sql.cache_key import HasCacheKey
from ..sql.schema import Column
@@ -141,6 +146,7 @@ class _IntrospectsAnnotations:
cls: Type[Any],
key: str,
annotation: Optional[_AnnotationScanType],
+ extracted_mapped_annotation: Optional[_AnnotationScanType],
is_dataclass_field: bool,
) -> None:
"""Perform class-specific initializaton at early declarative scanning
@@ -150,6 +156,70 @@ class _IntrospectsAnnotations:
"""
+ def _raise_for_required(self, key: str, cls: Type[Any]) -> NoReturn:
+ raise sa_exc.ArgumentError(
+ f"Python typing annotation is required for attribute "
+ f'"{cls.__name__}.{key}" when primary argument(s) for '
+ f'"{self.__class__.__name__}" construct are None or not present'
+ )
+
+
+class _AttributeOptions(NamedTuple):
+ """define Python-local attribute behavior options common to all
+ :class:`.MapperProperty` objects.
+
+ Currently this includes dataclass-generation arguments.
+
+ .. versionadded:: 2.0
+
+ """
+
+ dataclasses_init: Union[_NoArg, bool]
+ dataclasses_repr: Union[_NoArg, bool]
+ dataclasses_default: Union[_NoArg, Any]
+ dataclasses_default_factory: Union[_NoArg, Callable[[], Any]]
+
+ def _as_dataclass_field(self) -> Any:
+ """Return a ``dataclasses.Field`` object given these arguments."""
+
+ kw: Dict[str, Any] = {}
+ if self.dataclasses_default_factory is not _NoArg.NO_ARG:
+ kw["default_factory"] = self.dataclasses_default_factory
+ if self.dataclasses_default is not _NoArg.NO_ARG:
+ kw["default"] = self.dataclasses_default
+ if self.dataclasses_init is not _NoArg.NO_ARG:
+ kw["init"] = self.dataclasses_init
+ if self.dataclasses_repr is not _NoArg.NO_ARG:
+ kw["repr"] = self.dataclasses_repr
+
+ return dataclasses.field(**kw)
+
+ @classmethod
+ def _get_arguments_for_make_dataclass(
+ cls, key: str, annotation: Type[Any], elem: _T
+ ) -> Union[
+ Tuple[str, Type[Any]], Tuple[str, Type[Any], dataclasses.Field[Any]]
+ ]:
+ """given attribute key, annotation, and value from a class, return
+ the argument tuple we would pass to dataclasses.make_dataclass()
+ for this attribute.
+
+ """
+ if isinstance(elem, (MapperProperty, _MapsColumns)):
+ dc_field = elem._attribute_options._as_dataclass_field()
+
+ return (key, annotation, dc_field)
+ elif elem is not _NoArg.NO_ARG:
+ # why is typing not erroring on this?
+ return (key, annotation, elem)
+ else:
+ return (key, annotation)
+
+
+_DEFAULT_ATTRIBUTE_OPTIONS = _AttributeOptions(
+ _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG
+)
+
class _MapsColumns(_MappedAttribute[_T]):
"""interface for declarative-capable construct that delivers one or more
@@ -158,6 +228,9 @@ class _MapsColumns(_MappedAttribute[_T]):
__slots__ = ()
+ _attribute_options: _AttributeOptions
+ _has_dataclass_arguments: bool
+
@property
def mapper_property_to_assign(self) -> Optional[MapperProperty[_T]]:
"""return a MapperProperty to be assigned to the declarative mapping"""
@@ -199,6 +272,8 @@ class MapperProperty(
__slots__ = (
"_configure_started",
"_configure_finished",
+ "_attribute_options",
+ "_has_dataclass_arguments",
"parent",
"key",
"info",
@@ -241,6 +316,15 @@ class MapperProperty(
doc: Optional[str]
"""optional documentation string"""
+ _attribute_options: _AttributeOptions
+ """behavioral options for ORM-enabled Python attributes
+
+ .. versionadded:: 2.0
+
+ """
+
+ _has_dataclass_arguments: bool
+
def _memoized_attr_info(self) -> _InfoType:
"""Info dictionary associated with the object, allowing user-defined
data to be associated with this :class:`.InspectionAttr`.
@@ -349,9 +433,20 @@ class MapperProperty(
"""
- def __init__(self) -> None:
+ def __init__(
+ self, attribute_options: Optional[_AttributeOptions] = None
+ ) -> None:
self._configure_started = False
self._configure_finished = False
+ if (
+ attribute_options
+ and attribute_options != _DEFAULT_ATTRIBUTE_OPTIONS
+ ):
+ self._has_dataclass_arguments = True
+ self._attribute_options = attribute_options
+ else:
+ self._has_dataclass_arguments = False
+ self._attribute_options = _DEFAULT_ATTRIBUTE_OPTIONS
def init(self) -> None:
"""Called after all mappers are created to assemble
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index ad3e9f248..7655f3ae2 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -30,13 +30,14 @@ from . import strategy_options
from .descriptor_props import Composite
from .descriptor_props import ConcreteInheritedProperty
from .descriptor_props import Synonym
+from .interfaces import _AttributeOptions
+from .interfaces import _DEFAULT_ATTRIBUTE_OPTIONS
from .interfaces import _IntrospectsAnnotations
from .interfaces import _MapsColumns
from .interfaces import MapperProperty
from .interfaces import PropComparator
from .interfaces import StrategizedProperty
from .relationships import Relationship
-from .util import _extract_mapped_subtype
from .util import _orm_full_deannotate
from .. import exc as sa_exc
from .. import ForeignKey
@@ -45,6 +46,7 @@ from .. import util
from ..sql import coercions
from ..sql import roles
from ..sql import sqltypes
+from ..sql.base import _NoArg
from ..sql.elements import SQLCoreOperations
from ..sql.schema import Column
from ..sql.schema import SchemaConst
@@ -131,6 +133,7 @@ class ColumnProperty(
self,
column: _ORMColumnExprArgument[_T],
*additional_columns: _ORMColumnExprArgument[Any],
+ attribute_options: Optional[_AttributeOptions] = None,
group: Optional[str] = None,
deferred: bool = False,
raiseload: bool = False,
@@ -141,7 +144,9 @@ class ColumnProperty(
doc: Optional[str] = None,
_instrument: bool = True,
):
- super(ColumnProperty, self).__init__()
+ super(ColumnProperty, self).__init__(
+ attribute_options=attribute_options
+ )
columns = (column,) + additional_columns
self._orig_columns = [
coercions.expect(roles.LabeledColumnExprRole, c) for c in columns
@@ -193,6 +198,7 @@ class ColumnProperty(
cls: Type[Any],
key: str,
annotation: Optional[_AnnotationScanType],
+ extracted_mapped_annotation: Optional[_AnnotationScanType],
is_dataclass_field: bool,
) -> None:
column = self.columns[0]
@@ -487,13 +493,38 @@ class MappedColumn(
"foreign_keys",
"_has_nullable",
"deferred",
+ "_attribute_options",
+ "_has_dataclass_arguments",
)
deferred: bool
column: Column[_T]
foreign_keys: Optional[Set[ForeignKey]]
+ _attribute_options: _AttributeOptions
def __init__(self, *arg: Any, **kw: Any):
+ self._attribute_options = attr_opts = kw.pop(
+ "attribute_options", _DEFAULT_ATTRIBUTE_OPTIONS
+ )
+
+ self._has_dataclass_arguments = False
+
+ if attr_opts is not None and attr_opts != _DEFAULT_ATTRIBUTE_OPTIONS:
+ if attr_opts.dataclasses_default_factory is not _NoArg.NO_ARG:
+ self._has_dataclass_arguments = True
+ kw["default"] = attr_opts.dataclasses_default_factory
+ elif attr_opts.dataclasses_default is not _NoArg.NO_ARG:
+ kw["default"] = attr_opts.dataclasses_default
+
+ if (
+ attr_opts.dataclasses_init is not _NoArg.NO_ARG
+ or attr_opts.dataclasses_repr is not _NoArg.NO_ARG
+ ):
+ self._has_dataclass_arguments = True
+
+ if "default" in kw and kw["default"] is _NoArg.NO_ARG:
+ kw.pop("default")
+
self.deferred = kw.pop("deferred", False)
self.column = cast("Column[_T]", Column(*arg, **kw))
self.foreign_keys = self.column.foreign_keys
@@ -509,13 +540,19 @@ class MappedColumn(
new.deferred = self.deferred
new.foreign_keys = new.column.foreign_keys
new._has_nullable = self._has_nullable
+ new._attribute_options = self._attribute_options
+ new._has_dataclass_arguments = self._has_dataclass_arguments
util.set_creation_order(new)
return new
@property
def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]:
if self.deferred:
- return ColumnProperty(self.column, deferred=True)
+ return ColumnProperty(
+ self.column,
+ deferred=True,
+ attribute_options=self._attribute_options,
+ )
else:
return None
@@ -543,6 +580,7 @@ class MappedColumn(
cls: Type[Any],
key: str,
annotation: Optional[_AnnotationScanType],
+ extracted_mapped_annotation: Optional[_AnnotationScanType],
is_dataclass_field: bool,
) -> None:
column = self.column
@@ -553,18 +591,15 @@ class MappedColumn(
sqltype = column.type
- argument = _extract_mapped_subtype(
- annotation,
- cls,
- key,
- MappedColumn,
- sqltype._isnull and not self.column.foreign_keys,
- is_dataclass_field,
- )
- if argument is None:
- return
+ if extracted_mapped_annotation is None:
+ if sqltype._isnull and not self.column.foreign_keys:
+ self._raise_for_required(key, cls)
+ else:
+ return
- self._init_column_for_annotation(cls, registry, argument)
+ self._init_column_for_annotation(
+ cls, registry, extracted_mapped_annotation
+ )
@util.preload_module("sqlalchemy.orm.decl_base")
def declarative_scan_for_composite(
diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py
index 1186f0f54..deaf52147 100644
--- a/lib/sqlalchemy/orm/relationships.py
+++ b/lib/sqlalchemy/orm/relationships.py
@@ -49,6 +49,7 @@ from .base import class_mapper
from .base import LoaderCallableStatus
from .base import PassiveFlag
from .base import state_str
+from .interfaces import _AttributeOptions
from .interfaces import _IntrospectsAnnotations
from .interfaces import MANYTOMANY
from .interfaces import MANYTOONE
@@ -56,7 +57,6 @@ from .interfaces import ONETOMANY
from .interfaces import PropComparator
from .interfaces import RelationshipDirection
from .interfaces import StrategizedProperty
-from .util import _extract_mapped_subtype
from .util import _orm_annotate
from .util import _orm_deannotate
from .util import CascadeOptions
@@ -355,6 +355,7 @@ class Relationship(
post_update: bool = False,
cascade: str = "save-update, merge",
viewonly: bool = False,
+ attribute_options: Optional[_AttributeOptions] = None,
lazy: _LazyLoadArgumentType = "select",
passive_deletes: Union[Literal["all"], bool] = False,
passive_updates: bool = True,
@@ -380,7 +381,7 @@ class Relationship(
_local_remote_pairs: Optional[_ColumnPairs] = None,
_legacy_inactive_history_style: bool = False,
):
- super(Relationship, self).__init__()
+ super(Relationship, self).__init__(attribute_options=attribute_options)
self.uselist = uselist
self.argument = argument
@@ -1701,18 +1702,19 @@ class Relationship(
cls: Type[Any],
key: str,
annotation: Optional[_AnnotationScanType],
+ extracted_mapped_annotation: Optional[_AnnotationScanType],
is_dataclass_field: bool,
) -> None:
- argument = _extract_mapped_subtype(
- annotation,
- cls,
- key,
- Relationship,
- self.argument is None,
- is_dataclass_field,
- )
- if argument is None:
- return
+ argument = extracted_mapped_annotation
+
+ if extracted_mapped_annotation is None:
+
+ if self.argument is None:
+ self._raise_for_required(key, cls)
+ else:
+ return
+
+ argument = extracted_mapped_annotation
if hasattr(argument, "__origin__"):
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index c50cc5bac..520c95672 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -1927,7 +1927,7 @@ def _getitem(iterable_query: Query[Any], item: Any) -> Any:
def _is_mapped_annotation(
- raw_annotation: Union[type, str], cls: Type[Any]
+ raw_annotation: _AnnotationScanType, cls: Type[Any]
) -> bool:
annotated = de_stringify_annotation(cls, raw_annotation)
return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm")
@@ -1969,9 +1969,14 @@ def _extract_mapped_subtype(
attr_cls: Type[Any],
required: bool,
is_dataclass_field: bool,
- superclasses: Optional[Tuple[Type[Any], ...]] = None,
+ expect_mapped: bool = True,
) -> Optional[Union[type, str]]:
+ """given an annotation, figure out if it's ``Mapped[something]`` and if
+ so, return the ``something`` part.
+ Includes error raise scenarios and other options.
+
+ """
if raw_annotation is None:
if required:
@@ -1989,25 +1994,29 @@ def _extract_mapped_subtype(
if is_dataclass_field:
return annotated
else:
- # TODO: there don't seem to be tests for the failure
- # conditions here
- if not hasattr(annotated, "__origin__") or (
- not issubclass(
- annotated.__origin__, # type: ignore
- superclasses if superclasses else attr_cls,
- )
- and not issubclass(attr_cls, annotated.__origin__) # type: ignore
+ if not hasattr(annotated, "__origin__") or not is_origin_of(
+ annotated, "Mapped", module="sqlalchemy.orm"
):
- our_annotated_str = (
- annotated.__name__
+ anno_name = (
+ getattr(annotated, "__name__", None)
if not isinstance(annotated, str)
- else repr(annotated)
- )
- raise sa_exc.ArgumentError(
- f'Type annotation for "{cls.__name__}.{key}" should use the '
- f'syntax "Mapped[{our_annotated_str}]" or '
- f'"{attr_cls.__name__}[{our_annotated_str}]".'
+ else None
)
+ if anno_name is None:
+ our_annotated_str = repr(annotated)
+ else:
+ our_annotated_str = anno_name
+
+ if expect_mapped:
+ raise sa_exc.ArgumentError(
+ f'Type annotation for "{cls.__name__}.{key}" '
+ "should use the "
+ f'syntax "Mapped[{our_annotated_str}]" or '
+ f'"{attr_cls.__name__}[{our_annotated_str}]".'
+ )
+
+ else:
+ return annotated
if len(annotated.__args__) != 1: # type: ignore
raise sa_exc.ArgumentError(
diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py
index 53f76f3ce..d4e4d2dca 100644
--- a/lib/sqlalchemy/testing/fixtures.py
+++ b/lib/sqlalchemy/testing/fixtures.py
@@ -25,6 +25,7 @@ from .. import event
from .. import util
from ..orm import declarative_base
from ..orm import DeclarativeBase
+from ..orm import MappedAsDataclass
from ..orm import registry
from ..schema import sort_tables_and_constraints
@@ -90,7 +91,14 @@ class TestBase:
@config.fixture()
def registry(self, metadata):
- reg = registry(metadata=metadata)
+ reg = registry(
+ metadata=metadata,
+ type_annotation_map={
+ str: sa.String().with_variant(
+ sa.String(50), "mysql", "mariadb"
+ )
+ },
+ )
yield reg
reg.dispose()
@@ -109,6 +117,21 @@ class TestBase:
yield Base
Base.registry.dispose()
+ @config.fixture
+ def dc_decl_base(self, metadata):
+ _md = metadata
+
+ class Base(MappedAsDataclass, DeclarativeBase):
+ metadata = _md
+ type_annotation_map = {
+ str: sa.String().with_variant(
+ sa.String(50), "mysql", "mariadb"
+ )
+ }
+
+ yield Base
+ Base.registry.dispose()
+
@config.fixture()
def future_connection(self, future_engine, connection):
# integrate the future_engine and connection fixtures so
diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py
index adbbf143f..4ce1e7ff3 100644
--- a/lib/sqlalchemy/util/compat.py
+++ b/lib/sqlalchemy/util/compat.py
@@ -230,7 +230,11 @@ def inspect_formatargspec(
def dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]:
"""Return a sequence of all dataclasses.Field objects associated
- with a class."""
+ with a class as an already processed dataclass.
+
+ The class must **already be a dataclass** for Field objects to be returned.
+
+ """
if dataclasses.is_dataclass(cls):
return dataclasses.fields(cls)
@@ -240,7 +244,12 @@ def dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]:
def local_dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]:
"""Return a sequence of all dataclasses.Field objects associated with
- a class, excluding those that originate from a superclass."""
+ an already processed dataclass, excluding those that originate from a
+ superclass.
+
+ The class must **already be a dataclass** for Field objects to be returned.
+
+ """
if dataclasses.is_dataclass(cls):
super_fields: Set[dataclasses.Field[Any]] = set()
diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py
index 44e26f609..454de100b 100644
--- a/lib/sqlalchemy/util/typing.py
+++ b/lib/sqlalchemy/util/typing.py
@@ -23,6 +23,14 @@ from typing_extensions import NotRequired as NotRequired # noqa: F401
from . import compat
+
+# more zimports issues
+if True:
+ from typing_extensions import ( # noqa: F401
+ dataclass_transform as dataclass_transform,
+ )
+
+
_T = TypeVar("_T", bound=Any)
_KT = TypeVar("_KT")
_KT_co = TypeVar("_KT_co", covariant=True)
diff --git a/pyproject.toml b/pyproject.toml
index 29d59ea69..812d60e91 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -40,7 +40,6 @@ markers = [
[tool.pyright]
-
reportPrivateUsage = "none"
reportUnusedClass = "none"
reportUnusedFunction = "none"
diff --git a/setup.cfg b/setup.cfg
index 5ef2c6f22..0df41dc7b 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -37,7 +37,7 @@ package_dir =
install_requires =
importlib-metadata;python_version<"3.8"
greenlet != 0.4.17;(platform_machine=='aarch64' or (platform_machine=='ppc64le' or (platform_machine=='x86_64' or (platform_machine=='amd64' or (platform_machine=='AMD64' or (platform_machine=='win32' or platform_machine=='WIN32'))))))
- typing-extensions >= 4
+ typing-extensions >= 4.1.0
[options.extras_require]
asyncio =
diff --git a/test/orm/declarative/test_dc_transforms.py b/test/orm/declarative/test_dc_transforms.py
new file mode 100644
index 000000000..aac873723
--- /dev/null
+++ b/test/orm/declarative/test_dc_transforms.py
@@ -0,0 +1,816 @@
+import dataclasses
+import inspect as pyinspect
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import Set
+from typing import Type
+from unittest import mock
+
+from sqlalchemy import Column
+from sqlalchemy import exc
+from sqlalchemy import ForeignKey
+from sqlalchemy import inspect
+from sqlalchemy import Integer
+from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy import testing
+from sqlalchemy.orm import column_property
+from sqlalchemy.orm import composite
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import deferred
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import MappedAsDataclass
+from sqlalchemy.orm import MappedColumn
+from sqlalchemy.orm import registry as _RegistryType
+from sqlalchemy.orm import relationship
+from sqlalchemy.orm import Session
+from sqlalchemy.orm import synonym
+from sqlalchemy.testing import AssertsCompiledSQL
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import eq_regex
+from sqlalchemy.testing import expect_raises
+from sqlalchemy.testing import expect_raises_message
+from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_false
+from sqlalchemy.testing import is_true
+from sqlalchemy.testing import ne_
+
+
+class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase):
+ def test_basic_constructor_repr_base_cls(
+ self, dc_decl_base: Type[MappedAsDataclass]
+ ):
+ class A(dc_decl_base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ data: Mapped[str]
+
+ x: Mapped[Optional[int]] = mapped_column(default=None)
+
+ bs: Mapped[List["B"]] = relationship( # noqa: F821
+ default_factory=list
+ )
+
+ class B(dc_decl_base):
+ __tablename__ = "b"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ a_id = mapped_column(ForeignKey("a.id"), init=False)
+ data: Mapped[str]
+ x: Mapped[Optional[int]] = mapped_column(default=None)
+
+ A.__qualname__ = "some_module.A"
+ B.__qualname__ = "some_module.B"
+
+ eq_(
+ pyinspect.getfullargspec(A.__init__),
+ pyinspect.FullArgSpec(
+ args=["self", "data", "x", "bs"],
+ varargs=None,
+ varkw=None,
+ defaults=(None, mock.ANY),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={},
+ ),
+ )
+ eq_(
+ pyinspect.getfullargspec(B.__init__),
+ pyinspect.FullArgSpec(
+ args=["self", "data", "x"],
+ varargs=None,
+ varkw=None,
+ defaults=(None,),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={},
+ ),
+ )
+
+ a2 = A("10", x=5, bs=[B("data1"), B("data2", x=12)])
+ eq_(
+ repr(a2),
+ "some_module.A(id=None, data='10', x=5, "
+ "bs=[some_module.B(id=None, data='data1', a_id=None, x=None), "
+ "some_module.B(id=None, data='data2', a_id=None, x=12)])",
+ )
+
+ a3 = A("data")
+ eq_(repr(a3), "some_module.A(id=None, data='data', x=None, bs=[])")
+
+ def test_basic_constructor_repr_cls_decorator(
+ self, registry: _RegistryType
+ ):
+ @registry.mapped_as_dataclass()
+ class A:
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ data: Mapped[str]
+
+ x: Mapped[Optional[int]] = mapped_column(default=None)
+
+ bs: Mapped[List["B"]] = relationship( # noqa: F821
+ default_factory=list
+ )
+
+ @registry.mapped_as_dataclass()
+ class B:
+ __tablename__ = "b"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ a_id = mapped_column(ForeignKey("a.id"), init=False)
+ data: Mapped[str]
+ x: Mapped[Optional[int]] = mapped_column(default=None)
+
+ A.__qualname__ = "some_module.A"
+ B.__qualname__ = "some_module.B"
+
+ eq_(
+ pyinspect.getfullargspec(A.__init__),
+ pyinspect.FullArgSpec(
+ args=["self", "data", "x", "bs"],
+ varargs=None,
+ varkw=None,
+ defaults=(None, mock.ANY),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={},
+ ),
+ )
+ eq_(
+ pyinspect.getfullargspec(B.__init__),
+ pyinspect.FullArgSpec(
+ args=["self", "data", "x"],
+ varargs=None,
+ varkw=None,
+ defaults=(None,),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={},
+ ),
+ )
+
+ a2 = A("10", x=5, bs=[B("data1"), B("data2", x=12)])
+ eq_(
+ repr(a2),
+ "some_module.A(id=None, data='10', x=5, "
+ "bs=[some_module.B(id=None, data='data1', a_id=None, x=None), "
+ "some_module.B(id=None, data='data2', a_id=None, x=12)])",
+ )
+
+ a3 = A("data")
+ eq_(repr(a3), "some_module.A(id=None, data='data', x=None, bs=[])")
+
+ def test_default_fn(self, dc_decl_base: Type[MappedAsDataclass]):
+ class A(dc_decl_base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ data: Mapped[str] = mapped_column(default="d1")
+ data2: Mapped[str] = mapped_column(default_factory=lambda: "d2")
+
+ a1 = A()
+ eq_(a1.data, "d1")
+ eq_(a1.data2, "d2")
+
+ def test_default_factory_vs_collection_class(
+ self, dc_decl_base: Type[MappedAsDataclass]
+ ):
+ # this is currently the error raised by dataclasses. We can instead
+ # do this validation ourselves, but overall I don't know that we
+ # can hit every validation and rule that's in dataclasses
+ with expect_raises_message(
+ ValueError, "cannot specify both default and default_factory"
+ ):
+
+ class A(dc_decl_base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ data: Mapped[str] = mapped_column(
+ default="d1", default_factory=lambda: "d2"
+ )
+
+ def test_inheritance(self, dc_decl_base: Type[MappedAsDataclass]):
+ class Person(dc_decl_base):
+ __tablename__ = "person"
+ person_id: Mapped[int] = mapped_column(
+ primary_key=True, init=False
+ )
+ name: Mapped[str]
+ type: Mapped[str] = mapped_column(init=False)
+
+ __mapper_args__ = {"polymorphic_on": type}
+
+ class Engineer(Person):
+ __tablename__ = "engineer"
+
+ person_id: Mapped[int] = mapped_column(
+ ForeignKey("person.person_id"), primary_key=True, init=False
+ )
+
+ status: Mapped[str] = mapped_column(String(30))
+ engineer_name: Mapped[str]
+ primary_language: Mapped[str]
+
+ e1 = Engineer("nm", "st", "en", "pl")
+ eq_(e1.name, "nm")
+ eq_(e1.status, "st")
+ eq_(e1.engineer_name, "en")
+ eq_(e1.primary_language, "pl")
+
+ def test_integrated_dc(self, dc_decl_base: Type[MappedAsDataclass]):
+ """We will be telling users "this is a dataclass that is also
+ mapped". Therefore, they will want *any* kind of attribute to do what
+ it would normally do in a dataclass, including normal types without any
+ field and explicit use of dataclasses.field(). additionally, we'd like
+ ``Mapped`` to mean "persist this attribute". So the absence of
+ ``Mapped`` should also mean something too.
+
+ """
+
+ class A(dc_decl_base):
+ __tablename__ = "a"
+
+ ctrl_one: str
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ data: Mapped[str]
+ some_field: int = dataclasses.field(default=5)
+
+ some_none_field: Optional[str] = None
+
+ a1 = A("ctrlone", "datafield")
+ eq_(a1.some_field, 5)
+ eq_(a1.some_none_field, None)
+
+ # only Mapped[] is mapped
+ self.assert_compile(select(A), "SELECT a.id, a.data FROM a")
+ eq_(
+ pyinspect.getfullargspec(A.__init__),
+ pyinspect.FullArgSpec(
+ args=[
+ "self",
+ "ctrl_one",
+ "data",
+ "some_field",
+ "some_none_field",
+ ],
+ varargs=None,
+ varkw=None,
+ defaults=(5, None),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={},
+ ),
+ )
+
+ def test_dc_on_top_of_non_dc(self, decl_base: Type[DeclarativeBase]):
+ class Person(decl_base):
+ __tablename__ = "person"
+ person_id: Mapped[int] = mapped_column(primary_key=True)
+ name: Mapped[str]
+ type: Mapped[str] = mapped_column()
+
+ __mapper_args__ = {"polymorphic_on": type}
+
+ class Engineer(MappedAsDataclass, Person):
+ __tablename__ = "engineer"
+
+ person_id: Mapped[int] = mapped_column(
+ ForeignKey("person.person_id"), primary_key=True, init=False
+ )
+
+ status: Mapped[str] = mapped_column(String(30))
+ engineer_name: Mapped[str]
+ primary_language: Mapped[str]
+
+ e1 = Engineer("st", "en", "pl")
+ eq_(e1.status, "st")
+ eq_(e1.engineer_name, "en")
+ eq_(e1.primary_language, "pl")
+
+ eq_(
+ pyinspect.getfullargspec(Person.__init__),
+ # the boring **kw __init__
+ pyinspect.FullArgSpec(
+ args=["self"],
+ varargs=None,
+ varkw="kwargs",
+ defaults=None,
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={},
+ ),
+ )
+
+ eq_(
+ pyinspect.getfullargspec(Engineer.__init__),
+ # the exciting dataclasses __init__
+ pyinspect.FullArgSpec(
+ args=["self", "status", "engineer_name", "primary_language"],
+ varargs=None,
+ varkw=None,
+ defaults=None,
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={},
+ ),
+ )
+
+
+class RelationshipDefaultFactoryTest(fixtures.TestBase):
+ def test_list(self, dc_decl_base: Type[MappedAsDataclass]):
+ class A(dc_decl_base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+
+ bs: Mapped[List["B"]] = relationship( # noqa: F821
+ default_factory=lambda: [B(data="hi")]
+ )
+
+ class B(dc_decl_base):
+ __tablename__ = "b"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ a_id = mapped_column(ForeignKey("a.id"), init=False)
+ data: Mapped[str]
+
+ a1 = A()
+ eq_(a1.bs[0].data, "hi")
+
+ def test_set(self, dc_decl_base: Type[MappedAsDataclass]):
+ class A(dc_decl_base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+
+ bs: Mapped[Set["B"]] = relationship( # noqa: F821
+ default_factory=lambda: {B(data="hi")}
+ )
+
+ class B(dc_decl_base, unsafe_hash=True):
+ __tablename__ = "b"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ a_id = mapped_column(ForeignKey("a.id"), init=False)
+ data: Mapped[str]
+
+ a1 = A()
+ eq_(a1.bs.pop().data, "hi")
+
+ def test_oh_no_mismatch(self, dc_decl_base: Type[MappedAsDataclass]):
+ class A(dc_decl_base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+
+ bs: Mapped[Set["B"]] = relationship( # noqa: F821
+ default_factory=lambda: [B(data="hi")]
+ )
+
+ class B(dc_decl_base, unsafe_hash=True):
+ __tablename__ = "b"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ a_id = mapped_column(ForeignKey("a.id"), init=False)
+ data: Mapped[str]
+
+ # old school collection mismatch error FTW
+ with expect_raises_message(
+ TypeError, "Incompatible collection type: list is not set-like"
+ ):
+ A()
+
+ def test_replace_operation_works_w_history_etc(
+ self, registry: _RegistryType
+ ):
+ @registry.mapped_as_dataclass
+ class A:
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ data: Mapped[str]
+
+ x: Mapped[Optional[int]] = mapped_column(default=None)
+
+ bs: Mapped[List["B"]] = relationship( # noqa: F821
+ default_factory=list
+ )
+
+ @registry.mapped_as_dataclass
+ class B:
+ __tablename__ = "b"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ a_id = mapped_column(ForeignKey("a.id"), init=False)
+ data: Mapped[str]
+ x: Mapped[Optional[int]] = mapped_column(default=None)
+
+ registry.metadata.create_all(testing.db)
+
+ with Session(testing.db) as sess:
+ a1 = A("data", 10, [B("b1"), B("b2", x=5), B("b3")])
+ sess.add(a1)
+ sess.commit()
+
+ a2 = dataclasses.replace(a1, x=12, bs=[B("b4")])
+
+ assert a1 in sess
+ assert not sess.is_modified(a1, include_collections=True)
+ assert a2 not in sess
+ eq_(inspect(a2).attrs.x.history, ([12], (), ()))
+ sess.add(a2)
+ sess.commit()
+
+ eq_(sess.scalars(select(A.x).order_by(A.id)).all(), [10, 12])
+ eq_(
+ sess.scalars(select(B.data).order_by(B.id)).all(),
+ ["b1", "b2", "b3", "b4"],
+ )
+
+ def test_post_init(self, registry: _RegistryType):
+ @registry.mapped_as_dataclass
+ class A:
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ data: Mapped[str] = mapped_column(init=False)
+
+ def __post_init__(self):
+ self.data = "some data"
+
+ a1 = A()
+ eq_(a1.data, "some data")
+
+ def test_no_field_args_w_new_style(self, registry: _RegistryType):
+ with expect_raises_message(
+ exc.InvalidRequestError,
+ "SQLAlchemy mapped dataclasses can't consume mapping information",
+ ):
+
+ @registry.mapped_as_dataclass()
+ class A:
+ __tablename__ = "a"
+ __sa_dataclass_metadata_key__ = "sa"
+
+ account_id: int = dataclasses.field(
+ init=False,
+ metadata={"sa": Column(Integer, primary_key=True)},
+ )
+
+ def test_no_field_args_w_new_style_two(self, registry: _RegistryType):
+ @dataclasses.dataclass
+ class Base:
+ pass
+
+ with expect_raises_message(
+ exc.InvalidRequestError,
+ "SQLAlchemy mapped dataclasses can't consume mapping information",
+ ):
+
+ @registry.mapped_as_dataclass()
+ class A(Base):
+ __tablename__ = "a"
+ __sa_dataclass_metadata_key__ = "sa"
+
+ account_id: int = dataclasses.field(
+ init=False,
+ metadata={"sa": Column(Integer, primary_key=True)},
+ )
+
+
+class DataclassArgsTest(fixtures.TestBase):
+ dc_arg_names = ("init", "repr", "eq", "order", "unsafe_hash")
+
+ @testing.fixture(params=dc_arg_names)
+ def dc_argument_fixture(self, request: Any, registry: _RegistryType):
+ name = request.param
+
+ args = {n: n == name for n in self.dc_arg_names}
+ if args["order"]:
+ args["eq"] = True
+ yield args
+
+ @testing.fixture(
+ params=["mapped_column", "synonym", "deferred", "column_property"]
+ )
+ def mapped_expr_constructor(self, request):
+ name = request.param
+
+ if name == "mapped_column":
+ yield mapped_column(default=7, init=True)
+ elif name == "synonym":
+ yield synonym("some_int", default=7, init=True)
+ elif name == "deferred":
+ yield deferred(Column(Integer), default=7, init=True)
+ elif name == "column_property":
+ yield column_property(Column(Integer), default=7, init=True)
+
+ def test_attrs_rejected_if_not_a_dc(
+ self, mapped_expr_constructor, decl_base: Type[DeclarativeBase]
+ ):
+ if isinstance(mapped_expr_constructor, MappedColumn):
+ unwanted_args = "'init'"
+ else:
+ unwanted_args = "'default', 'init'"
+ with expect_raises_message(
+ exc.ArgumentError,
+ r"Attribute 'x' on class .*A.* includes dataclasses "
+ r"argument\(s\): "
+ rf"{unwanted_args} but class does not specify SQLAlchemy native "
+ "dataclass configuration",
+ ):
+
+ class A(decl_base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+
+ x: Mapped[int] = mapped_expr_constructor
+
+ def _assert_cls(self, cls, dc_arguments):
+
+ if dc_arguments["init"]:
+
+ def create(data, x):
+ return cls(data, x)
+
+ else:
+
+ def create(data, x):
+ a1 = cls()
+ a1.data = data
+ a1.x = x
+ return a1
+
+ for n in self.dc_arg_names:
+ if dc_arguments[n]:
+ getattr(self, f"_assert_{n}")(cls, create, dc_arguments)
+ else:
+ getattr(self, f"_assert_not_{n}")(cls, create, dc_arguments)
+
+ if dc_arguments["init"]:
+ a1 = cls("some data")
+ eq_(a1.x, 7)
+
+ a1 = create("some data", 15)
+ some_int = a1.some_int
+ eq_(
+ dataclasses.asdict(a1),
+ {"data": "some data", "id": None, "some_int": some_int, "x": 15},
+ )
+ eq_(dataclasses.astuple(a1), (None, "some data", some_int, 15))
+
+ def _assert_unsafe_hash(self, cls, create, dc_arguments):
+ a1 = create("d1", 5)
+ hash(a1)
+
+ def _assert_not_unsafe_hash(self, cls, create, dc_arguments):
+ a1 = create("d1", 5)
+
+ if dc_arguments["eq"]:
+ with expect_raises(TypeError):
+ hash(a1)
+ else:
+ hash(a1)
+
+ def _assert_eq(self, cls, create, dc_arguments):
+ a1 = create("d1", 5)
+ a2 = create("d2", 10)
+ a3 = create("d1", 5)
+
+ eq_(a1, a3)
+ ne_(a1, a2)
+
+ def _assert_not_eq(self, cls, create, dc_arguments):
+ a1 = create("d1", 5)
+ a2 = create("d2", 10)
+ a3 = create("d1", 5)
+
+ eq_(a1, a1)
+ ne_(a1, a3)
+ ne_(a1, a2)
+
+ def _assert_order(self, cls, create, dc_arguments):
+ is_false(create("g", 10) < create("b", 7))
+
+ is_true(create("g", 10) > create("b", 7))
+
+ is_false(create("g", 10) <= create("b", 7))
+
+ is_true(create("g", 10) >= create("b", 7))
+
+ eq_(
+ list(sorted([create("g", 10), create("g", 5), create("b", 7)])),
+ [
+ create("b", 7),
+ create("g", 5),
+ create("g", 10),
+ ],
+ )
+
+ def _assert_not_order(self, cls, create, dc_arguments):
+ with expect_raises(TypeError):
+ create("g", 10) < create("b", 7)
+
+ with expect_raises(TypeError):
+ create("g", 10) > create("b", 7)
+
+ with expect_raises(TypeError):
+ create("g", 10) <= create("b", 7)
+
+ with expect_raises(TypeError):
+ create("g", 10) >= create("b", 7)
+
+ def _assert_repr(self, cls, create, dc_arguments):
+ a1 = create("some data", 12)
+ eq_regex(repr(a1), r".*A\(id=None, data='some data', x=12\)")
+
+ def _assert_not_repr(self, cls, create, dc_arguments):
+ a1 = create("some data", 12)
+ eq_regex(repr(a1), r"<.*A object at 0x.*>")
+
+ def _assert_init(self, cls, create, dc_arguments):
+ a1 = cls("some data", 5)
+
+ eq_(a1.data, "some data")
+ eq_(a1.x, 5)
+
+ a2 = cls(data="some data", x=5)
+ eq_(a2.data, "some data")
+ eq_(a2.x, 5)
+
+ a3 = cls(data="some data")
+ eq_(a3.data, "some data")
+ eq_(a3.x, 7)
+
+ def _assert_not_init(self, cls, create, dc_arguments):
+
+ with expect_raises(TypeError):
+ cls("Some data", 5)
+
+ # we run real "dataclasses" on the class. so with init=False, it
+ # doesn't touch what was there, and the SQLA default constructor
+ # gets put on.
+ a1 = cls(data="some data")
+ eq_(a1.data, "some data")
+ eq_(a1.x, None)
+
+ a1 = cls()
+ eq_(a1.data, None)
+
+ # no constructor, it sets None for x...ok
+ eq_(a1.x, None)
+
+ def test_dc_arguments_decorator(
+ self,
+ dc_argument_fixture,
+ mapped_expr_constructor,
+ registry: _RegistryType,
+ ):
+ @registry.mapped_as_dataclass(**dc_argument_fixture)
+ class A:
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ data: Mapped[str]
+
+ some_int: Mapped[int] = mapped_column(init=False, repr=False)
+
+ x: Mapped[Optional[int]] = mapped_expr_constructor
+
+ self._assert_cls(A, dc_argument_fixture)
+
+ def test_dc_arguments_base(
+ self,
+ dc_argument_fixture,
+ mapped_expr_constructor,
+ registry: _RegistryType,
+ ):
+ reg = registry
+
+ class Base(MappedAsDataclass, DeclarativeBase, **dc_argument_fixture):
+ registry = reg
+
+ class A(Base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ data: Mapped[str]
+
+ some_int: Mapped[int] = mapped_column(init=False, repr=False)
+
+ x: Mapped[Optional[int]] = mapped_expr_constructor
+
+ self.A = A
+
+ def test_dc_arguments_perclass(
+ self,
+ dc_argument_fixture,
+ mapped_expr_constructor,
+ decl_base: Type[DeclarativeBase],
+ ):
+ class A(MappedAsDataclass, decl_base, **dc_argument_fixture):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ data: Mapped[str]
+
+ some_int: Mapped[int] = mapped_column(init=False, repr=False)
+
+ x: Mapped[Optional[int]] = mapped_expr_constructor
+
+ self.A = A
+
+
+class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL):
+ __dialect__ = "default"
+
+ def test_composite_setup(self, dc_decl_base: Type[MappedAsDataclass]):
+ @dataclasses.dataclass
+ class Point:
+ x: int
+ y: int
+
+ class Edge(dc_decl_base):
+ __tablename__ = "edge"
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ graph_id: Mapped[int] = mapped_column(
+ ForeignKey("graph.id"), init=False
+ )
+
+ start: Mapped[Point] = composite(
+ Point, mapped_column("x1"), mapped_column("y1"), default=None
+ )
+
+ end: Mapped[Point] = composite(
+ Point, mapped_column("x2"), mapped_column("y2"), default=None
+ )
+
+ class Graph(dc_decl_base):
+ __tablename__ = "graph"
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+
+ edges: Mapped[List[Edge]] = relationship()
+
+ Point.__qualname__ = "mymodel.Point"
+ Edge.__qualname__ = "mymodel.Edge"
+ Graph.__qualname__ = "mymodel.Graph"
+ g = Graph(
+ edges=[
+ Edge(start=Point(1, 2), end=Point(3, 4)),
+ Edge(start=Point(7, 8), end=Point(5, 6)),
+ ]
+ )
+ eq_(
+ repr(g),
+ "mymodel.Graph(id=None, edges=[mymodel.Edge(id=None, "
+ "graph_id=None, start=mymodel.Point(x=1, y=2), "
+ "end=mymodel.Point(x=3, y=4)), "
+ "mymodel.Edge(id=None, graph_id=None, "
+ "start=mymodel.Point(x=7, y=8), end=mymodel.Point(x=5, y=6))])",
+ )
+
+ def test_named_setup(self, dc_decl_base: Type[MappedAsDataclass]):
+ @dataclasses.dataclass
+ class Address:
+ street: str
+ state: str
+ zip_: str
+
+ class User(dc_decl_base):
+ __tablename__ = "user"
+
+ id: Mapped[int] = mapped_column(
+ primary_key=True, init=False, repr=False
+ )
+ name: Mapped[str] = mapped_column()
+
+ address: Mapped[Address] = composite(
+ Address,
+ mapped_column(),
+ mapped_column(),
+ mapped_column("zip"),
+ default=None,
+ )
+
+ Address.__qualname__ = "mymodule.Address"
+ User.__qualname__ = "mymodule.User"
+ u = User(
+ name="user 1",
+ address=Address("123 anywhere street", "NY", "12345"),
+ )
+ u2 = User("u2")
+ eq_(
+ repr(u),
+ "mymodule.User(name='user 1', "
+ "address=mymodule.Address(street='123 anywhere street', "
+ "state='NY', zip_='12345'))",
+ )
+ eq_(repr(u2), "mymodule.User(name='u2', address=None)")
diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py
index d7d19821c..865735439 100644
--- a/test/orm/declarative/test_typed_mapping.py
+++ b/test/orm/declarative/test_typed_mapping.py
@@ -190,6 +190,18 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
is_true(User.__table__.c.data.nullable)
assert isinstance(User.__table__.c.created_at.type, DateTime)
+ def test_column_default(self, decl_base):
+ class MyClass(decl_base):
+ __tablename__ = "mytable"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ data: Mapped[str] = mapped_column(default="some default")
+
+ mc = MyClass()
+ assert "data" not in mc.__dict__
+
+ eq_(MyClass.__table__.c.data.default.arg, "some default")
+
def test_anno_w_fixed_table(self, decl_base):
users = Table(
"users",
@@ -959,7 +971,7 @@ class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL):
with expect_raises_message(
ArgumentError,
r"Type annotation for \"User.address\" should use the syntax "
- r"\"Mapped\['Address'\]\" or \"MappedColumn\['Address'\]\"",
+ r"\"Mapped\['Address'\]\"",
):
class User(decl_base):
@@ -1068,6 +1080,38 @@ class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL):
# round trip!
eq_(u1.address, Address("123 anywhere street", "NY", "12345"))
+ def test_cls_annotated_no_mapped_cols_setup(self, decl_base):
+ @dataclasses.dataclass
+ class Address:
+ street: str
+ state: str
+ zip_: str
+
+ class User(decl_base):
+ __tablename__ = "user"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ name: Mapped[str] = mapped_column()
+
+ address: Mapped[Address] = composite()
+
+ decl_base.metadata.create_all(testing.db)
+
+ with fixture_session() as sess:
+ sess.add(
+ User(
+ name="user 1",
+ address=Address("123 anywhere street", "NY", "12345"),
+ )
+ )
+ sess.commit()
+
+ with fixture_session() as sess:
+ u1 = sess.scalar(select(User))
+
+ # round trip!
+ eq_(u1.address, Address("123 anywhere street", "NY", "12345"))
+
def test_one_col_setup(self, decl_base):
@dataclasses.dataclass
class Address: