summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2020-12-03 14:35:30 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2020-12-19 12:43:16 -0500
commit9625adba3553803acd5488660d65c8e675a61fa6 (patch)
tree23106014c9c6ed76ab2a81502cafa2ce99cbe0b5 /lib/sqlalchemy
parent27766512b2d037a8f0048dccc6e2f02c281fbc9a (diff)
downloadsqlalchemy-9625adba3553803acd5488660d65c8e675a61fa6.tar.gz
Allow Declarative to extract class attr from field
Added an alternate resolution scheme to Declarative that will extract the SQLAlchemy column or mapped property from the "metadata" dictionary of a dataclasses.Field object. This allows full declarative mappings to be combined with dataclass fields. Fixes: #5745 Change-Id: I1165bc025246a4cb9fc099b1b7c46a6b0f799b23
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/decl_base.py28
-rw-r--r--lib/sqlalchemy/orm/mapper.py11
-rw-r--r--lib/sqlalchemy/util/__init__.py1
-rw-r--r--lib/sqlalchemy/util/compat.py16
4 files changed, 42 insertions, 14 deletions
diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py
index 8da326b0e..353f44e43 100644
--- a/lib/sqlalchemy/orm/decl_base.py
+++ b/lib/sqlalchemy/orm/decl_base.py
@@ -334,6 +334,9 @@ class _ClassScanMapperConfig(_MapperConfig):
tablename = None
for base in cls.__mro__:
+
+ sa_dataclass_metadata_key = None
+
class_mapped = (
base is not cls
and _declared_mapping_info(base) is not None
@@ -342,10 +345,25 @@ class _ClassScanMapperConfig(_MapperConfig):
)
)
+ if sa_dataclass_metadata_key is None:
+ sa_dataclass_metadata_key = _get_immediate_cls_attr(
+ base, "__sa_dataclass_metadata_key__", None
+ )
+
+ def attributes_for_class(cls):
+ for name, obj in vars(cls).items():
+ yield name, obj
+ if sa_dataclass_metadata_key:
+ for field in util.dataclass_fields(cls):
+ if sa_dataclass_metadata_key in field.metadata:
+ yield field.name, field.metadata[
+ sa_dataclass_metadata_key
+ ]
+
if not class_mapped and base is not cls:
- self._produce_column_copies(base)
+ self._produce_column_copies(attributes_for_class, base)
- for name, obj in vars(base).items():
+ for name, obj in attributes_for_class(base):
if name == "__mapper_args__":
check_decl = _check_declared_props_nocascade(
obj, name, cls
@@ -452,6 +470,8 @@ class _ClassScanMapperConfig(_MapperConfig):
# however, check for some more common mistakes
else:
self._warn_for_decl_attributes(base, name, obj)
+ elif name not in dict_ or dict_[name] is not obj:
+ dict_[name] = obj
if inherited_table_args and not tablename:
table_args = None
@@ -469,12 +489,12 @@ class _ClassScanMapperConfig(_MapperConfig):
% (key, cls)
)
- def _produce_column_copies(self, base):
+ def _produce_column_copies(self, attributes_for_class, base):
cls = self.cls
dict_ = self.dict_
column_copies = self.column_copies
# copy mixin columns to the mapped class
- for name, obj in vars(base).items():
+ for name, obj in attributes_for_class(base):
if isinstance(obj, Column):
if getattr(cls, name) is not obj:
# if column has been overridden
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index e50183894..e8f98d150 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -56,12 +56,6 @@ from ..sql import util as sql_util
from ..sql import visitors
from ..util import HasMemoized
-try:
- import dataclasses
-except ImportError:
- # The dataclasses module was added in Python 3.7
- dataclasses = None
-
_mapper_registry = weakref.WeakKeyDictionary()
_already_compiling = False
@@ -2645,10 +2639,7 @@ class Mapper(
@HasMemoized.memoized_attribute
def _dataclass_fields(self):
- if dataclasses is None or not dataclasses.is_dataclass(self.class_):
- return frozenset()
-
- return {field.name for field in dataclasses.fields(self.class_)}
+ return [f.name for f in util.dataclass_fields(self.class_)]
def _should_exclude(self, name, assigned_name, local, column):
"""determine whether a particular property should be implicitly
diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py
index f4363d03c..2e3f68722 100644
--- a/lib/sqlalchemy/util/__init__.py
+++ b/lib/sqlalchemy/util/__init__.py
@@ -57,6 +57,7 @@ from .compat import byte_buffer # noqa
from .compat import callable # noqa
from .compat import cmp # noqa
from .compat import cpython # noqa
+from .compat import dataclass_fields # noqa
from .compat import decode_backslashreplace # noqa
from .compat import dottedgetter # noqa
from .compat import has_refcount_gc # noqa
diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py
index e8c488047..77c913640 100644
--- a/lib/sqlalchemy/util/compat.py
+++ b/lib/sqlalchemy/util/compat.py
@@ -421,6 +421,22 @@ else:
import collections as collections_abc # noqa
+if py37:
+ import dataclasses
+
+ def dataclass_fields(cls):
+ if dataclasses.is_dataclass(cls):
+ return dataclasses.fields(cls)
+ else:
+ return []
+
+
+else:
+
+ def dataclass_fields(cls):
+ return []
+
+
def raise_from_cause(exception, exc_info=None):
r"""legacy. use raise\_()"""