diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2020-12-19 19:56:14 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2020-12-19 19:56:14 +0000 |
| commit | a8f51f3c11f3cb2e344732cf3abb371f03ed30d8 (patch) | |
| tree | 316243093b637ebb286d1b8d6a069e51b70c26c7 /lib/sqlalchemy | |
| parent | b63b4275f0059c61c158fafe989e188eb6f8332e (diff) | |
| parent | 9625adba3553803acd5488660d65c8e675a61fa6 (diff) | |
| download | sqlalchemy-a8f51f3c11f3cb2e344732cf3abb371f03ed30d8.tar.gz | |
Merge "Allow Declarative to extract class attr from field"
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/orm/decl_base.py | 28 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/__init__.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/compat.py | 16 |
4 files changed, 42 insertions, 14 deletions
diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index 8da326b0e..353f44e43 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -334,6 +334,9 @@ class _ClassScanMapperConfig(_MapperConfig): tablename = None for base in cls.__mro__: + + sa_dataclass_metadata_key = None + class_mapped = ( base is not cls and _declared_mapping_info(base) is not None @@ -342,10 +345,25 @@ class _ClassScanMapperConfig(_MapperConfig): ) ) + if sa_dataclass_metadata_key is None: + sa_dataclass_metadata_key = _get_immediate_cls_attr( + base, "__sa_dataclass_metadata_key__", None + ) + + def attributes_for_class(cls): + for name, obj in vars(cls).items(): + yield name, obj + if sa_dataclass_metadata_key: + for field in util.dataclass_fields(cls): + if sa_dataclass_metadata_key in field.metadata: + yield field.name, field.metadata[ + sa_dataclass_metadata_key + ] + if not class_mapped and base is not cls: - self._produce_column_copies(base) + self._produce_column_copies(attributes_for_class, base) - for name, obj in vars(base).items(): + for name, obj in attributes_for_class(base): if name == "__mapper_args__": check_decl = _check_declared_props_nocascade( obj, name, cls @@ -452,6 +470,8 @@ class _ClassScanMapperConfig(_MapperConfig): # however, check for some more common mistakes else: self._warn_for_decl_attributes(base, name, obj) + elif name not in dict_ or dict_[name] is not obj: + dict_[name] = obj if inherited_table_args and not tablename: table_args = None @@ -469,12 +489,12 @@ class _ClassScanMapperConfig(_MapperConfig): % (key, cls) ) - def _produce_column_copies(self, base): + def _produce_column_copies(self, attributes_for_class, base): cls = self.cls dict_ = self.dict_ column_copies = self.column_copies # copy mixin columns to the mapped class - for name, obj in vars(base).items(): + for name, obj in attributes_for_class(base): if isinstance(obj, Column): if getattr(cls, name) is not obj: # if column has been overridden diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index e50183894..e8f98d150 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -56,12 +56,6 @@ from ..sql import util as sql_util from ..sql import visitors from ..util import HasMemoized -try: - import dataclasses -except ImportError: - # The dataclasses module was added in Python 3.7 - dataclasses = None - _mapper_registry = weakref.WeakKeyDictionary() _already_compiling = False @@ -2645,10 +2639,7 @@ class Mapper( @HasMemoized.memoized_attribute def _dataclass_fields(self): - if dataclasses is None or not dataclasses.is_dataclass(self.class_): - return frozenset() - - return {field.name for field in dataclasses.fields(self.class_)} + return [f.name for f in util.dataclass_fields(self.class_)] def _should_exclude(self, name, assigned_name, local, column): """determine whether a particular property should be implicitly diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index f4363d03c..2e3f68722 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -57,6 +57,7 @@ from .compat import byte_buffer # noqa from .compat import callable # noqa from .compat import cmp # noqa from .compat import cpython # noqa +from .compat import dataclass_fields # noqa from .compat import decode_backslashreplace # noqa from .compat import dottedgetter # noqa from .compat import has_refcount_gc # noqa diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index e8c488047..77c913640 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -421,6 +421,22 @@ else: import collections as collections_abc # noqa +if py37: + import dataclasses + + def dataclass_fields(cls): + if dataclasses.is_dataclass(cls): + return dataclasses.fields(cls) + else: + return [] + + +else: + + def dataclass_fields(cls): + return [] + + def raise_from_cause(exception, exc_info=None): r"""legacy. use raise\_()""" |
