summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2022-09-26 02:33:19 +0000
committerGerrit Code Review <gerrit@ci3.zzzcomputing.com>2022-09-26 02:33:19 +0000
commit1657cea73d5ec9aeedd541001e125e03e581a34b (patch)
treeb1d8527435fa51f7cec399972ea5af29d4f74a67
parente708cfea0bdaae82ac30dd7d33f9442115b9af6d (diff)
parentc86ec8f8c98b756ef06933174a3f4a0f3cfbed41 (diff)
downloadsqlalchemy-1657cea73d5ec9aeedd541001e125e03e581a34b.tar.gz
Merge "`aggregate_order_by` now supports cache generation." into main
-rw-r--r--doc/build/changelog/unreleased_14/8574.rst5
-rw-r--r--lib/sqlalchemy/dialects/postgresql/ext.py12
-rw-r--r--lib/sqlalchemy/testing/fixtures.py108
-rw-r--r--test/dialect/postgresql/test_compiler.py33
-rw-r--r--test/orm/test_cache_key.py5
-rw-r--r--test/orm/test_deprecations.py2
-rw-r--r--test/sql/test_compare.py106
7 files changed, 161 insertions, 110 deletions
diff --git a/doc/build/changelog/unreleased_14/8574.rst b/doc/build/changelog/unreleased_14/8574.rst
new file mode 100644
index 000000000..ffc1761c3
--- /dev/null
+++ b/doc/build/changelog/unreleased_14/8574.rst
@@ -0,0 +1,5 @@
+.. change::
+ :tags: usecase, postgresql
+ :tickets: 8574
+
+ :class:`_postgresql.aggregate_order_by` now supports cache generation.
diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py
index 0192cf581..ebaad2734 100644
--- a/lib/sqlalchemy/dialects/postgresql/ext.py
+++ b/lib/sqlalchemy/dialects/postgresql/ext.py
@@ -5,8 +5,10 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
+from __future__ import annotations
from itertools import zip_longest
+from typing import TYPE_CHECKING
from .array import ARRAY
from ...sql import coercions
@@ -16,6 +18,10 @@ from ...sql import functions
from ...sql import roles
from ...sql import schema
from ...sql.schema import ColumnCollectionConstraint
+from ...sql.visitors import InternalTraversal
+
+if TYPE_CHECKING:
+ from ...sql.visitors import _TraverseInternalsType
class aggregate_order_by(expression.ColumnElement):
@@ -56,7 +62,11 @@ class aggregate_order_by(expression.ColumnElement):
__visit_name__ = "aggregate_order_by"
stringify_dialect = "postgresql"
- inherit_cache = False
+ _traverse_internals: _TraverseInternalsType = [
+ ("target", InternalTraversal.dp_clauseelement),
+ ("type", InternalTraversal.dp_type),
+ ("order_by", InternalTraversal.dp_clauseelement),
+ ]
def __init__(self, target, *order_by):
self.target = coercions.expect(roles.ExpressionElementRole, target)
diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py
index ef284babc..5fb547cbc 100644
--- a/lib/sqlalchemy/testing/fixtures.py
+++ b/lib/sqlalchemy/testing/fixtures.py
@@ -9,6 +9,7 @@
from __future__ import annotations
+import itertools
import re
import sys
@@ -16,6 +17,8 @@ import sqlalchemy as sa
from . import assertions
from . import config
from . import schema
+from .assertions import eq_
+from .assertions import ne_
from .entities import BasicEntity
from .entities import ComparableEntity
from .entities import ComparableMixin # noqa
@@ -27,6 +30,8 @@ from ..orm import DeclarativeBase
from ..orm import MappedAsDataclass
from ..orm import registry
from ..schema import sort_tables_and_constraints
+from ..sql import visitors
+from ..sql.elements import ClauseElement
@config.mark_base_test_class()
@@ -881,3 +886,106 @@ class ComputedReflectionFixtureTest(TablesTest):
Computed("normal * 42", persisted=True),
)
)
+
+
+class CacheKeyFixture:
+ def _compare_equal(self, a, b, compare_values):
+ a_key = a._generate_cache_key()
+ b_key = b._generate_cache_key()
+
+ if a_key is None:
+ assert a._annotations.get("nocache")
+
+ assert b_key is None
+ else:
+
+ eq_(a_key.key, b_key.key)
+ eq_(hash(a_key.key), hash(b_key.key))
+
+ for a_param, b_param in zip(a_key.bindparams, b_key.bindparams):
+ assert a_param.compare(b_param, compare_values=compare_values)
+ return a_key, b_key
+
+ def _run_cache_key_fixture(self, fixture, compare_values):
+ case_a = fixture()
+ case_b = fixture()
+
+ for a, b in itertools.combinations_with_replacement(
+ range(len(case_a)), 2
+ ):
+ if a == b:
+ a_key, b_key = self._compare_equal(
+ case_a[a], case_b[b], compare_values
+ )
+ if a_key is None:
+ continue
+ else:
+ a_key = case_a[a]._generate_cache_key()
+ b_key = case_b[b]._generate_cache_key()
+
+ if a_key is None or b_key is None:
+ if a_key is None:
+ assert case_a[a]._annotations.get("nocache")
+ if b_key is None:
+ assert case_b[b]._annotations.get("nocache")
+ continue
+
+ if a_key.key == b_key.key:
+ for a_param, b_param in zip(
+ a_key.bindparams, b_key.bindparams
+ ):
+ if not a_param.compare(
+ b_param, compare_values=compare_values
+ ):
+ break
+ else:
+ # this fails unconditionally since we could not
+ # find bound parameter values that differed.
+ # Usually we intended to get two distinct keys here
+ # so the failure will be more descriptive using the
+ # ne_() assertion.
+ ne_(a_key.key, b_key.key)
+ else:
+ ne_(a_key.key, b_key.key)
+
+ # ClauseElement-specific test to ensure the cache key
+ # collected all the bound parameters that aren't marked
+ # as "literal execute"
+ if isinstance(case_a[a], ClauseElement) and isinstance(
+ case_b[b], ClauseElement
+ ):
+ assert_a_params = []
+ assert_b_params = []
+
+ for elem in visitors.iterate(case_a[a]):
+ if elem.__visit_name__ == "bindparam":
+ assert_a_params.append(elem)
+
+ for elem in visitors.iterate(case_b[b]):
+ if elem.__visit_name__ == "bindparam":
+ assert_b_params.append(elem)
+
+ # note we're asserting the order of the params as well as
+ # if there are dupes or not. ordering has to be
+ # deterministic and matches what a traversal would provide.
+ eq_(
+ sorted(a_key.bindparams, key=lambda b: b.key),
+ sorted(
+ util.unique_list(assert_a_params), key=lambda b: b.key
+ ),
+ )
+ eq_(
+ sorted(b_key.bindparams, key=lambda b: b.key),
+ sorted(
+ util.unique_list(assert_b_params), key=lambda b: b.key
+ ),
+ )
+
+ def _run_cache_key_equal_fixture(self, fixture, compare_values):
+ case_a = fixture()
+ case_b = fixture()
+
+ for a, b in itertools.combinations_with_replacement(
+ range(len(case_a)), 2
+ ):
+ self._compare_equal(case_a[a], case_b[b], compare_values)
diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py
index 67e54e4f5..c763dbeac 100644
--- a/test/dialect/postgresql/test_compiler.py
+++ b/test/dialect/postgresql/test_compiler.py
@@ -3465,3 +3465,36 @@ class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
"SELECT 1 " + exp,
checkparams=params,
)
+
+
+class CacheKeyTest(fixtures.CacheKeyFixture, fixtures.TestBase):
+ def test_aggregate_order_by(self):
+ """test #8574"""
+
+ self._run_cache_key_fixture(
+ lambda: (
+ aggregate_order_by(column("a"), column("a")),
+ aggregate_order_by(column("a"), column("b")),
+ aggregate_order_by(column("a"), column("a").desc()),
+ aggregate_order_by(column("a"), column("a").nulls_first()),
+ aggregate_order_by(
+ column("a"), column("a").desc().nulls_first()
+ ),
+ aggregate_order_by(column("a", Integer), column("b")),
+ aggregate_order_by(column("a"), column("b"), column("c")),
+ aggregate_order_by(column("a"), column("c"), column("b")),
+ aggregate_order_by(
+ column("a"), column("b").desc(), column("c")
+ ),
+ aggregate_order_by(
+ column("a"), column("b").nulls_first(), column("c")
+ ),
+ aggregate_order_by(
+ column("a"), column("b").desc().nulls_first(), column("c")
+ ),
+ aggregate_order_by(
+ column("a", Integer), column("a"), column("b")
+ ),
+ ),
+ compare_values=False,
+ )
diff --git a/test/orm/test_cache_key.py b/test/orm/test_cache_key.py
index 08fd22dc8..70770089c 100644
--- a/test/orm/test_cache_key.py
+++ b/test/orm/test_cache_key.py
@@ -42,7 +42,6 @@ from sqlalchemy.testing.fixtures import fixture_session
from test.orm import _fixtures
from .inheritance import _poly_fixtures
from .test_query import QueryTest
-from ..sql.test_compare import CacheKeyFixture
def stmt_20(*elements):
@@ -52,7 +51,7 @@ def stmt_20(*elements):
)
-class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest):
+class CacheKeyTest(fixtures.CacheKeyFixture, _fixtures.FixtureTest):
run_setup_mappers = "once"
run_inserts = None
run_deletes = None
@@ -591,7 +590,7 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest):
)
-class PolyCacheKeyTest(CacheKeyFixture, _poly_fixtures._Polymorphic):
+class PolyCacheKeyTest(fixtures.CacheKeyFixture, _poly_fixtures._Polymorphic):
run_setup_mappers = "once"
run_inserts = None
run_deletes = None
diff --git a/test/orm/test_deprecations.py b/test/orm/test_deprecations.py
index b012009c8..71c03aee7 100644
--- a/test/orm/test_deprecations.py
+++ b/test/orm/test_deprecations.py
@@ -52,6 +52,7 @@ from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
from sqlalchemy.testing import is_true
from sqlalchemy.testing import mock
+from sqlalchemy.testing.fixtures import CacheKeyFixture
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
@@ -65,7 +66,6 @@ from .test_options import PathTest as OptionsPathTest
from .test_options import PathTest
from .test_options import QueryTest as OptionsQueryTest
from .test_query import QueryTest
-from ..sql.test_compare import CacheKeyFixture
if True:
# hack - zimports won't stop reformatting this to be too-long for now
diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py
index 18f26887a..30ca5c569 100644
--- a/test/sql/test_compare.py
+++ b/test/sql/test_compare.py
@@ -27,7 +27,6 @@ from sqlalchemy import tuple_
from sqlalchemy import TypeDecorator
from sqlalchemy import union
from sqlalchemy import union_all
-from sqlalchemy import util
from sqlalchemy import values
from sqlalchemy.dialects import mysql
from sqlalchemy.dialects import postgresql
@@ -1054,110 +1053,7 @@ class CoreFixtures:
]
-class CacheKeyFixture:
- def _compare_equal(self, a, b, compare_values):
- a_key = a._generate_cache_key()
- b_key = b._generate_cache_key()
-
- if a_key is None:
- assert a._annotations.get("nocache")
-
- assert b_key is None
- else:
-
- eq_(a_key.key, b_key.key)
- eq_(hash(a_key.key), hash(b_key.key))
-
- for a_param, b_param in zip(a_key.bindparams, b_key.bindparams):
- assert a_param.compare(b_param, compare_values=compare_values)
- return a_key, b_key
-
- def _run_cache_key_fixture(self, fixture, compare_values):
- case_a = fixture()
- case_b = fixture()
-
- for a, b in itertools.combinations_with_replacement(
- range(len(case_a)), 2
- ):
- if a == b:
- a_key, b_key = self._compare_equal(
- case_a[a], case_b[b], compare_values
- )
- if a_key is None:
- continue
- else:
- a_key = case_a[a]._generate_cache_key()
- b_key = case_b[b]._generate_cache_key()
-
- if a_key is None or b_key is None:
- if a_key is None:
- assert case_a[a]._annotations.get("nocache")
- if b_key is None:
- assert case_b[b]._annotations.get("nocache")
- continue
-
- if a_key.key == b_key.key:
- for a_param, b_param in zip(
- a_key.bindparams, b_key.bindparams
- ):
- if not a_param.compare(
- b_param, compare_values=compare_values
- ):
- break
- else:
- # this fails unconditionally since we could not
- # find bound parameter values that differed.
- # Usually we intended to get two distinct keys here
- # so the failure will be more descriptive using the
- # ne_() assertion.
- ne_(a_key.key, b_key.key)
- else:
- ne_(a_key.key, b_key.key)
-
- # ClauseElement-specific test to ensure the cache key
- # collected all the bound parameters that aren't marked
- # as "literal execute"
- if isinstance(case_a[a], ClauseElement) and isinstance(
- case_b[b], ClauseElement
- ):
- assert_a_params = []
- assert_b_params = []
-
- for elem in visitors.iterate(case_a[a]):
- if elem.__visit_name__ == "bindparam":
- assert_a_params.append(elem)
-
- for elem in visitors.iterate(case_b[b]):
- if elem.__visit_name__ == "bindparam":
- assert_b_params.append(elem)
-
- # note we're asserting the order of the params as well as
- # if there are dupes or not. ordering has to be
- # deterministic and matches what a traversal would provide.
- eq_(
- sorted(a_key.bindparams, key=lambda b: b.key),
- sorted(
- util.unique_list(assert_a_params), key=lambda b: b.key
- ),
- )
- eq_(
- sorted(b_key.bindparams, key=lambda b: b.key),
- sorted(
- util.unique_list(assert_b_params), key=lambda b: b.key
- ),
- )
-
- def _run_cache_key_equal_fixture(self, fixture, compare_values):
- case_a = fixture()
- case_b = fixture()
-
- for a, b in itertools.combinations_with_replacement(
- range(len(case_a)), 2
- ):
- self._compare_equal(case_a[a], case_b[b], compare_values)
-
-
-class CacheKeyTest(CacheKeyFixture, CoreFixtures, fixtures.TestBase):
+class CacheKeyTest(fixtures.CacheKeyFixture, CoreFixtures, fixtures.TestBase):
# we are slightly breaking the policy of not having external dialect
# stuff in here, but use pg/mysql as test cases to ensure that these
# objects don't report an inaccurate cache key, which is dependent