diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2023-02-16 02:37:52 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2023-02-16 02:37:52 +0000 |
| commit | 7ec7563892acf4c67abe638afc4ca970eefbcd51 (patch) | |
| tree | 944292f46dd8caa4b47d9cc4c74e55dac0114d87 /lib | |
| parent | 3fd081d070716fd5fc578555f945d503f9a91f91 (diff) | |
| parent | 18fd19e60d55b35408d94b892e0a2051bcb7ec88 (diff) | |
| download | sqlalchemy-7ec7563892acf4c67abe638afc4ca970eefbcd51.tar.gz | |
Merge "add dataclasses callable and apply annotations more strictly" into main
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/sqlalchemy/orm/decl_api.py | 14 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/decl_base.py | 81 |
2 files changed, 82 insertions, 13 deletions
diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index d02012b86..f332d2964 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -593,6 +593,9 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative): unsafe_hash: Union[_NoArg, bool] = _NoArg.NO_ARG, match_args: Union[_NoArg, bool] = _NoArg.NO_ARG, kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + dataclass_callable: Union[ + _NoArg, Callable[..., Type[Any]] + ] = _NoArg.NO_ARG, ) -> None: apply_dc_transforms: _DataclassArguments = { "init": init, @@ -602,6 +605,7 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative): "unsafe_hash": unsafe_hash, "match_args": match_args, "kw_only": kw_only, + "dataclass_callable": dataclass_callable, } current_transforms: _DataclassArguments @@ -623,8 +627,11 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative): super().__init_subclass__() if not _is_mapped_class(cls): + new_anno = ( + _ClassScanMapperConfig._update_annotations_for_non_mapped_class + )(cls) _ClassScanMapperConfig._apply_dataclasses_to_any_class( - current_transforms, cls + current_transforms, cls, new_anno ) @@ -1569,6 +1576,7 @@ class registry: unsafe_hash: Union[_NoArg, bool] = ..., match_args: Union[_NoArg, bool] = ..., kw_only: Union[_NoArg, bool] = ..., + dataclass_callable: Union[_NoArg, Callable[..., Type[Any]]] = ..., ) -> Callable[[Type[_O]], Type[_O]]: ... @@ -1583,6 +1591,9 @@ class registry: unsafe_hash: Union[_NoArg, bool] = _NoArg.NO_ARG, match_args: Union[_NoArg, bool] = _NoArg.NO_ARG, kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + dataclass_callable: Union[ + _NoArg, Callable[..., Type[Any]] + ] = _NoArg.NO_ARG, ) -> Union[Type[_O], Callable[[Type[_O]], Type[_O]]]: """Class decorator that will apply the Declarative mapping process to a given class, and additionally convert the class to be a @@ -1608,6 +1619,7 @@ class registry: "unsafe_hash": unsafe_hash, "match_args": match_args, "kw_only": kw_only, + "dataclass_callable": dataclass_callable, } _as_declarative(self, cls, cls.__dict__) return cls diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index aeed9b439..f0be55b89 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -126,6 +126,7 @@ class _DataclassArguments(TypedDict): unsafe_hash: Union[_NoArg, bool] match_args: Union[_NoArg, bool] kw_only: Union[_NoArg, bool] + dataclass_callable: Union[_NoArg, Callable[..., Type[Any]]] def _declared_mapping_info( @@ -1099,26 +1100,81 @@ class _ClassScanMapperConfig(_MapperConfig): for k, v in defaults.items(): setattr(self.cls, k, v) - self.cls.__annotations__ = annotations - self._apply_dataclasses_to_any_class( - dataclass_setup_arguments, self.cls + dataclass_setup_arguments, self.cls, annotations ) @classmethod + def _update_annotations_for_non_mapped_class( + cls, klass: Type[_O] + ) -> Mapping[str, _AnnotationScanType]: + cls_annotations = util.get_annotations(klass) + + new_anno = {} + for name, annotation in cls_annotations.items(): + if _is_mapped_annotation(annotation, klass, klass): + + extracted = _extract_mapped_subtype( + annotation, + klass, + klass.__module__, + name, + type(None), + required=False, + is_dataclass_field=False, + expect_mapped=False, + ) + if extracted: + inner, _ = extracted + new_anno[name] = inner + else: + new_anno[name] = annotation + return new_anno + + @classmethod def _apply_dataclasses_to_any_class( - cls, dataclass_setup_arguments: _DataclassArguments, klass: Type[_O] + cls, + dataclass_setup_arguments: _DataclassArguments, + klass: Type[_O], + use_annotations: Mapping[str, _AnnotationScanType], ) -> None: cls._assert_dc_arguments(dataclass_setup_arguments) - dataclasses.dataclass( - klass, - **{ - k: v - for k, v in dataclass_setup_arguments.items() - if v is not _NoArg.NO_ARG - }, - ) + dataclass_callable = dataclass_setup_arguments["dataclass_callable"] + if dataclass_callable is _NoArg.NO_ARG: + dataclass_callable = dataclasses.dataclass + + restored: Optional[Any] + + if use_annotations: + # apply constructed annotations that should look "normal" to a + # dataclasses callable, based on the fields present. This + # means remove the Mapped[] container and ensure all Field + # entries have an annotation + restored = getattr(klass, "__annotations__", None) + klass.__annotations__ = cast("Dict[str, Any]", use_annotations) + else: + restored = None + + try: + dataclass_callable( + klass, + **{ + k: v + for k, v in dataclass_setup_arguments.items() + if v is not _NoArg.NO_ARG and k != "dataclass_callable" + }, + ) + finally: + # restore original annotations outside of the dataclasses + # process; for mixins and __abstract__ superclasses, SQLAlchemy + # Declarative will need to see the Mapped[] container inside the + # annotations in order to map subclasses + if use_annotations: + if restored is None: + del klass.__annotations__ + else: + klass.__annotations__ = restored @classmethod def _assert_dc_arguments(cls, arguments: _DataclassArguments) -> None: @@ -1130,6 +1186,7 @@ class _ClassScanMapperConfig(_MapperConfig): "unsafe_hash", "kw_only", "match_args", + "dataclass_callable", } disallowed_args = set(arguments).difference(allowed) if disallowed_args: |
