summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2022-03-23 16:58:12 +0000
committerGerrit Code Review <gerrit@ci3.zzzcomputing.com>2022-03-23 16:58:12 +0000
commit675d11f113b3eb5a931d8dfb31328feee2080c27 (patch)
tree23e3c4001cd659a4a257691b1b85be96ea06e8d0
parentc565c470517e1cc70a7f33d1ad3d3256935f1121 (diff)
parentd051645463b169bf1535459653eff247cb772e62 (diff)
downloadsqlalchemy-675d11f113b3eb5a931d8dfb31328feee2080c27.tar.gz
Merge "trust user PK argument as given; don't reduce" into main
-rw-r--r--doc/build/changelog/unreleased_14/7842.rst12
-rw-r--r--lib/sqlalchemy/orm/mapper.py18
-rw-r--r--test/orm/test_mapper.py61
3 files changed, 82 insertions, 9 deletions
diff --git a/doc/build/changelog/unreleased_14/7842.rst b/doc/build/changelog/unreleased_14/7842.rst
new file mode 100644
index 000000000..c165ed44b
--- /dev/null
+++ b/doc/build/changelog/unreleased_14/7842.rst
@@ -0,0 +1,12 @@
+.. change::
+ :tags: bug, orm
+ :tickets: 7842
+
+ Fixed issue where the :class:`_orm.Mapper` would reduce a user-defined
+ :paramref:`_orm.Mapper.primary_key` argument too aggressively, in the case
+ of mapping to a ``UNION`` where for some of the SELECT entries, two columns
+ are essentially equivalent, but in another, they are not, such as in a
+ recursive CTE. The logic here has been changed to accept a given
+ user-defined PK as given, where columns will be related to the mapped
+ selectable but no longer "reduced" as this heuristic can't accommodate for
+ all situations.
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 011e7d2ef..7d1fc7643 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -1368,17 +1368,17 @@ class Mapper(
# that of the inheriting (unless concrete or explicit)
self.primary_key = self.inherits.primary_key
else:
- # determine primary key from argument or persist_selectable pks -
- # reduce to the minimal set of columns
+ # determine primary key from argument or persist_selectable pks
if self._primary_key_argument:
- primary_key = sql_util.reduce_columns(
- [
- self.persist_selectable.corresponding_column(c)
- for c in self._primary_key_argument
- ],
- ignore_nonexistent_tables=True,
- )
+ primary_key = [
+ self.persist_selectable.corresponding_column(c)
+ for c in self._primary_key_argument
+ ]
else:
+ # if heuristically determined PKs, reduce to the minimal set
+ # of columns by eliminating FK->PK pairs for a multi-table
+ # expression. May over-reduce for some kinds of UNIONs
+ # / CTEs; use explicit PK argument for these special cases
primary_key = sql_util.reduce_columns(
self._pks_by_table[self.persist_selectable],
ignore_nonexistent_tables=True,
diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py
index 1fad974b9..980c82fbe 100644
--- a/test/orm/test_mapper.py
+++ b/test/orm/test_mapper.py
@@ -5,6 +5,7 @@ import sqlalchemy as sa
from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import Integer
+from sqlalchemy import literal
from sqlalchemy import MetaData
from sqlalchemy import select
from sqlalchemy import String
@@ -41,6 +42,7 @@ from sqlalchemy.testing import is_
from sqlalchemy.testing import is_false
from sqlalchemy.testing import is_true
from sqlalchemy.testing import ne_
+from sqlalchemy.testing.fixtures import ComparableEntity
from sqlalchemy.testing.fixtures import ComparableMixin
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
@@ -1403,6 +1405,65 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL):
],
)
+ @testing.requires.ctes
+ def test_mapping_to_union_dont_overlimit_pk(self, registry, connection):
+ """test #7842"""
+ Base = registry.generate_base()
+
+ class Node(Base):
+ __tablename__ = "cte_nodes"
+
+ id = Column(Integer, primary_key=True)
+ parent = Column(Integer, ForeignKey("cte_nodes.id"))
+
+ # so we dont have to deal with NULLS FIRST
+ sort_key = Column(Integer)
+
+ class NodeRel(ComparableEntity, Base):
+ table = select(
+ Node.id, Node.parent, Node.sort_key, literal(0).label("depth")
+ ).cte(recursive=True)
+ __table__ = table.union_all(
+ select(
+ Node.id,
+ table.c.parent,
+ table.c.sort_key,
+ table.c.depth + literal(1),
+ )
+ .select_from(Node)
+ .join(table, Node.parent == table.c.id)
+ )
+
+ __mapper_args__ = {
+ "primary_key": (__table__.c.id, __table__.c.parent)
+ }
+
+ nt = NodeRel.__table__
+
+ eq_(NodeRel.__mapper__.primary_key, (nt.c.id, nt.c.parent))
+
+ registry.metadata.create_all(connection)
+ with Session(connection) as session:
+ n1, n2, n3, n4 = (
+ Node(id=1, sort_key=1),
+ Node(id=2, parent=1, sort_key=2),
+ Node(id=3, parent=2, sort_key=3),
+ Node(id=4, parent=3, sort_key=4),
+ )
+ session.add_all([n1, n2, n3, n4])
+ session.commit()
+
+ q_rel = select(NodeRel).filter_by(id=4).order_by(NodeRel.sort_key)
+ eq_(
+ session.scalars(q_rel).all(),
+ [
+ NodeRel(id=4, parent=None),
+ NodeRel(id=4, parent=1),
+ NodeRel(id=4, parent=2),
+ NodeRel(id=4, parent=3),
+ ],
+ )
+
def test_scalar_pk_arg(self):
users, Keyword, items, Item, User, keywords = (
self.tables.users,