summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/sqlalchemy/orm/decl_api.py10
-rw-r--r--lib/sqlalchemy/orm/decl_base.py48
-rw-r--r--test/orm/declarative/test_mixin.py90
-rw-r--r--test/orm/declarative/test_typed_mapping.py5
4 files changed, 128 insertions, 25 deletions
diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py
index e6e69a9e0..05d6dacfb 100644
--- a/lib/sqlalchemy/orm/decl_api.py
+++ b/lib/sqlalchemy/orm/decl_api.py
@@ -324,6 +324,16 @@ class declared_attr(interfaces._MappedAttribute[_T]):
fn: _DeclaredAttrDecorated[_T],
cascading: bool = False,
):
+ # suppport
+ # @declared_attr
+ # @classmethod
+ # def foo(cls) -> Mapped[thing]:
+ # ...
+ # which seems to help typing tools interpret the fn as a classmethod
+ # for situations where needed
+ if isinstance(fn, classmethod):
+ fn = fn.__func__ # type: ignore
+
self.fget = fn
self._cascading = cascading
self.__doc__ = fn.__doc__
diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py
index 62251fa2b..108027dd5 100644
--- a/lib/sqlalchemy/orm/decl_base.py
+++ b/lib/sqlalchemy/orm/decl_base.py
@@ -737,6 +737,7 @@ class _ClassScanMapperConfig(_MapperConfig):
locally_collected_columns = self._produce_column_copies(
local_attributes_for_class,
attribute_is_overridden,
+ fixed_table,
)
else:
locally_collected_columns = {}
@@ -828,9 +829,7 @@ class _ClassScanMapperConfig(_MapperConfig):
# acting like that for now.
if isinstance(obj, (Column, MappedColumn)):
- self._collect_annotation(
- name, annotation, is_dataclass_field, True, obj
- )
+ self._collect_annotation(name, annotation, True, obj)
# already copied columns to the mapped class.
continue
elif isinstance(obj, MapperProperty):
@@ -913,23 +912,18 @@ class _ClassScanMapperConfig(_MapperConfig):
self._collect_annotation(
name,
obj._collect_return_annotation(),
- False,
True,
obj,
)
elif _is_mapped_annotation(annotation, cls):
- generated_obj = self._collect_annotation(
- name, annotation, is_dataclass_field, True, obj
- )
- if obj is None:
- if not fixed_table:
- collected_attributes[name] = (
- generated_obj
- if generated_obj is not None
- else MappedColumn()
- )
- else:
- collected_attributes[name] = obj
+ # Mapped annotation without any object.
+ # product_column_copies should have handled this.
+ # if future support for other MapperProperty,
+ # then test if this name is already handled and
+ # otherwise proceed to generate.
+ if not fixed_table:
+ assert name in collected_attributes
+ continue
else:
# here, the attribute is some other kind of
# property that we assume is not part of the
@@ -953,12 +947,10 @@ class _ClassScanMapperConfig(_MapperConfig):
obj = obj.fget()
collected_attributes[name] = obj
- self._collect_annotation(
- name, annotation, True, False, obj
- )
+ self._collect_annotation(name, annotation, False, obj)
else:
generated_obj = self._collect_annotation(
- name, annotation, False, None, obj
+ name, annotation, None, obj
)
if (
obj is None
@@ -1060,7 +1052,6 @@ class _ClassScanMapperConfig(_MapperConfig):
self,
name: str,
raw_annotation: _AnnotationScanType,
- is_dataclass: bool,
expect_mapped: Optional[bool],
attr_value: Any,
) -> Any:
@@ -1128,6 +1119,7 @@ class _ClassScanMapperConfig(_MapperConfig):
[], Iterable[Tuple[str, Any, Any, bool]]
],
attribute_is_overridden: Callable[[str, Any], bool],
+ fixed_table: bool,
) -> Dict[str, Union[Column[Any], MappedColumn[Any]]]:
cls = self.cls
dict_ = self.clsdict_view
@@ -1136,7 +1128,19 @@ class _ClassScanMapperConfig(_MapperConfig):
# copy mixin columns to the mapped class
for name, obj, annotation, is_dataclass in attributes_for_class():
- if isinstance(obj, (Column, MappedColumn)):
+ if (
+ not fixed_table
+ and obj is None
+ and _is_mapped_annotation(annotation, cls)
+ ):
+ obj = self._collect_annotation(name, annotation, True, obj)
+ if obj is None:
+ obj = MappedColumn()
+
+ locally_collected_attributes[name] = obj
+ setattr(cls, name, obj)
+
+ elif isinstance(obj, (Column, MappedColumn)):
if attribute_is_overridden(name, obj):
# if column has been overridden
# (like by the InstrumentedAttribute of the
diff --git a/test/orm/declarative/test_mixin.py b/test/orm/declarative/test_mixin.py
index 72e14ceeb..a6851de5b 100644
--- a/test/orm/declarative/test_mixin.py
+++ b/test/orm/declarative/test_mixin.py
@@ -1,5 +1,7 @@
from operator import is_not
+from typing_extensions import Annotated
+
import sqlalchemy as sa
from sqlalchemy import ForeignKey
from sqlalchemy import func
@@ -21,6 +23,7 @@ from sqlalchemy.orm import declared_attr
from sqlalchemy.orm import deferred
from sqlalchemy.orm import events as orm_events
from sqlalchemy.orm import has_inherited_table
+from sqlalchemy.orm import Mapped
from sqlalchemy.orm import registry
from sqlalchemy.orm import relationship
from sqlalchemy.orm import synonym
@@ -1646,6 +1649,93 @@ class DeclarativeMixinPropertyTest(
m2,
)
+ @testing.combinations(
+ "anno",
+ "anno_w_clsmeth",
+ "pep593",
+ "nonanno",
+ "legacy",
+ argnames="clstype",
+ )
+ def test_column_property_col_ref(self, decl_base, clstype):
+
+ if clstype == "anno":
+
+ class SomethingMixin:
+ x: Mapped[int]
+ y: Mapped[int] = mapped_column()
+
+ @declared_attr
+ def x_plus_y(cls) -> Mapped[int]:
+ return column_property(cls.x + cls.y)
+
+ elif clstype == "anno_w_clsmeth":
+ # this form works better w/ pylance, so support it
+ class SomethingMixin:
+ x: Mapped[int]
+ y: Mapped[int] = mapped_column()
+
+ @declared_attr
+ @classmethod
+ def x_plus_y(cls) -> Mapped[int]:
+ return column_property(cls.x + cls.y)
+
+ elif clstype == "nonanno":
+
+ class SomethingMixin:
+ x = mapped_column(Integer)
+ y = mapped_column(Integer)
+
+ @declared_attr
+ def x_plus_y(cls) -> Mapped[int]:
+ return column_property(cls.x + cls.y)
+
+ elif clstype == "pep593":
+ myint = Annotated[int, mapped_column(Integer)]
+
+ class SomethingMixin:
+ x: Mapped[myint]
+ y: Mapped[myint]
+
+ @declared_attr
+ def x_plus_y(cls) -> Mapped[int]:
+ return column_property(cls.x + cls.y)
+
+ elif clstype == "legacy":
+
+ class SomethingMixin:
+ x = Column(Integer)
+ y = Column(Integer)
+
+ @declared_attr
+ def x_plus_y(cls) -> Mapped[int]:
+ return column_property(cls.x + cls.y)
+
+ else:
+ assert False
+
+ class Something(SomethingMixin, Base):
+ __tablename__ = "something"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+
+ class SomethingElse(SomethingMixin, Base):
+ __tablename__ = "something_else"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+
+ # use the mixin twice, make sure columns are copied, etc
+ self.assert_compile(
+ select(Something.x_plus_y),
+ "SELECT something.x + something.y AS anon_1 FROM something",
+ )
+
+ self.assert_compile(
+ select(SomethingElse.x_plus_y),
+ "SELECT something_else.x + something_else.y AS anon_1 "
+ "FROM something_else",
+ )
+
def test_doc(self):
"""test documentation transfer.
diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py
index cd45d96d1..c33aef9c4 100644
--- a/test/orm/declarative/test_typed_mapping.py
+++ b/test/orm/declarative/test_typed_mapping.py
@@ -815,10 +815,9 @@ class MixinTest(fixtures.TestBase, testing.AssertsCompiledSQL):
__tablename__ = "a"
id: Mapped[int] = mapped_column(primary_key=True)
- # ordering of cols is TODO
- eq_(A.__table__.c.keys(), ["id", "y", "name", "x"])
+ eq_(A.__table__.c.keys(), ["id", "name", "x", "y"])
- self.assert_compile(select(A), "SELECT a.id, a.y, a.name, a.x FROM a")
+ self.assert_compile(select(A), "SELECT a.id, a.name, a.x, a.y FROM a")
def test_mapped_column_omit_fn_fixed_table(self, decl_base):
class MixinOne: