diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-12-03 14:35:30 -0500 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-12-19 12:43:16 -0500 |
| commit | 9625adba3553803acd5488660d65c8e675a61fa6 (patch) | |
| tree | 23106014c9c6ed76ab2a81502cafa2ce99cbe0b5 /lib/sqlalchemy | |
| parent | 27766512b2d037a8f0048dccc6e2f02c281fbc9a (diff) | |
| download | sqlalchemy-9625adba3553803acd5488660d65c8e675a61fa6.tar.gz | |
Allow Declarative to extract class attr from field
Added an alternate resolution scheme to Declarative that will extract the
SQLAlchemy column or mapped property from the "metadata" dictionary of a
dataclasses.Field object. This allows full declarative mappings to be
combined with dataclass fields.
Fixes: #5745
Change-Id: I1165bc025246a4cb9fc099b1b7c46a6b0f799b23
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/orm/decl_base.py | 28 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/__init__.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/compat.py | 16 |
4 files changed, 42 insertions, 14 deletions
diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index 8da326b0e..353f44e43 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -334,6 +334,9 @@ class _ClassScanMapperConfig(_MapperConfig): tablename = None for base in cls.__mro__: + + sa_dataclass_metadata_key = None + class_mapped = ( base is not cls and _declared_mapping_info(base) is not None @@ -342,10 +345,25 @@ class _ClassScanMapperConfig(_MapperConfig): ) ) + if sa_dataclass_metadata_key is None: + sa_dataclass_metadata_key = _get_immediate_cls_attr( + base, "__sa_dataclass_metadata_key__", None + ) + + def attributes_for_class(cls): + for name, obj in vars(cls).items(): + yield name, obj + if sa_dataclass_metadata_key: + for field in util.dataclass_fields(cls): + if sa_dataclass_metadata_key in field.metadata: + yield field.name, field.metadata[ + sa_dataclass_metadata_key + ] + if not class_mapped and base is not cls: - self._produce_column_copies(base) + self._produce_column_copies(attributes_for_class, base) - for name, obj in vars(base).items(): + for name, obj in attributes_for_class(base): if name == "__mapper_args__": check_decl = _check_declared_props_nocascade( obj, name, cls @@ -452,6 +470,8 @@ class _ClassScanMapperConfig(_MapperConfig): # however, check for some more common mistakes else: self._warn_for_decl_attributes(base, name, obj) + elif name not in dict_ or dict_[name] is not obj: + dict_[name] = obj if inherited_table_args and not tablename: table_args = None @@ -469,12 +489,12 @@ class _ClassScanMapperConfig(_MapperConfig): % (key, cls) ) - def _produce_column_copies(self, base): + def _produce_column_copies(self, attributes_for_class, base): cls = self.cls dict_ = self.dict_ column_copies = self.column_copies # copy mixin columns to the mapped class - for name, obj in vars(base).items(): + for name, obj in attributes_for_class(base): if isinstance(obj, Column): if getattr(cls, name) is not obj: # if column has been overridden diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index e50183894..e8f98d150 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -56,12 +56,6 @@ from ..sql import util as sql_util from ..sql import visitors from ..util import HasMemoized -try: - import dataclasses -except ImportError: - # The dataclasses module was added in Python 3.7 - dataclasses = None - _mapper_registry = weakref.WeakKeyDictionary() _already_compiling = False @@ -2645,10 +2639,7 @@ class Mapper( @HasMemoized.memoized_attribute def _dataclass_fields(self): - if dataclasses is None or not dataclasses.is_dataclass(self.class_): - return frozenset() - - return {field.name for field in dataclasses.fields(self.class_)} + return [f.name for f in util.dataclass_fields(self.class_)] def _should_exclude(self, name, assigned_name, local, column): """determine whether a particular property should be implicitly diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index f4363d03c..2e3f68722 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -57,6 +57,7 @@ from .compat import byte_buffer # noqa from .compat import callable # noqa from .compat import cmp # noqa from .compat import cpython # noqa +from .compat import dataclass_fields # noqa from .compat import decode_backslashreplace # noqa from .compat import dottedgetter # noqa from .compat import has_refcount_gc # noqa diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index e8c488047..77c913640 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -421,6 +421,22 @@ else: import collections as collections_abc # noqa +if py37: + import dataclasses + + def dataclass_fields(cls): + if dataclasses.is_dataclass(cls): + return dataclasses.fields(cls) + else: + return [] + + +else: + + def dataclass_fields(cls): + return [] + + def raise_from_cause(exception, exc_info=None): r"""legacy. use raise\_()""" |
