summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2023-02-16 02:37:52 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2023-02-16 02:37:52 +0000
commit7ec7563892acf4c67abe638afc4ca970eefbcd51 (patch)
tree944292f46dd8caa4b47d9cc4c74e55dac0114d87 /lib
parent3fd081d070716fd5fc578555f945d503f9a91f91 (diff)
parent18fd19e60d55b35408d94b892e0a2051bcb7ec88 (diff)
downloadsqlalchemy-7ec7563892acf4c67abe638afc4ca970eefbcd51.tar.gz
Merge "add dataclasses callable and apply annotations more strictly" into main
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/orm/decl_api.py14
-rw-r--r--lib/sqlalchemy/orm/decl_base.py81
2 files changed, 82 insertions, 13 deletions
diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py
index d02012b86..f332d2964 100644
--- a/lib/sqlalchemy/orm/decl_api.py
+++ b/lib/sqlalchemy/orm/decl_api.py
@@ -593,6 +593,9 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative):
unsafe_hash: Union[_NoArg, bool] = _NoArg.NO_ARG,
match_args: Union[_NoArg, bool] = _NoArg.NO_ARG,
kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ dataclass_callable: Union[
+ _NoArg, Callable[..., Type[Any]]
+ ] = _NoArg.NO_ARG,
) -> None:
apply_dc_transforms: _DataclassArguments = {
"init": init,
@@ -602,6 +605,7 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative):
"unsafe_hash": unsafe_hash,
"match_args": match_args,
"kw_only": kw_only,
+ "dataclass_callable": dataclass_callable,
}
current_transforms: _DataclassArguments
@@ -623,8 +627,11 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative):
super().__init_subclass__()
if not _is_mapped_class(cls):
+ new_anno = (
+ _ClassScanMapperConfig._update_annotations_for_non_mapped_class
+ )(cls)
_ClassScanMapperConfig._apply_dataclasses_to_any_class(
- current_transforms, cls
+ current_transforms, cls, new_anno
)
@@ -1569,6 +1576,7 @@ class registry:
unsafe_hash: Union[_NoArg, bool] = ...,
match_args: Union[_NoArg, bool] = ...,
kw_only: Union[_NoArg, bool] = ...,
+ dataclass_callable: Union[_NoArg, Callable[..., Type[Any]]] = ...,
) -> Callable[[Type[_O]], Type[_O]]:
...
@@ -1583,6 +1591,9 @@ class registry:
unsafe_hash: Union[_NoArg, bool] = _NoArg.NO_ARG,
match_args: Union[_NoArg, bool] = _NoArg.NO_ARG,
kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ dataclass_callable: Union[
+ _NoArg, Callable[..., Type[Any]]
+ ] = _NoArg.NO_ARG,
) -> Union[Type[_O], Callable[[Type[_O]], Type[_O]]]:
"""Class decorator that will apply the Declarative mapping process
to a given class, and additionally convert the class to be a
@@ -1608,6 +1619,7 @@ class registry:
"unsafe_hash": unsafe_hash,
"match_args": match_args,
"kw_only": kw_only,
+ "dataclass_callable": dataclass_callable,
}
_as_declarative(self, cls, cls.__dict__)
return cls
diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py
index aeed9b439..f0be55b89 100644
--- a/lib/sqlalchemy/orm/decl_base.py
+++ b/lib/sqlalchemy/orm/decl_base.py
@@ -126,6 +126,7 @@ class _DataclassArguments(TypedDict):
unsafe_hash: Union[_NoArg, bool]
match_args: Union[_NoArg, bool]
kw_only: Union[_NoArg, bool]
+ dataclass_callable: Union[_NoArg, Callable[..., Type[Any]]]
def _declared_mapping_info(
@@ -1099,26 +1100,81 @@ class _ClassScanMapperConfig(_MapperConfig):
for k, v in defaults.items():
setattr(self.cls, k, v)
- self.cls.__annotations__ = annotations
-
self._apply_dataclasses_to_any_class(
- dataclass_setup_arguments, self.cls
+ dataclass_setup_arguments, self.cls, annotations
)
@classmethod
+ def _update_annotations_for_non_mapped_class(
+ cls, klass: Type[_O]
+ ) -> Mapping[str, _AnnotationScanType]:
+ cls_annotations = util.get_annotations(klass)
+
+ new_anno = {}
+ for name, annotation in cls_annotations.items():
+ if _is_mapped_annotation(annotation, klass, klass):
+
+ extracted = _extract_mapped_subtype(
+ annotation,
+ klass,
+ klass.__module__,
+ name,
+ type(None),
+ required=False,
+ is_dataclass_field=False,
+ expect_mapped=False,
+ )
+ if extracted:
+ inner, _ = extracted
+ new_anno[name] = inner
+ else:
+ new_anno[name] = annotation
+ return new_anno
+
+ @classmethod
def _apply_dataclasses_to_any_class(
- cls, dataclass_setup_arguments: _DataclassArguments, klass: Type[_O]
+ cls,
+ dataclass_setup_arguments: _DataclassArguments,
+ klass: Type[_O],
+ use_annotations: Mapping[str, _AnnotationScanType],
) -> None:
cls._assert_dc_arguments(dataclass_setup_arguments)
- dataclasses.dataclass(
- klass,
- **{
- k: v
- for k, v in dataclass_setup_arguments.items()
- if v is not _NoArg.NO_ARG
- },
- )
+ dataclass_callable = dataclass_setup_arguments["dataclass_callable"]
+ if dataclass_callable is _NoArg.NO_ARG:
+ dataclass_callable = dataclasses.dataclass
+
+ restored: Optional[Any]
+
+ if use_annotations:
+ # apply constructed annotations that should look "normal" to a
+ # dataclasses callable, based on the fields present. This
+ # means remove the Mapped[] container and ensure all Field
+ # entries have an annotation
+ restored = getattr(klass, "__annotations__", None)
+ klass.__annotations__ = cast("Dict[str, Any]", use_annotations)
+ else:
+ restored = None
+
+ try:
+ dataclass_callable(
+ klass,
+ **{
+ k: v
+ for k, v in dataclass_setup_arguments.items()
+ if v is not _NoArg.NO_ARG and k != "dataclass_callable"
+ },
+ )
+ finally:
+ # restore original annotations outside of the dataclasses
+ # process; for mixins and __abstract__ superclasses, SQLAlchemy
+ # Declarative will need to see the Mapped[] container inside the
+ # annotations in order to map subclasses
+ if use_annotations:
+ if restored is None:
+ del klass.__annotations__
+ else:
+ klass.__annotations__ = restored
@classmethod
def _assert_dc_arguments(cls, arguments: _DataclassArguments) -> None:
@@ -1130,6 +1186,7 @@ class _ClassScanMapperConfig(_MapperConfig):
"unsafe_hash",
"kw_only",
"match_args",
+ "dataclass_callable",
}
disallowed_args = set(arguments).difference(allowed)
if disallowed_args: