summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/__init__.py1
-rw-r--r--lib/sqlalchemy/orm/_orm_constructors.py167
-rw-r--r--lib/sqlalchemy/orm/decl_api.py142
-rw-r--r--lib/sqlalchemy/orm/decl_base.py315
-rw-r--r--lib/sqlalchemy/orm/descriptor_props.py48
-rw-r--r--lib/sqlalchemy/orm/instrumentation.py9
-rw-r--r--lib/sqlalchemy/orm/interfaces.py97
-rw-r--r--lib/sqlalchemy/orm/properties.py63
-rw-r--r--lib/sqlalchemy/orm/relationships.py26
-rw-r--r--lib/sqlalchemy/orm/util.py45
-rw-r--r--lib/sqlalchemy/testing/fixtures.py25
-rw-r--r--lib/sqlalchemy/util/compat.py13
-rw-r--r--lib/sqlalchemy/util/typing.py8
13 files changed, 799 insertions, 160 deletions
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py
index b7d1df532..4f19ba946 100644
--- a/lib/sqlalchemy/orm/__init__.py
+++ b/lib/sqlalchemy/orm/__init__.py
@@ -60,6 +60,7 @@ from .decl_api import DeclarativeBaseNoMeta as DeclarativeBaseNoMeta
from .decl_api import DeclarativeMeta as DeclarativeMeta
from .decl_api import declared_attr as declared_attr
from .decl_api import has_inherited_table as has_inherited_table
+from .decl_api import MappedAsDataclass as MappedAsDataclass
from .decl_api import registry as registry
from .decl_api import synonym_for as synonym_for
from .descriptor_props import Composite as Composite
diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py
index 0692cac09..ece6a52be 100644
--- a/lib/sqlalchemy/orm/_orm_constructors.py
+++ b/lib/sqlalchemy/orm/_orm_constructors.py
@@ -21,9 +21,9 @@ from typing import Union
from . import mapperlib as mapperlib
from ._typing import _O
-from .base import Mapped
from .descriptor_props import Composite
from .descriptor_props import Synonym
+from .interfaces import _AttributeOptions
from .properties import ColumnProperty
from .properties import MappedColumn
from .query import AliasOption
@@ -37,6 +37,8 @@ from .util import LoaderCriteriaOption
from .. import sql
from .. import util
from ..exc import InvalidRequestError
+from ..sql._typing import _no_kw
+from ..sql.base import _NoArg
from ..sql.base import SchemaEventTarget
from ..sql.schema import SchemaConst
from ..sql.selectable import FromClause
@@ -105,6 +107,10 @@ def mapped_column(
Union[_TypeEngineArgument[Any], SchemaEventTarget]
] = None,
*args: SchemaEventTarget,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Optional[Any] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
nullable: Optional[
Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]]
] = SchemaConst.NULL_UNSPECIFIED,
@@ -113,7 +119,6 @@ def mapped_column(
name: Optional[str] = None,
type_: Optional[_TypeEngineArgument[Any]] = None,
autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto",
- default: Optional[Any] = None,
doc: Optional[str] = None,
key: Optional[str] = None,
index: Optional[bool] = None,
@@ -300,6 +305,12 @@ def mapped_column(
type_=type_,
autoincrement=autoincrement,
default=default,
+ attribute_options=_AttributeOptions(
+ init,
+ repr,
+ default,
+ default_factory,
+ ),
doc=doc,
key=key,
index=index,
@@ -325,6 +336,10 @@ def column_property(
deferred: bool = False,
raiseload: bool = False,
comparator_factory: Optional[Type[PropComparator[_T]]] = None,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Optional[Any] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
active_history: bool = False,
expire_on_flush: bool = True,
info: Optional[_InfoType] = None,
@@ -416,6 +431,12 @@ def column_property(
return ColumnProperty(
column,
*additional_columns,
+ attribute_options=_AttributeOptions(
+ init,
+ repr,
+ default,
+ default_factory,
+ ),
group=group,
deferred=deferred,
raiseload=raiseload,
@@ -429,25 +450,61 @@ def column_property(
@overload
def composite(
- class_: Type[_CC],
+ _class_or_attr: Type[_CC],
*attrs: _CompositeAttrType[Any],
- **kwargs: Any,
+ group: Optional[str] = None,
+ deferred: bool = False,
+ raiseload: bool = False,
+ comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None,
+ active_history: bool = False,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Optional[Any] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+ info: Optional[_InfoType] = None,
+ doc: Optional[str] = None,
+ **__kw: Any,
) -> Composite[_CC]:
...
@overload
def composite(
+ _class_or_attr: _CompositeAttrType[Any],
*attrs: _CompositeAttrType[Any],
- **kwargs: Any,
+ group: Optional[str] = None,
+ deferred: bool = False,
+ raiseload: bool = False,
+ comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None,
+ active_history: bool = False,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Optional[Any] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+ info: Optional[_InfoType] = None,
+ doc: Optional[str] = None,
+ **__kw: Any,
) -> Composite[Any]:
...
def composite(
- class_: Any = None,
+ _class_or_attr: Union[
+ None, Type[_CC], Callable[..., _CC], _CompositeAttrType[Any]
+ ] = None,
*attrs: _CompositeAttrType[Any],
- **kwargs: Any,
+ group: Optional[str] = None,
+ deferred: bool = False,
+ raiseload: bool = False,
+ comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None,
+ active_history: bool = False,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Optional[Any] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+ info: Optional[_InfoType] = None,
+ doc: Optional[str] = None,
+ **__kw: Any,
) -> Composite[Any]:
r"""Return a composite column-based property for use with a Mapper.
@@ -497,7 +554,26 @@ def composite(
:attr:`.MapperProperty.info` attribute of this object.
"""
- return Composite(class_, *attrs, **kwargs)
+ if __kw:
+ raise _no_kw()
+
+ return Composite(
+ _class_or_attr,
+ *attrs,
+ attribute_options=_AttributeOptions(
+ init,
+ repr,
+ default,
+ default_factory,
+ ),
+ group=group,
+ deferred=deferred,
+ raiseload=raiseload,
+ comparator_factory=comparator_factory,
+ active_history=active_history,
+ info=info,
+ doc=doc,
+ )
def with_loader_criteria(
@@ -700,6 +776,10 @@ def relationship(
post_update: bool = False,
cascade: str = "save-update, merge",
viewonly: bool = False,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Union[_NoArg, _T] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
lazy: _LazyLoadArgumentType = "select",
passive_deletes: Union[Literal["all"], bool] = False,
passive_updates: bool = True,
@@ -1532,6 +1612,12 @@ def relationship(
post_update=post_update,
cascade=cascade,
viewonly=viewonly,
+ attribute_options=_AttributeOptions(
+ init,
+ repr,
+ default,
+ default_factory,
+ ),
lazy=lazy,
passive_deletes=passive_deletes,
passive_updates=passive_updates,
@@ -1559,6 +1645,10 @@ def synonym(
map_column: Optional[bool] = None,
descriptor: Optional[Any] = None,
comparator_factory: Optional[Type[PropComparator[_T]]] = None,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Union[_NoArg, _T] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
info: Optional[_InfoType] = None,
doc: Optional[str] = None,
) -> Synonym[Any]:
@@ -1670,6 +1760,12 @@ def synonym(
map_column=map_column,
descriptor=descriptor,
comparator_factory=comparator_factory,
+ attribute_options=_AttributeOptions(
+ init,
+ repr,
+ default,
+ default_factory,
+ ),
doc=doc,
info=info,
)
@@ -1784,7 +1880,17 @@ def backref(name: str, **kwargs: Any) -> _ORMBackrefArgument:
def deferred(
column: _ORMColumnExprArgument[_T],
*additional_columns: _ORMColumnExprArgument[Any],
- **kw: Any,
+ group: Optional[str] = None,
+ raiseload: bool = False,
+ comparator_factory: Optional[Type[PropComparator[_T]]] = None,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Optional[Any] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+ active_history: bool = False,
+ expire_on_flush: bool = True,
+ info: Optional[_InfoType] = None,
+ doc: Optional[str] = None,
) -> ColumnProperty[_T]:
r"""Indicate a column-based mapped attribute that by default will
not load unless accessed.
@@ -1803,21 +1909,41 @@ def deferred(
:ref:`deferred_raiseload`
- :param \**kw: additional keyword arguments passed to
- :class:`.ColumnProperty`.
+ Additional arguments are the same as that of :func:`_orm.column_property`.
.. seealso::
:ref:`deferred`
"""
- kw["deferred"] = True
- return ColumnProperty(column, *additional_columns, **kw)
+ return ColumnProperty(
+ column,
+ *additional_columns,
+ attribute_options=_AttributeOptions(
+ init,
+ repr,
+ default,
+ default_factory,
+ ),
+ group=group,
+ deferred=True,
+ raiseload=raiseload,
+ comparator_factory=comparator_factory,
+ active_history=active_history,
+ expire_on_flush=expire_on_flush,
+ info=info,
+ doc=doc,
+ )
def query_expression(
default_expr: _ORMColumnExprArgument[_T] = sql.null(),
-) -> Mapped[_T]:
+ *,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ expire_on_flush: bool = True,
+ info: Optional[_InfoType] = None,
+ doc: Optional[str] = None,
+) -> ColumnProperty[_T]:
"""Indicate an attribute that populates from a query-time SQL expression.
:param default_expr: Optional SQL expression object that will be used in
@@ -1840,7 +1966,18 @@ def query_expression(
:ref:`mapper_querytime_expression`
"""
- prop = ColumnProperty(default_expr)
+ prop = ColumnProperty(
+ default_expr,
+ attribute_options=_AttributeOptions(
+ _NoArg.NO_ARG,
+ repr,
+ _NoArg.NO_ARG,
+ _NoArg.NO_ARG,
+ ),
+ expire_on_flush=expire_on_flush,
+ info=info,
+ doc=doc,
+ )
prop.strategy_key = (("query_expression", True),)
return prop
diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py
index 1c343b04c..553a50107 100644
--- a/lib/sqlalchemy/orm/decl_api.py
+++ b/lib/sqlalchemy/orm/decl_api.py
@@ -33,6 +33,13 @@ from . import clsregistry
from . import instrumentation
from . import interfaces
from . import mapperlib
+from ._orm_constructors import column_property
+from ._orm_constructors import composite
+from ._orm_constructors import deferred
+from ._orm_constructors import mapped_column
+from ._orm_constructors import query_expression
+from ._orm_constructors import relationship
+from ._orm_constructors import synonym
from .attributes import InstrumentedAttribute
from .base import _inspect_mapped_class
from .base import Mapped
@@ -42,8 +49,13 @@ from .decl_base import _declarative_constructor
from .decl_base import _DeferredMapperConfig
from .decl_base import _del_attribute
from .decl_base import _mapper
+from .descriptor_props import Composite
+from .descriptor_props import Synonym
from .descriptor_props import Synonym as _orm_synonym
from .mapper import Mapper
+from .properties import ColumnProperty
+from .properties import MappedColumn
+from .relationships import Relationship
from .state import InstanceState
from .. import exc
from .. import inspection
@@ -60,9 +72,9 @@ from ..util.typing import Literal
if TYPE_CHECKING:
from ._typing import _O
from ._typing import _RegistryType
- from .descriptor_props import Synonym
from .instrumentation import ClassManager
from .interfaces import MapperProperty
+ from .state import InstanceState # noqa
from ..sql._typing import _TypeEngineArgument
_T = TypeVar("_T", bound=Any)
@@ -120,6 +132,26 @@ class DeclarativeAttributeIntercept(
"""
+@compat_typing.dataclass_transform(
+ field_descriptors=(
+ MappedColumn[Any],
+ Relationship[Any],
+ Composite[Any],
+ ColumnProperty[Any],
+ Synonym[Any],
+ mapped_column,
+ relationship,
+ composite,
+ column_property,
+ synonym,
+ deferred,
+ query_expression,
+ ),
+)
+class DCTransformDeclarative(DeclarativeAttributeIntercept):
+ """metaclass that includes @dataclass_transforms"""
+
+
class DeclarativeMeta(
_DynamicAttributesType, inspection.Inspectable[Mapper[Any]]
):
@@ -543,12 +575,42 @@ class DeclarativeBaseNoMeta(inspection.Inspectable[Mapper[Any]]):
cls._sa_registry.map_declaratively(cls)
+class MappedAsDataclass(metaclass=DCTransformDeclarative):
+ """Mixin class to indicate when mapping this class, also convert it to be
+ a dataclass.
+
+ .. seealso::
+
+ :meth:`_orm.registry.mapped_as_dataclass`
+
+ .. versionadded:: 2.0
+ """
+
+ def __init_subclass__(
+ cls,
+ init: bool = True,
+ repr: bool = True, # noqa: A002
+ eq: bool = True,
+ order: bool = False,
+ unsafe_hash: bool = False,
+ ) -> None:
+ cls._sa_apply_dc_transforms = {
+ "init": init,
+ "repr": repr,
+ "eq": eq,
+ "order": order,
+ "unsafe_hash": unsafe_hash,
+ }
+ super().__init_subclass__()
+
+
class DeclarativeBase(
inspection.Inspectable[InstanceState[Any]],
metaclass=DeclarativeAttributeIntercept,
):
"""Base class used for declarative class definitions.
+
The :class:`_orm.DeclarativeBase` allows for the creation of new
declarative bases in such a way that is compatible with type checkers::
@@ -1121,7 +1183,7 @@ class registry:
bases = not isinstance(cls, tuple) and (cls,) or cls
- class_dict = dict(registry=self, metadata=metadata)
+ class_dict: Dict[str, Any] = dict(registry=self, metadata=metadata)
if isinstance(cls, type):
class_dict["__doc__"] = cls.__doc__
@@ -1142,6 +1204,78 @@ class registry:
return metaclass(name, bases, class_dict)
+ @compat_typing.dataclass_transform(
+ field_descriptors=(
+ MappedColumn[Any],
+ Relationship[Any],
+ Composite[Any],
+ ColumnProperty[Any],
+ Synonym[Any],
+ mapped_column,
+ relationship,
+ composite,
+ column_property,
+ synonym,
+ deferred,
+ query_expression,
+ ),
+ )
+ @overload
+ def mapped_as_dataclass(self, __cls: Type[_O]) -> Type[_O]:
+ ...
+
+ @overload
+ def mapped_as_dataclass(
+ self,
+ __cls: Literal[None] = ...,
+ *,
+ init: bool = True,
+ repr: bool = True, # noqa: A002
+ eq: bool = True,
+ order: bool = False,
+ unsafe_hash: bool = False,
+ ) -> Callable[[Type[_O]], Type[_O]]:
+ ...
+
+ def mapped_as_dataclass(
+ self,
+ __cls: Optional[Type[_O]] = None,
+ *,
+ init: bool = True,
+ repr: bool = True, # noqa: A002
+ eq: bool = True,
+ order: bool = False,
+ unsafe_hash: bool = False,
+ ) -> Union[Type[_O], Callable[[Type[_O]], Type[_O]]]:
+ """Class decorator that will apply the Declarative mapping process
+ to a given class, and additionally convert the class to be a
+ Python dataclass.
+
+ .. seealso::
+
+ :meth:`_orm.registry.mapped`
+
+ .. versionadded:: 2.0
+
+
+ """
+
+ def decorate(cls: Type[_O]) -> Type[_O]:
+ cls._sa_apply_dc_transforms = {
+ "init": init,
+ "repr": repr,
+ "eq": eq,
+ "order": order,
+ "unsafe_hash": unsafe_hash,
+ }
+ _as_declarative(self, cls, cls.__dict__)
+ return cls
+
+ if __cls:
+ return decorate(__cls)
+ else:
+ return decorate
+
def mapped(self, cls: Type[_O]) -> Type[_O]:
"""Class decorator that will apply the Declarative mapping process
to a given class.
@@ -1174,6 +1308,10 @@ class registry:
that will apply Declarative mapping to subclasses automatically
using a Python metaclass.
+ .. seealso::
+
+ :meth:`_orm.registry.mapped_as_dataclass`
+
"""
_as_declarative(self, cls, cls.__dict__)
return cls
diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py
index a66421e22..54a272f86 100644
--- a/lib/sqlalchemy/orm/decl_base.py
+++ b/lib/sqlalchemy/orm/decl_base.py
@@ -10,6 +10,8 @@
from __future__ import annotations
import collections
+import dataclasses
+import re
from typing import Any
from typing import Callable
from typing import cast
@@ -40,6 +42,7 @@ from .base import _is_mapped_class
from .base import InspectionAttr
from .descriptor_props import Composite
from .descriptor_props import Synonym
+from .interfaces import _AttributeOptions
from .interfaces import _IntrospectsAnnotations
from .interfaces import _MappedAttribute
from .interfaces import _MapsColumns
@@ -48,15 +51,18 @@ from .mapper import Mapper as mapper
from .mapper import Mapper
from .properties import ColumnProperty
from .properties import MappedColumn
+from .util import _extract_mapped_subtype
from .util import _is_mapped_annotation
from .util import class_mapper
from .. import event
from .. import exc
from .. import util
from ..sql import expression
+from ..sql.base import _NoArg
from ..sql.schema import Column
from ..sql.schema import Table
from ..util import topological
+from ..util.typing import _AnnotationScanType
from ..util.typing import Protocol
if TYPE_CHECKING:
@@ -392,11 +398,13 @@ class _ClassScanMapperConfig(_MapperConfig):
"mapper_args",
"mapper_args_fn",
"inherits",
+ "allow_dataclass_fields",
+ "dataclass_setup_arguments",
)
registry: _RegistryType
clsdict_view: _ClassDict
- collected_annotations: Dict[str, Tuple[Any, bool]]
+ collected_annotations: Dict[str, Tuple[Any, Any, bool]]
collected_attributes: Dict[str, Any]
local_table: Optional[FromClause]
persist_selectable: Optional[FromClause]
@@ -411,6 +419,17 @@ class _ClassScanMapperConfig(_MapperConfig):
mapper_args_fn: Optional[Callable[[], Dict[str, Any]]]
inherits: Optional[Type[Any]]
+ dataclass_setup_arguments: Optional[Dict[str, Any]]
+ """if the class has SQLAlchemy native dataclass parameters, where
+ we will create a SQLAlchemy dataclass (not a real dataclass).
+
+ """
+
+ allow_dataclass_fields: bool
+ """if true, look for dataclass-processed Field objects on the target
+ class as well as superclasses and extract ORM mapping directives from
+ the "metadata" attribute of each Field"""
+
def __init__(
self,
registry: _RegistryType,
@@ -434,10 +453,37 @@ class _ClassScanMapperConfig(_MapperConfig):
self.declared_columns = util.OrderedSet()
self.column_copies = {}
+ self.dataclass_setup_arguments = dca = getattr(
+ self.cls, "_sa_apply_dc_transforms", None
+ )
+
+ cld = dataclasses.is_dataclass(cls_)
+
+ sdk = _get_immediate_cls_attr(cls_, "__sa_dataclass_metadata_key__")
+
+ # we don't want to consume Field objects from a not-already-dataclass.
+ # the Field objects won't have their "name" or "type" populated,
+ # and while it seems like we could just set these on Field as we
+ # read them, Field is documented as "user read only" and we need to
+ # stay far away from any off-label use of dataclasses APIs.
+ if (not cld or dca) and sdk:
+ raise exc.InvalidRequestError(
+ "SQLAlchemy mapped dataclasses can't consume mapping "
+ "information from dataclass.Field() objects if the immediate "
+ "class is not already a dataclass."
+ )
+
+ # if already a dataclass, and __sa_dataclass_metadata_key__ present,
+ # then also look inside of dataclass.Field() objects yielded by
+ # dataclasses.get_fields(cls) when scanning for attributes
+ self.allow_dataclass_fields = bool(sdk and cld)
+
self._setup_declared_events()
self._scan_attributes()
+ self._setup_dataclasses_transforms()
+
with mapperlib._CONFIGURE_MUTEX:
clsregistry.add_class(
self.classname, self.cls, registry._class_registry
@@ -477,11 +523,15 @@ class _ClassScanMapperConfig(_MapperConfig):
attribute, taking SQLAlchemy-enabled dataclass fields into account.
"""
- sa_dataclass_metadata_key = _get_immediate_cls_attr(
- cls, "__sa_dataclass_metadata_key__"
- )
- if sa_dataclass_metadata_key is None:
+ if self.allow_dataclass_fields:
+ sa_dataclass_metadata_key = _get_immediate_cls_attr(
+ cls, "__sa_dataclass_metadata_key__"
+ )
+ else:
+ sa_dataclass_metadata_key = None
+
+ if not sa_dataclass_metadata_key:
def attribute_is_overridden(key: str, obj: Any) -> bool:
return getattr(cls, key) is not obj
@@ -551,6 +601,7 @@ class _ClassScanMapperConfig(_MapperConfig):
"__dict__",
"__weakref__",
"_sa_class_manager",
+ "_sa_apply_dc_transforms",
"__dict__",
"__weakref__",
]
@@ -563,10 +614,6 @@ class _ClassScanMapperConfig(_MapperConfig):
adjusting for SQLAlchemy fields embedded in dataclass fields.
"""
- sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr(
- cls, "__sa_dataclass_metadata_key__"
- )
-
cls_annotations = util.get_annotations(cls)
cls_vars = vars(cls)
@@ -576,7 +623,15 @@ class _ClassScanMapperConfig(_MapperConfig):
names = util.merge_lists_w_ordering(
[n for n in cls_vars if n not in skip], list(cls_annotations)
)
- if sa_dataclass_metadata_key is None:
+
+ if self.allow_dataclass_fields:
+ sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr(
+ cls, "__sa_dataclass_metadata_key__"
+ )
+ else:
+ sa_dataclass_metadata_key = None
+
+ if not sa_dataclass_metadata_key:
def local_attributes_for_class() -> Iterable[
Tuple[str, Any, Any, bool]
@@ -652,45 +707,51 @@ class _ClassScanMapperConfig(_MapperConfig):
name,
obj,
annotation,
- is_dataclass,
+ is_dataclass_field,
) in local_attributes_for_class():
- if name == "__mapper_args__":
- check_decl = _check_declared_props_nocascade(
- obj, name, cls
- )
- if not mapper_args_fn and (not class_mapped or check_decl):
- # don't even invoke __mapper_args__ until
- # after we've determined everything about the
- # mapped table.
- # make a copy of it so a class-level dictionary
- # is not overwritten when we update column-based
- # arguments.
- def _mapper_args_fn() -> Dict[str, Any]:
- return dict(cls_as_Decl.__mapper_args__)
-
- mapper_args_fn = _mapper_args_fn
-
- elif name == "__tablename__":
- check_decl = _check_declared_props_nocascade(
- obj, name, cls
- )
- if not tablename and (not class_mapped or check_decl):
- tablename = cls_as_Decl.__tablename__
- elif name == "__table_args__":
- check_decl = _check_declared_props_nocascade(
- obj, name, cls
- )
- if not table_args and (not class_mapped or check_decl):
- table_args = cls_as_Decl.__table_args__
- if not isinstance(
- table_args, (tuple, dict, type(None))
+ if re.match(r"^__.+__$", name):
+ if name == "__mapper_args__":
+ check_decl = _check_declared_props_nocascade(
+ obj, name, cls
+ )
+ if not mapper_args_fn and (
+ not class_mapped or check_decl
):
- raise exc.ArgumentError(
- "__table_args__ value must be a tuple, "
- "dict, or None"
- )
- if base is not cls:
- inherited_table_args = True
+ # don't even invoke __mapper_args__ until
+ # after we've determined everything about the
+ # mapped table.
+ # make a copy of it so a class-level dictionary
+ # is not overwritten when we update column-based
+ # arguments.
+ def _mapper_args_fn() -> Dict[str, Any]:
+ return dict(cls_as_Decl.__mapper_args__)
+
+ mapper_args_fn = _mapper_args_fn
+
+ elif name == "__tablename__":
+ check_decl = _check_declared_props_nocascade(
+ obj, name, cls
+ )
+ if not tablename and (not class_mapped or check_decl):
+ tablename = cls_as_Decl.__tablename__
+ elif name == "__table_args__":
+ check_decl = _check_declared_props_nocascade(
+ obj, name, cls
+ )
+ if not table_args and (not class_mapped or check_decl):
+ table_args = cls_as_Decl.__table_args__
+ if not isinstance(
+ table_args, (tuple, dict, type(None))
+ ):
+ raise exc.ArgumentError(
+ "__table_args__ value must be a tuple, "
+ "dict, or None"
+ )
+ if base is not cls:
+ inherited_table_args = True
+ else:
+ # skip all other dunder names
+ continue
elif class_mapped:
if _is_declarative_props(obj):
util.warn(
@@ -706,9 +767,8 @@ class _ClassScanMapperConfig(_MapperConfig):
# acting like that for now.
if isinstance(obj, (Column, MappedColumn)):
- self.collected_annotations[name] = (
- annotation,
- False,
+ self._collect_annotation(
+ name, annotation, is_dataclass_field, True, obj
)
# already copied columns to the mapped class.
continue
@@ -745,7 +805,7 @@ class _ClassScanMapperConfig(_MapperConfig):
] = ret = obj.__get__(obj, cls)
setattr(cls, name, ret)
else:
- if is_dataclass:
+ if is_dataclass_field:
# access attribute using normal class access
# first, to see if it's been mapped on a
# superclass. note if the dataclasses.field()
@@ -789,14 +849,16 @@ class _ClassScanMapperConfig(_MapperConfig):
):
ret.doc = obj.__doc__
- self.collected_annotations[name] = (
+ self._collect_annotation(
+ name,
obj._collect_return_annotation(),
False,
+ True,
+ obj,
)
elif _is_mapped_annotation(annotation, cls):
- self.collected_annotations[name] = (
- annotation,
- is_dataclass,
+ self._collect_annotation(
+ name, annotation, is_dataclass_field, True, obj
)
if obj is None:
if not fixed_table:
@@ -809,7 +871,7 @@ class _ClassScanMapperConfig(_MapperConfig):
# declarative mapping. however, check for some
# more common mistakes
self._warn_for_decl_attributes(base, name, obj)
- elif is_dataclass and (
+ elif is_dataclass_field and (
name not in clsdict_view or clsdict_view[name] is not obj
):
# here, we are definitely looking at the target class
@@ -826,14 +888,12 @@ class _ClassScanMapperConfig(_MapperConfig):
obj = obj.fget()
collected_attributes[name] = obj
- self.collected_annotations[name] = (
- annotation,
- True,
+ self._collect_annotation(
+ name, annotation, True, False, obj
)
else:
- self.collected_annotations[name] = (
- annotation,
- False,
+ self._collect_annotation(
+ name, annotation, False, None, obj
)
if (
obj is None
@@ -843,6 +903,10 @@ class _ClassScanMapperConfig(_MapperConfig):
collected_attributes[name] = MappedColumn()
elif name in clsdict_view:
collected_attributes[name] = obj
+ # else if the name is not in the cls.__dict__,
+ # don't collect it as an attribute.
+ # we will see the annotation only, which is meaningful
+ # both for mapping and dataclasses setup
if inherited_table_args and not tablename:
table_args = None
@@ -851,6 +915,77 @@ class _ClassScanMapperConfig(_MapperConfig):
self.tablename = tablename
self.mapper_args_fn = mapper_args_fn
+ def _setup_dataclasses_transforms(self) -> None:
+
+ dataclass_setup_arguments = self.dataclass_setup_arguments
+ if not dataclass_setup_arguments:
+ return
+
+ manager = instrumentation.manager_of_class(self.cls)
+ assert manager is not None
+
+ field_list = [
+ _AttributeOptions._get_arguments_for_make_dataclass(
+ key,
+ anno,
+ self.collected_attributes.get(key, _NoArg.NO_ARG),
+ )
+ for key, anno in (
+ (key, mapped_anno if mapped_anno else raw_anno)
+ for key, (
+ raw_anno,
+ mapped_anno,
+ is_dc,
+ ) in self.collected_annotations.items()
+ )
+ ]
+
+ annotations = {}
+ defaults = {}
+ for item in field_list:
+ if len(item) == 2:
+ name, tp = item # type: ignore
+ elif len(item) == 3:
+ name, tp, spec = item # type: ignore
+ defaults[name] = spec
+ else:
+ assert False
+ annotations[name] = tp
+
+ for k, v in defaults.items():
+ setattr(self.cls, k, v)
+ self.cls.__annotations__ = annotations
+
+ dataclasses.dataclass(self.cls, **dataclass_setup_arguments)
+
+ def _collect_annotation(
+ self,
+ name: str,
+ raw_annotation: _AnnotationScanType,
+ is_dataclass: bool,
+ expect_mapped: Optional[bool],
+ attr_value: Any,
+ ) -> None:
+
+ if expect_mapped is None:
+ expect_mapped = isinstance(attr_value, _MappedAttribute)
+
+ extracted_mapped_annotation = _extract_mapped_subtype(
+ raw_annotation,
+ self.cls,
+ name,
+ type(attr_value),
+ required=False,
+ is_dataclass_field=False,
+ expect_mapped=expect_mapped and not self.allow_dataclass_fields,
+ )
+
+ self.collected_annotations[name] = (
+ raw_annotation,
+ extracted_mapped_annotation,
+ is_dataclass,
+ )
+
def _warn_for_decl_attributes(
self, cls: Type[Any], key: str, c: Any
) -> None:
@@ -982,13 +1117,53 @@ class _ClassScanMapperConfig(_MapperConfig):
_undefer_column_name(
k, self.column_copies.get(value, value) # type: ignore
)
- elif isinstance(value, _IntrospectsAnnotations):
- annotation, is_dataclass = self.collected_annotations.get(
- k, (None, False)
- )
- value.declarative_scan(
- self.registry, cls, k, annotation, is_dataclass
- )
+ else:
+ if isinstance(value, _IntrospectsAnnotations):
+ (
+ annotation,
+ extracted_mapped_annotation,
+ is_dataclass,
+ ) = self.collected_annotations.get(k, (None, None, False))
+ value.declarative_scan(
+ self.registry,
+ cls,
+ k,
+ annotation,
+ extracted_mapped_annotation,
+ is_dataclass,
+ )
+
+ if (
+ isinstance(value, (MapperProperty, _MapsColumns))
+ and value._has_dataclass_arguments
+ and not self.dataclass_setup_arguments
+ ):
+ if isinstance(value, MapperProperty):
+ argnames = [
+ "init",
+ "default_factory",
+ "repr",
+ "default",
+ ]
+ else:
+ argnames = ["init", "default_factory", "repr"]
+
+ args = {
+ a
+ for a in argnames
+ if getattr(
+ value._attribute_options, f"dataclasses_{a}"
+ )
+ is not _NoArg.NO_ARG
+ }
+ raise exc.ArgumentError(
+ f"Attribute '{k}' on class {cls} includes dataclasses "
+ f"argument(s): "
+ f"{', '.join(sorted(repr(a) for a in args))} but "
+ f"class does not specify "
+ "SQLAlchemy native dataclass configuration."
+ )
+
our_stuff[k] = value
def _extract_declared_columns(self) -> None:
@@ -997,6 +1172,7 @@ class _ClassScanMapperConfig(_MapperConfig):
# extract columns from the class dict
declared_columns = self.declared_columns
name_to_prop_key = collections.defaultdict(set)
+
for key, c in list(our_stuff.items()):
if isinstance(c, _MapsColumns):
@@ -1019,7 +1195,6 @@ class _ClassScanMapperConfig(_MapperConfig):
# otherwise, Mapper will map it under the column key.
if mp_to_assign is None and key != col.key:
our_stuff[key] = col
-
elif isinstance(c, Column):
# undefer previously occurred here, and now occurs earlier.
# ensure every column we get here has been named
diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py
index 8c89f96aa..a366a9534 100644
--- a/lib/sqlalchemy/orm/descriptor_props.py
+++ b/lib/sqlalchemy/orm/descriptor_props.py
@@ -35,11 +35,11 @@ from .base import LoaderCallableStatus
from .base import Mapped
from .base import PassiveFlag
from .base import SQLORMOperations
+from .interfaces import _AttributeOptions
from .interfaces import _IntrospectsAnnotations
from .interfaces import _MapsColumns
from .interfaces import MapperProperty
from .interfaces import PropComparator
-from .util import _extract_mapped_subtype
from .util import _none_set
from .. import event
from .. import exc as sa_exc
@@ -200,24 +200,26 @@ class Composite(
def __init__(
self,
- class_: Union[
+ _class_or_attr: Union[
None, Type[_CC], Callable[..., _CC], _CompositeAttrType[Any]
] = None,
*attrs: _CompositeAttrType[Any],
+ attribute_options: Optional[_AttributeOptions] = None,
active_history: bool = False,
deferred: bool = False,
group: Optional[str] = None,
comparator_factory: Optional[Type[Comparator[_CC]]] = None,
info: Optional[_InfoType] = None,
+ **kwargs: Any,
):
- super().__init__()
+ super().__init__(attribute_options=attribute_options)
- if isinstance(class_, (Mapped, str, sql.ColumnElement)):
- self.attrs = (class_,) + attrs
+ if isinstance(_class_or_attr, (Mapped, str, sql.ColumnElement)):
+ self.attrs = (_class_or_attr,) + attrs
# will initialize within declarative_scan
self.composite_class = None # type: ignore
else:
- self.composite_class = class_ # type: ignore
+ self.composite_class = _class_or_attr # type: ignore
self.attrs = attrs
self.active_history = active_history
@@ -332,19 +334,15 @@ class Composite(
cls: Type[Any],
key: str,
annotation: Optional[_AnnotationScanType],
+ extracted_mapped_annotation: Optional[_AnnotationScanType],
is_dataclass_field: bool,
) -> None:
- MappedColumn = util.preloaded.orm_properties.MappedColumn
-
- argument = _extract_mapped_subtype(
- annotation,
- cls,
- key,
- MappedColumn,
- self.composite_class is None,
- is_dataclass_field,
- )
-
+ if (
+ self.composite_class is None
+ and extracted_mapped_annotation is None
+ ):
+ self._raise_for_required(key, cls)
+ argument = extracted_mapped_annotation
if argument and self.composite_class is None:
if isinstance(argument, str) or hasattr(
argument, "__forward_arg__"
@@ -371,11 +369,18 @@ class Composite(
for param, attr in itertools.zip_longest(
insp.parameters.values(), self.attrs
):
- if param is None or attr is None:
+ if param is None:
raise sa_exc.ArgumentError(
- f"number of arguments to {self.composite_class.__name__} "
- f"class and number of attributes don't match"
+ f"number of composite attributes "
+ f"{len(self.attrs)} exceeds "
+ f"that of the number of attributes in class "
+ f"{self.composite_class.__name__} {len(insp.parameters)}"
)
+ if attr is None:
+ # fill in missing attr spots with empty MappedColumn
+ attr = MappedColumn()
+ self.attrs += (attr,)
+
if isinstance(attr, MappedColumn):
attr.declarative_scan_for_composite(
registry, cls, key, param.name, param.annotation
@@ -800,10 +805,11 @@ class Synonym(DescriptorProperty[_T]):
map_column: Optional[bool] = None,
descriptor: Optional[Any] = None,
comparator_factory: Optional[Type[PropComparator[_T]]] = None,
+ attribute_options: Optional[_AttributeOptions] = None,
info: Optional[_InfoType] = None,
doc: Optional[str] = None,
):
- super().__init__()
+ super().__init__(attribute_options=attribute_options)
self.name = name
self.map_column = map_column
diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py
index 4fa61b7ce..33de2aee9 100644
--- a/lib/sqlalchemy/orm/instrumentation.py
+++ b/lib/sqlalchemy/orm/instrumentation.py
@@ -113,6 +113,7 @@ class ClassManager(
"previously known as deferred_scalar_loader"
init_method: Optional[Callable[..., None]]
+ original_init: Optional[Callable[..., None]] = None
factory: Optional[_ManagerFactory]
@@ -229,7 +230,7 @@ class ClassManager(
if finalize and not self._finalized:
self._finalize()
- def _finalize(self):
+ def _finalize(self) -> None:
if self._finalized:
return
self._finalized = True
@@ -238,14 +239,14 @@ class ClassManager(
_instrumentation_factory.dispatch.class_instrument(self.class_)
- def __hash__(self):
+ def __hash__(self) -> int: # type: ignore[override]
return id(self)
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return other is self
@property
- def is_mapped(self):
+ def is_mapped(self) -> bool:
return "mapper" in self.__dict__
@HasMemoized.memoized_attribute
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index b5569ce06..e0034061d 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -19,6 +19,7 @@ are exposed when inspecting mappings.
from __future__ import annotations
import collections
+import dataclasses
import typing
from typing import Any
from typing import Callable
@@ -27,6 +28,8 @@ from typing import ClassVar
from typing import Dict
from typing import Iterator
from typing import List
+from typing import NamedTuple
+from typing import NoReturn
from typing import Optional
from typing import Sequence
from typing import Set
@@ -51,11 +54,13 @@ from .base import ONETOMANY as ONETOMANY # noqa: F401
from .base import RelationshipDirection as RelationshipDirection # noqa: F401
from .base import SQLORMOperations
from .. import ColumnElement
+from .. import exc as sa_exc
from .. import inspection
from .. import util
from ..sql import operators
from ..sql import roles
from ..sql import visitors
+from ..sql.base import _NoArg
from ..sql.base import ExecutableOption
from ..sql.cache_key import HasCacheKey
from ..sql.schema import Column
@@ -141,6 +146,7 @@ class _IntrospectsAnnotations:
cls: Type[Any],
key: str,
annotation: Optional[_AnnotationScanType],
+ extracted_mapped_annotation: Optional[_AnnotationScanType],
is_dataclass_field: bool,
) -> None:
"""Perform class-specific initializaton at early declarative scanning
@@ -150,6 +156,70 @@ class _IntrospectsAnnotations:
"""
+ def _raise_for_required(self, key: str, cls: Type[Any]) -> NoReturn:
+ raise sa_exc.ArgumentError(
+ f"Python typing annotation is required for attribute "
+ f'"{cls.__name__}.{key}" when primary argument(s) for '
+ f'"{self.__class__.__name__}" construct are None or not present'
+ )
+
+
+class _AttributeOptions(NamedTuple):
+ """define Python-local attribute behavior options common to all
+ :class:`.MapperProperty` objects.
+
+ Currently this includes dataclass-generation arguments.
+
+ .. versionadded:: 2.0
+
+ """
+
+ dataclasses_init: Union[_NoArg, bool]
+ dataclasses_repr: Union[_NoArg, bool]
+ dataclasses_default: Union[_NoArg, Any]
+ dataclasses_default_factory: Union[_NoArg, Callable[[], Any]]
+
+ def _as_dataclass_field(self) -> Any:
+ """Return a ``dataclasses.Field`` object given these arguments."""
+
+ kw: Dict[str, Any] = {}
+ if self.dataclasses_default_factory is not _NoArg.NO_ARG:
+ kw["default_factory"] = self.dataclasses_default_factory
+ if self.dataclasses_default is not _NoArg.NO_ARG:
+ kw["default"] = self.dataclasses_default
+ if self.dataclasses_init is not _NoArg.NO_ARG:
+ kw["init"] = self.dataclasses_init
+ if self.dataclasses_repr is not _NoArg.NO_ARG:
+ kw["repr"] = self.dataclasses_repr
+
+ return dataclasses.field(**kw)
+
+ @classmethod
+ def _get_arguments_for_make_dataclass(
+ cls, key: str, annotation: Type[Any], elem: _T
+ ) -> Union[
+ Tuple[str, Type[Any]], Tuple[str, Type[Any], dataclasses.Field[Any]]
+ ]:
+ """given attribute key, annotation, and value from a class, return
+ the argument tuple we would pass to dataclasses.make_dataclass()
+ for this attribute.
+
+ """
+ if isinstance(elem, (MapperProperty, _MapsColumns)):
+ dc_field = elem._attribute_options._as_dataclass_field()
+
+ return (key, annotation, dc_field)
+ elif elem is not _NoArg.NO_ARG:
+ # why is typing not erroring on this?
+ return (key, annotation, elem)
+ else:
+ return (key, annotation)
+
+
+_DEFAULT_ATTRIBUTE_OPTIONS = _AttributeOptions(
+ _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG
+)
+
class _MapsColumns(_MappedAttribute[_T]):
"""interface for declarative-capable construct that delivers one or more
@@ -158,6 +228,9 @@ class _MapsColumns(_MappedAttribute[_T]):
__slots__ = ()
+ _attribute_options: _AttributeOptions
+ _has_dataclass_arguments: bool
+
@property
def mapper_property_to_assign(self) -> Optional[MapperProperty[_T]]:
"""return a MapperProperty to be assigned to the declarative mapping"""
@@ -199,6 +272,8 @@ class MapperProperty(
__slots__ = (
"_configure_started",
"_configure_finished",
+ "_attribute_options",
+ "_has_dataclass_arguments",
"parent",
"key",
"info",
@@ -241,6 +316,15 @@ class MapperProperty(
doc: Optional[str]
"""optional documentation string"""
+ _attribute_options: _AttributeOptions
+ """behavioral options for ORM-enabled Python attributes
+
+ .. versionadded:: 2.0
+
+ """
+
+ _has_dataclass_arguments: bool
+
def _memoized_attr_info(self) -> _InfoType:
"""Info dictionary associated with the object, allowing user-defined
data to be associated with this :class:`.InspectionAttr`.
@@ -349,9 +433,20 @@ class MapperProperty(
"""
- def __init__(self) -> None:
+ def __init__(
+ self, attribute_options: Optional[_AttributeOptions] = None
+ ) -> None:
self._configure_started = False
self._configure_finished = False
+ if (
+ attribute_options
+ and attribute_options != _DEFAULT_ATTRIBUTE_OPTIONS
+ ):
+ self._has_dataclass_arguments = True
+ self._attribute_options = attribute_options
+ else:
+ self._has_dataclass_arguments = False
+ self._attribute_options = _DEFAULT_ATTRIBUTE_OPTIONS
def init(self) -> None:
"""Called after all mappers are created to assemble
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index ad3e9f248..7655f3ae2 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -30,13 +30,14 @@ from . import strategy_options
from .descriptor_props import Composite
from .descriptor_props import ConcreteInheritedProperty
from .descriptor_props import Synonym
+from .interfaces import _AttributeOptions
+from .interfaces import _DEFAULT_ATTRIBUTE_OPTIONS
from .interfaces import _IntrospectsAnnotations
from .interfaces import _MapsColumns
from .interfaces import MapperProperty
from .interfaces import PropComparator
from .interfaces import StrategizedProperty
from .relationships import Relationship
-from .util import _extract_mapped_subtype
from .util import _orm_full_deannotate
from .. import exc as sa_exc
from .. import ForeignKey
@@ -45,6 +46,7 @@ from .. import util
from ..sql import coercions
from ..sql import roles
from ..sql import sqltypes
+from ..sql.base import _NoArg
from ..sql.elements import SQLCoreOperations
from ..sql.schema import Column
from ..sql.schema import SchemaConst
@@ -131,6 +133,7 @@ class ColumnProperty(
self,
column: _ORMColumnExprArgument[_T],
*additional_columns: _ORMColumnExprArgument[Any],
+ attribute_options: Optional[_AttributeOptions] = None,
group: Optional[str] = None,
deferred: bool = False,
raiseload: bool = False,
@@ -141,7 +144,9 @@ class ColumnProperty(
doc: Optional[str] = None,
_instrument: bool = True,
):
- super(ColumnProperty, self).__init__()
+ super(ColumnProperty, self).__init__(
+ attribute_options=attribute_options
+ )
columns = (column,) + additional_columns
self._orig_columns = [
coercions.expect(roles.LabeledColumnExprRole, c) for c in columns
@@ -193,6 +198,7 @@ class ColumnProperty(
cls: Type[Any],
key: str,
annotation: Optional[_AnnotationScanType],
+ extracted_mapped_annotation: Optional[_AnnotationScanType],
is_dataclass_field: bool,
) -> None:
column = self.columns[0]
@@ -487,13 +493,38 @@ class MappedColumn(
"foreign_keys",
"_has_nullable",
"deferred",
+ "_attribute_options",
+ "_has_dataclass_arguments",
)
deferred: bool
column: Column[_T]
foreign_keys: Optional[Set[ForeignKey]]
+ _attribute_options: _AttributeOptions
def __init__(self, *arg: Any, **kw: Any):
+ self._attribute_options = attr_opts = kw.pop(
+ "attribute_options", _DEFAULT_ATTRIBUTE_OPTIONS
+ )
+
+ self._has_dataclass_arguments = False
+
+ if attr_opts is not None and attr_opts != _DEFAULT_ATTRIBUTE_OPTIONS:
+ if attr_opts.dataclasses_default_factory is not _NoArg.NO_ARG:
+ self._has_dataclass_arguments = True
+ kw["default"] = attr_opts.dataclasses_default_factory
+ elif attr_opts.dataclasses_default is not _NoArg.NO_ARG:
+ kw["default"] = attr_opts.dataclasses_default
+
+ if (
+ attr_opts.dataclasses_init is not _NoArg.NO_ARG
+ or attr_opts.dataclasses_repr is not _NoArg.NO_ARG
+ ):
+ self._has_dataclass_arguments = True
+
+ if "default" in kw and kw["default"] is _NoArg.NO_ARG:
+ kw.pop("default")
+
self.deferred = kw.pop("deferred", False)
self.column = cast("Column[_T]", Column(*arg, **kw))
self.foreign_keys = self.column.foreign_keys
@@ -509,13 +540,19 @@ class MappedColumn(
new.deferred = self.deferred
new.foreign_keys = new.column.foreign_keys
new._has_nullable = self._has_nullable
+ new._attribute_options = self._attribute_options
+ new._has_dataclass_arguments = self._has_dataclass_arguments
util.set_creation_order(new)
return new
@property
def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]:
if self.deferred:
- return ColumnProperty(self.column, deferred=True)
+ return ColumnProperty(
+ self.column,
+ deferred=True,
+ attribute_options=self._attribute_options,
+ )
else:
return None
@@ -543,6 +580,7 @@ class MappedColumn(
cls: Type[Any],
key: str,
annotation: Optional[_AnnotationScanType],
+ extracted_mapped_annotation: Optional[_AnnotationScanType],
is_dataclass_field: bool,
) -> None:
column = self.column
@@ -553,18 +591,15 @@ class MappedColumn(
sqltype = column.type
- argument = _extract_mapped_subtype(
- annotation,
- cls,
- key,
- MappedColumn,
- sqltype._isnull and not self.column.foreign_keys,
- is_dataclass_field,
- )
- if argument is None:
- return
+ if extracted_mapped_annotation is None:
+ if sqltype._isnull and not self.column.foreign_keys:
+ self._raise_for_required(key, cls)
+ else:
+ return
- self._init_column_for_annotation(cls, registry, argument)
+ self._init_column_for_annotation(
+ cls, registry, extracted_mapped_annotation
+ )
@util.preload_module("sqlalchemy.orm.decl_base")
def declarative_scan_for_composite(
diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py
index 1186f0f54..deaf52147 100644
--- a/lib/sqlalchemy/orm/relationships.py
+++ b/lib/sqlalchemy/orm/relationships.py
@@ -49,6 +49,7 @@ from .base import class_mapper
from .base import LoaderCallableStatus
from .base import PassiveFlag
from .base import state_str
+from .interfaces import _AttributeOptions
from .interfaces import _IntrospectsAnnotations
from .interfaces import MANYTOMANY
from .interfaces import MANYTOONE
@@ -56,7 +57,6 @@ from .interfaces import ONETOMANY
from .interfaces import PropComparator
from .interfaces import RelationshipDirection
from .interfaces import StrategizedProperty
-from .util import _extract_mapped_subtype
from .util import _orm_annotate
from .util import _orm_deannotate
from .util import CascadeOptions
@@ -355,6 +355,7 @@ class Relationship(
post_update: bool = False,
cascade: str = "save-update, merge",
viewonly: bool = False,
+ attribute_options: Optional[_AttributeOptions] = None,
lazy: _LazyLoadArgumentType = "select",
passive_deletes: Union[Literal["all"], bool] = False,
passive_updates: bool = True,
@@ -380,7 +381,7 @@ class Relationship(
_local_remote_pairs: Optional[_ColumnPairs] = None,
_legacy_inactive_history_style: bool = False,
):
- super(Relationship, self).__init__()
+ super(Relationship, self).__init__(attribute_options=attribute_options)
self.uselist = uselist
self.argument = argument
@@ -1701,18 +1702,19 @@ class Relationship(
cls: Type[Any],
key: str,
annotation: Optional[_AnnotationScanType],
+ extracted_mapped_annotation: Optional[_AnnotationScanType],
is_dataclass_field: bool,
) -> None:
- argument = _extract_mapped_subtype(
- annotation,
- cls,
- key,
- Relationship,
- self.argument is None,
- is_dataclass_field,
- )
- if argument is None:
- return
+ argument = extracted_mapped_annotation
+
+ if extracted_mapped_annotation is None:
+
+ if self.argument is None:
+ self._raise_for_required(key, cls)
+ else:
+ return
+
+ argument = extracted_mapped_annotation
if hasattr(argument, "__origin__"):
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index c50cc5bac..520c95672 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -1927,7 +1927,7 @@ def _getitem(iterable_query: Query[Any], item: Any) -> Any:
def _is_mapped_annotation(
- raw_annotation: Union[type, str], cls: Type[Any]
+ raw_annotation: _AnnotationScanType, cls: Type[Any]
) -> bool:
annotated = de_stringify_annotation(cls, raw_annotation)
return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm")
@@ -1969,9 +1969,14 @@ def _extract_mapped_subtype(
attr_cls: Type[Any],
required: bool,
is_dataclass_field: bool,
- superclasses: Optional[Tuple[Type[Any], ...]] = None,
+ expect_mapped: bool = True,
) -> Optional[Union[type, str]]:
+ """given an annotation, figure out if it's ``Mapped[something]`` and if
+ so, return the ``something`` part.
+ Includes error raise scenarios and other options.
+
+ """
if raw_annotation is None:
if required:
@@ -1989,25 +1994,29 @@ def _extract_mapped_subtype(
if is_dataclass_field:
return annotated
else:
- # TODO: there don't seem to be tests for the failure
- # conditions here
- if not hasattr(annotated, "__origin__") or (
- not issubclass(
- annotated.__origin__, # type: ignore
- superclasses if superclasses else attr_cls,
- )
- and not issubclass(attr_cls, annotated.__origin__) # type: ignore
+ if not hasattr(annotated, "__origin__") or not is_origin_of(
+ annotated, "Mapped", module="sqlalchemy.orm"
):
- our_annotated_str = (
- annotated.__name__
+ anno_name = (
+ getattr(annotated, "__name__", None)
if not isinstance(annotated, str)
- else repr(annotated)
- )
- raise sa_exc.ArgumentError(
- f'Type annotation for "{cls.__name__}.{key}" should use the '
- f'syntax "Mapped[{our_annotated_str}]" or '
- f'"{attr_cls.__name__}[{our_annotated_str}]".'
+ else None
)
+ if anno_name is None:
+ our_annotated_str = repr(annotated)
+ else:
+ our_annotated_str = anno_name
+
+ if expect_mapped:
+ raise sa_exc.ArgumentError(
+ f'Type annotation for "{cls.__name__}.{key}" '
+ "should use the "
+ f'syntax "Mapped[{our_annotated_str}]" or '
+ f'"{attr_cls.__name__}[{our_annotated_str}]".'
+ )
+
+ else:
+ return annotated
if len(annotated.__args__) != 1: # type: ignore
raise sa_exc.ArgumentError(
diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py
index 53f76f3ce..d4e4d2dca 100644
--- a/lib/sqlalchemy/testing/fixtures.py
+++ b/lib/sqlalchemy/testing/fixtures.py
@@ -25,6 +25,7 @@ from .. import event
from .. import util
from ..orm import declarative_base
from ..orm import DeclarativeBase
+from ..orm import MappedAsDataclass
from ..orm import registry
from ..schema import sort_tables_and_constraints
@@ -90,7 +91,14 @@ class TestBase:
@config.fixture()
def registry(self, metadata):
- reg = registry(metadata=metadata)
+ reg = registry(
+ metadata=metadata,
+ type_annotation_map={
+ str: sa.String().with_variant(
+ sa.String(50), "mysql", "mariadb"
+ )
+ },
+ )
yield reg
reg.dispose()
@@ -109,6 +117,21 @@ class TestBase:
yield Base
Base.registry.dispose()
+ @config.fixture
+ def dc_decl_base(self, metadata):
+ _md = metadata
+
+ class Base(MappedAsDataclass, DeclarativeBase):
+ metadata = _md
+ type_annotation_map = {
+ str: sa.String().with_variant(
+ sa.String(50), "mysql", "mariadb"
+ )
+ }
+
+ yield Base
+ Base.registry.dispose()
+
@config.fixture()
def future_connection(self, future_engine, connection):
# integrate the future_engine and connection fixtures so
diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py
index adbbf143f..4ce1e7ff3 100644
--- a/lib/sqlalchemy/util/compat.py
+++ b/lib/sqlalchemy/util/compat.py
@@ -230,7 +230,11 @@ def inspect_formatargspec(
def dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]:
"""Return a sequence of all dataclasses.Field objects associated
- with a class."""
+ with a class as an already processed dataclass.
+
+ The class must **already be a dataclass** for Field objects to be returned.
+
+ """
if dataclasses.is_dataclass(cls):
return dataclasses.fields(cls)
@@ -240,7 +244,12 @@ def dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]:
def local_dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]:
"""Return a sequence of all dataclasses.Field objects associated with
- a class, excluding those that originate from a superclass."""
+ an already processed dataclass, excluding those that originate from a
+ superclass.
+
+ The class must **already be a dataclass** for Field objects to be returned.
+
+ """
if dataclasses.is_dataclass(cls):
super_fields: Set[dataclasses.Field[Any]] = set()
diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py
index 44e26f609..454de100b 100644
--- a/lib/sqlalchemy/util/typing.py
+++ b/lib/sqlalchemy/util/typing.py
@@ -23,6 +23,14 @@ from typing_extensions import NotRequired as NotRequired # noqa: F401
from . import compat
+
+# more zimports issues
+if True:
+ from typing_extensions import ( # noqa: F401
+ dataclass_transform as dataclass_transform,
+ )
+
+
_T = TypeVar("_T", bound=Any)
_KT = TypeVar("_KT")
_KT_co = TypeVar("_KT_co", covariant=True)