summaryrefslogtreecommitdiff
path: root/test/orm/inheritance/test_poly_loading.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/orm/inheritance/test_poly_loading.py')
-rw-r--r--test/orm/inheritance/test_poly_loading.py109
1 files changed, 109 insertions, 0 deletions
diff --git a/test/orm/inheritance/test_poly_loading.py b/test/orm/inheritance/test_poly_loading.py
index f03f15bd2..9086be3c4 100644
--- a/test/orm/inheritance/test_poly_loading.py
+++ b/test/orm/inheritance/test_poly_loading.py
@@ -8,6 +8,7 @@ from sqlalchemy import String
from sqlalchemy import testing
from sqlalchemy import union
from sqlalchemy.orm import backref
+from sqlalchemy.orm import column_property
from sqlalchemy.orm import composite
from sqlalchemy.orm import defaultload
from sqlalchemy.orm import immediateload
@@ -1174,3 +1175,111 @@ class CompositeAttributesTest(fixtures.TestBase):
B(id=2, thing2="thing2", comp2=XYThing(3, 4)),
],
)
+
+
+class PolymorphicOnExprTest(
+ testing.AssertsExecutionResults, fixtures.TestBase
+):
+ """test for #8704"""
+
+ @testing.fixture()
+ def poly_fixture(self, connection, decl_base):
+ def fixture(create_prop, use_load):
+ class TypeTable(decl_base):
+ __tablename__ = "type"
+
+ id = Column(Integer, primary_key=True)
+ name = Column(String(30))
+
+ class PolyBase(ComparableEntity, decl_base):
+ __tablename__ = "base"
+
+ id = Column(Integer, primary_key=True)
+ type_id = Column(ForeignKey(TypeTable.id))
+
+ if create_prop == "create_prop":
+ polymorphic = column_property(
+ select(TypeTable.name)
+ .where(TypeTable.id == type_id)
+ .scalar_subquery()
+ )
+ __mapper_args__ = {
+ "polymorphic_on": polymorphic,
+ }
+ elif create_prop == "dont_create_prop":
+ __mapper_args__ = {
+ "polymorphic_on": select(TypeTable.name)
+ .where(TypeTable.id == type_id)
+ .scalar_subquery()
+ }
+ elif create_prop == "arg_level_prop":
+ __mapper_args__ = {
+ "polymorphic_on": column_property(
+ select(TypeTable.name)
+ .where(TypeTable.id == type_id)
+ .scalar_subquery()
+ )
+ }
+
+ class Foo(PolyBase):
+ __tablename__ = "foo"
+
+ if use_load == "use_polymorphic_load":
+ __mapper_args__ = {
+ "polymorphic_identity": "foo",
+ "polymorphic_load": "selectin",
+ }
+ else:
+ __mapper_args__ = {
+ "polymorphic_identity": "foo",
+ }
+
+ id = Column(ForeignKey(PolyBase.id), primary_key=True)
+ foo_attr = Column(String(30))
+
+ decl_base.metadata.create_all(connection)
+
+ with Session(connection) as session:
+ foo_type = TypeTable(name="foo")
+ session.add(foo_type)
+ session.flush()
+
+ foo = Foo(type_id=foo_type.id, foo_attr="foo value")
+ session.add(foo)
+
+ session.commit()
+
+ return PolyBase, Foo, TypeTable
+
+ yield fixture
+
+ @testing.combinations(
+ "create_prop",
+ "dont_create_prop",
+ "arg_level_prop",
+ argnames="create_prop",
+ )
+ @testing.combinations(
+ "use_polymorphic_load",
+ "use_loader_option",
+ "none",
+ argnames="use_load",
+ )
+ def test_load_selectin(
+ self, poly_fixture, connection, create_prop, use_load
+ ):
+ PolyBase, Foo, TypeTable = poly_fixture(create_prop, use_load)
+
+ sess = Session(connection)
+
+ foo_type = sess.scalars(select(TypeTable)).one()
+
+ stmt = select(PolyBase)
+ if use_load == "use_loader_option":
+ stmt = stmt.options(selectin_polymorphic(PolyBase, [Foo]))
+ obj = sess.scalars(stmt).all()
+
+ def go():
+ eq_(obj, [Foo(type_id=foo_type.id, foo_attr="foo value")])
+
+ self.assert_sql_count(testing.db, go, 0 if use_load != "none" else 1)