diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2023-04-10 21:01:11 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@bbpush.zzzcomputing.com> | 2023-04-10 21:01:11 +0000 |
commit | 33a9947686a7c8bb3becd6169ec7abe84cf2c7e6 (patch) | |
tree | 2ea396fa8b4552ab126d3b01e523b80ea4917b80 | |
parent | 1ae44556907f14f92874f05d05242cb57bb0f855 (diff) | |
parent | 157c521736f1c9cfceb9b3a6ecf17f782d358c46 (diff) | |
download | alembic-33a9947686a7c8bb3becd6169ec7abe84cf2c7e6.tar.gz |
Merge "Use column sort in index compare on postgresql" into main
-rw-r--r-- | alembic/autogenerate/compare.py | 24 | ||||
-rw-r--r-- | alembic/ddl/postgresql.py | 73 | ||||
-rw-r--r-- | docs/build/unreleased/1213.rst | 6 | ||||
-rw-r--r-- | tests/requirements.py | 6 | ||||
-rw-r--r-- | tests/test_autogen_indexes.py | 288 |
5 files changed, 305 insertions, 92 deletions
diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py index 4f5126f..85cb426 100644 --- a/alembic/autogenerate/compare.py +++ b/alembic/autogenerate/compare.py @@ -8,6 +8,7 @@ from typing import cast from typing import Dict from typing import Iterator from typing import List +from typing import Mapping from typing import Optional from typing import Set from typing import Tuple @@ -19,6 +20,7 @@ from sqlalchemy import inspect from sqlalchemy import schema as sa_schema from sqlalchemy import text from sqlalchemy import types as sqltypes +from sqlalchemy.sql import expression from sqlalchemy.util import OrderedSet from alembic.ddl.base import _fk_spec @@ -278,15 +280,35 @@ def _compare_tables( upgrade_ops.ops.append(modify_table_ops) +_IndexColumnSortingOps: Mapping[str, Any] = util.immutabledict( + { + "asc": expression.asc, + "desc": expression.desc, + "nulls_first": expression.nullsfirst, + "nulls_last": expression.nullslast, + "nullsfirst": expression.nullsfirst, # 1_3 name + "nullslast": expression.nullslast, # 1_3 name + } +) + + def _make_index(params: Dict[str, Any], conn_table: Table) -> Optional[Index]: exprs: list[Union[Column[Any], TextClause]] = [] + sorting = params.get("column_sorting") + for num, col_name in enumerate(params["column_names"]): item: Union[Column[Any], TextClause] if col_name is None: assert "expressions" in params - item = text(params["expressions"][num]) + name = params["expressions"][num] + item = text(name) else: + name = col_name item = conn_table.c[col_name] + if sorting and name in sorting: + for operator in sorting[name]: + if operator in _IndexColumnSortingOps: + item = _IndexColumnSortingOps[operator](item) exprs.append(item) ix = sa_schema.Index( params["name"], *exprs, unique=params["unique"], _table=conn_table diff --git a/alembic/ddl/postgresql.py b/alembic/ddl/postgresql.py index 4ffc2eb..247838b 100644 --- a/alembic/ddl/postgresql.py +++ b/alembic/ddl/postgresql.py @@ -21,8 +21,10 @@ from sqlalchemy.dialects.postgresql import BIGINT from sqlalchemy.dialects.postgresql import ExcludeConstraint from sqlalchemy.dialects.postgresql import INTEGER from sqlalchemy.schema import CreateIndex +from sqlalchemy.sql import operators from sqlalchemy.sql.elements import ColumnClause from sqlalchemy.sql.elements import TextClause +from sqlalchemy.sql.elements import UnaryExpression from sqlalchemy.types import NULLTYPE from .base import alter_column @@ -53,6 +55,7 @@ if TYPE_CHECKING: from sqlalchemy.dialects.postgresql.json import JSON from sqlalchemy.dialects.postgresql.json import JSONB from sqlalchemy.sql.elements import BinaryExpression + from sqlalchemy.sql.elements import ClauseElement from sqlalchemy.sql.elements import quoted_name from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.schema import Table @@ -248,11 +251,14 @@ class PostgresqlImpl(DefaultImpl): if not sqla_compat.sqla_2: self._skip_functional_indexes(metadata_indexes, conn_indexes) - def _cleanup_index_expr(self, index: Index, expr: str) -> str: + def _cleanup_index_expr( + self, index: Index, expr: str, remove_suffix: str + ) -> str: # start = expr expr = expr.lower() expr = expr.replace('"', "") if index.table is not None: + # should not be needed, since include_table=False is in compile expr = expr.replace(f"{index.table.name.lower()}.", "") while expr and expr[0] == "(" and expr[-1] == ")": @@ -261,25 +267,64 @@ class PostgresqlImpl(DefaultImpl): # strip :: cast. types can have spaces in them expr = re.sub(r"(::[\w ]+\w)", "", expr) + if remove_suffix and expr.endswith(remove_suffix): + expr = expr[: -len(remove_suffix)] + # print(f"START: {start} END: {expr}") return expr + def _default_modifiers(self, exp: ClauseElement) -> str: + to_remove = "" + while isinstance(exp, UnaryExpression): + if exp.modifier is None: + exp = exp.element + else: + op = exp.modifier + if isinstance(exp.element, UnaryExpression): + inner_op = exp.element.modifier + else: + inner_op = None + if inner_op is None: + if op == operators.asc_op: + # default is asc + to_remove = " asc" + elif op == operators.nullslast_op: + # default is nulls last + to_remove = " nulls last" + else: + if ( + inner_op == operators.asc_op + and op == operators.nullslast_op + ): + # default is asc nulls last + to_remove = " asc nulls last" + elif ( + inner_op == operators.desc_op + and op == operators.nullsfirst_op + ): + # default for desc is nulls first + to_remove = " nulls first" + break + return to_remove + def create_index_sig(self, index: Index) -> Tuple[Any, ...]: - if sqla_compat.is_expression_index(index): - return tuple( - self._cleanup_index_expr( - index, - e + return tuple( + self._cleanup_index_expr( + index, + *( + (e, "") if isinstance(e, str) - else e.compile( - dialect=self.dialect, - compile_kwargs={"literal_binds": True}, - ).string, - ) - for e in index.expressions + else (self._compile_element(e), self._default_modifiers(e)) + ), ) - else: - return super().create_index_sig(index) + for e in index.expressions + ) + + def _compile_element(self, element: ClauseElement) -> str: + return element.compile( + dialect=self.dialect, + compile_kwargs={"literal_binds": True, "include_table": False}, + ).string def render_type( self, type_: TypeEngine, autogen_context: AutogenContext diff --git a/docs/build/unreleased/1213.rst b/docs/build/unreleased/1213.rst new file mode 100644 index 0000000..29b88a4 --- /dev/null +++ b/docs/build/unreleased/1213.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: postgresql, autogenerate + :tickets: 1213 + + Added support for autogenerate comparison of indexes on PostgreSQL which + include SQL sort option, such as ``ASC`` or ``NULLS FIRST``. diff --git a/tests/requirements.py b/tests/requirements.py index 1a100dd..dbbb88a 100644 --- a/tests/requirements.py +++ b/tests/requirements.py @@ -138,9 +138,15 @@ class DefaultRequirements(SuiteRequirements): def reflects_indexes_w_sorting(self): # TODO: figure out what's happening on the SQLAlchemy side # when we reflect an index that has asc() / desc() on the column + # Tracked by https://github.com/sqlalchemy/sqlalchemy/issues/9597 return exclusions.fails_on(["oracle"]) @property + def reflects_indexes_column_sorting(self): + "Actually reflect column_sorting on the indexes" + return exclusions.only_on(["postgresql"]) + + @property def long_names(self): if sqla_compat.sqla_14: return exclusions.skip_if("oracle<18") diff --git a/tests/test_autogen_indexes.py b/tests/test_autogen_indexes.py index 30b7d90..f697e5a 100644 --- a/tests/test_autogen_indexes.py +++ b/tests/test_autogen_indexes.py @@ -15,8 +15,11 @@ from sqlalchemy import String from sqlalchemy import Table from sqlalchemy import UniqueConstraint from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION +from sqlalchemy.sql.expression import asc from sqlalchemy.sql.expression import column from sqlalchemy.sql.expression import desc +from sqlalchemy.sql.expression import nullsfirst +from sqlalchemy.sql.expression import nullslast from alembic import testing from alembic.testing import combinations @@ -1130,11 +1133,6 @@ class AutogenerateIndexTest(AutogenFixtureTest, TestBase): eq_(diffs, []) - # fails in the 0.8 series where we have truncation rules, - # but no control over quoting. passes in 0.7.9 where we don't have - # truncation rules either. dropping these ancient versions - # is long overdue. - def test_unchanged_case_sensitive_implicit_idx(self): m1 = MetaData() m2 = MetaData() @@ -1216,6 +1214,206 @@ class AutogenerateIndexTest(AutogenFixtureTest, TestBase): ], ) + @config.requirements.reflects_indexes_column_sorting + @testing.combinations( + (desc, asc), + (asc, desc), + (desc, lambda x: nullslast(desc(x))), + (nullslast, nullsfirst), + (nullsfirst, nullslast), + (lambda x: nullslast(desc(x)), lambda x: nullsfirst(asc(x))), + ) + def test_column_sort_changed(self, old_fn, new_fn): + m1 = MetaData() + m2 = MetaData() + + old = Index("SomeIndex", old_fn("y")) + Table("order_change", m1, Column("y", Integer), old) + + new = Index("SomeIndex", new_fn("y")) + Table("order_change", m2, Column("y", Integer), new) + diffs = self._fixture(m1, m2) + eq_( + diffs, + [ + ( + "remove_index", + schemacompare.CompareIndex(old, name_only=True), + ), + ("add_index", schemacompare.CompareIndex(new, name_only=True)), + ], + ) + + @config.requirements.reflects_indexes_column_sorting + @testing.combinations( + (asc, asc), + (desc, desc), + (nullslast, nullslast), + (nullsfirst, nullsfirst), + (lambda x: x, asc), + (lambda x: x, nullslast), + (desc, lambda x: nullsfirst(desc(x))), + (lambda x: nullslast(asc(x)), lambda x: x), + ) + def test_column_sort_not_changed(self, old_fn, new_fn): + m1 = MetaData() + m2 = MetaData() + + old = Index("SomeIndex", old_fn("y")) + Table("order_change", m1, Column("y", Integer), old) + + new = Index("SomeIndex", new_fn("y")) + Table("order_change", m2, Column("y", Integer), new) + diffs = self._fixture(m1, m2) + eq_(diffs, []) + + +def _lots_of_indexes(flatten: bool = False): + diff_pairs = [ + ( + lambda t: Index("SomeIndex", "y", func.lower(t.c.x)), + lambda t: Index("SomeIndex", func.lower(t.c.x)), + ), + ( + lambda CapT: Index("SomeIndex", "y", func.lower(CapT.c.XCol)), + lambda CapT: Index("SomeIndex", func.lower(CapT.c.XCol)), + ), + ( + lambda t: Index( + "SomeIndex", "y", func.lower(column("x")), _table=t + ), + lambda t: Index("SomeIndex", func.lower(column("x")), _table=t), + ), + ( + lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)), + lambda t: Index("SomeIndex", func.lower(t.c.x), t.c.y), + ), + ( + lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)), + lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.q)), + ), + ( + lambda t: Index("SomeIndex", t.c.z, func.lower(t.c.x)), + lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)), + ), + ( + lambda t: Index("SomeIndex", func.lower(t.c.x)), + lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)), + ), + ( + lambda t: Index("SomeIndex", t.c.y, func.upper(t.c.x)), + lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)), + ), + ( + lambda t: Index("SomeIndex", t.c.y, t.c.ff + 1), + lambda t: Index("SomeIndex", t.c.y, t.c.ff + 3), + ), + ( + lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)), + lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x + t.c.q)), + ), + ( + lambda t: Index("SomeIndex", t.c.y, t.c.z + 3), + lambda t: Index("SomeIndex", t.c.y, t.c.z * 3), + ), + ( + lambda t: Index("SomeIndex", func.lower(t.c.x), t.c.q + "42"), + lambda t: Index("SomeIndex", func.lower(t.c.q), t.c.x + "42"), + ), + ( + lambda t: Index("SomeIndex", func.lower(t.c.x), t.c.z + 42), + lambda t: Index("SomeIndex", t.c.z + 42, func.lower(t.c.q)), + ), + ( + lambda t: Index("SomeIndex", t.c.ff + 42), + lambda t: Index("SomeIndex", 42 + t.c.ff), + ), + ] + + with_sort = [ + ( + lambda t: Index("SomeIndex", "y", func.lower(t.c.x)), + lambda t: Index("SomeIndex", "y", desc(func.lower(t.c.x))), + ), + ( + lambda t: Index("SomeIndex", "y", func.lower(t.c.x)), + lambda t: Index("SomeIndex", desc("y"), func.lower(t.c.x)), + ), + ( + lambda t: Index("SomeIndex", "y", func.lower(t.c.x)), + lambda t: Index("SomeIndex", "y", nullsfirst(func.lower(t.c.x))), + ), + ( + lambda t: Index("SomeIndex", "y", func.lower(t.c.x)), + lambda t: Index("SomeIndex", nullsfirst("y"), func.lower(t.c.x)), + ), + ( + lambda t: Index("SomeIndex", asc(func.lower(t.c.x))), + lambda t: Index("SomeIndex", desc(func.lower(t.c.x))), + ), + ( + lambda t: Index("SomeIndex", desc(func.lower(t.c.x))), + lambda t: Index("SomeIndex", asc(func.lower(t.c.x))), + ), + ( + lambda t: Index("SomeIndex", nullslast(asc(func.lower(t.c.x)))), + lambda t: Index("SomeIndex", nullslast(desc(func.lower(t.c.x)))), + ), + ( + lambda t: Index("SomeIndex", nullslast(desc(func.lower(t.c.x)))), + lambda t: Index("SomeIndex", nullsfirst(desc(func.lower(t.c.x)))), + ), + ( + lambda t: Index("SomeIndex", nullsfirst(func.lower(t.c.x))), + lambda t: Index("SomeIndex", desc(func.lower(t.c.x))), + ), + ] + + req = config.requirements.reflects_indexes_column_sorting + + if flatten: + + flat = list(itertools.chain.from_iterable(diff_pairs)) + for f1, f2 in with_sort: + flat.extend([(f1, req), (f2, req)]) + return flat + else: + return diff_pairs + [(f1, f2, req) for f1, f2 in with_sort] + + +def _lost_of_equal_indexes(_lots_of_indexes): + equal_pairs = [ + (fn, fn) if not isinstance(fn, tuple) else (fn[0], fn[0], fn[1]) + for fn in _lots_of_indexes(flatten=True) + ] + equal_pairs += [ + ( + lambda t: Index("SomeIndex", "y", func.lower(t.c.x)), + lambda t: Index("SomeIndex", "y", asc(func.lower(t.c.x))), + config.requirements.reflects_indexes_column_sorting, + ), + ( + lambda t: Index("SomeIndex", "y", func.lower(t.c.x)), + lambda t: Index("SomeIndex", "y", nullslast(func.lower(t.c.x))), + config.requirements.reflects_indexes_column_sorting, + ), + ( + lambda t: Index("SomeIndex", "y", func.lower(t.c.x)), + lambda t: Index( + "SomeIndex", "y", nullslast(asc(func.lower(t.c.x))) + ), + config.requirements.reflects_indexes_column_sorting, + ), + ( + lambda t: Index("SomeIndex", "y", desc(func.lower(t.c.x))), + lambda t: Index( + "SomeIndex", "y", nullsfirst(desc(func.lower(t.c.x))) + ), + config.requirements.reflects_indexes_column_sorting, + ), + ] + return equal_pairs + class AutogenerateExpressionIndexTest(AutogenFixtureTest, TestBase): """tests involving indexes with expression""" @@ -1263,74 +1461,6 @@ class AutogenerateExpressionIndexTest(AutogenFixtureTest, TestBase): diffs = self._fixture(m1, m2) eq_(diffs, []) - def _lots_of_indexes(flatten: bool = False): - diff_pairs = [ - ( - lambda t: Index("SomeIndex", "y", func.lower(t.c.x)), - lambda t: Index("SomeIndex", func.lower(t.c.x)), - ), - ( - lambda CapT: Index("SomeIndex", "y", func.lower(CapT.c.XCol)), - lambda CapT: Index("SomeIndex", func.lower(CapT.c.XCol)), - ), - ( - lambda t: Index( - "SomeIndex", "y", func.lower(column("x")), _table=t - ), - lambda t: Index( - "SomeIndex", func.lower(column("x")), _table=t - ), - ), - ( - lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)), - lambda t: Index("SomeIndex", func.lower(t.c.x), t.c.y), - ), - ( - lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)), - lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.q)), - ), - ( - lambda t: Index("SomeIndex", t.c.z, func.lower(t.c.x)), - lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)), - ), - ( - lambda t: Index("SomeIndex", func.lower(t.c.x)), - lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)), - ), - ( - lambda t: Index("SomeIndex", t.c.y, func.upper(t.c.x)), - lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)), - ), - ( - lambda t: Index("SomeIndex", t.c.y, t.c.ff + 1), - lambda t: Index("SomeIndex", t.c.y, t.c.ff + 3), - ), - ( - lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)), - lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x + t.c.q)), - ), - ( - lambda t: Index("SomeIndex", t.c.y, t.c.z + 3), - lambda t: Index("SomeIndex", t.c.y, t.c.z * 3), - ), - ( - lambda t: Index("SomeIndex", func.lower(t.c.x), t.c.q + "42"), - lambda t: Index("SomeIndex", func.lower(t.c.q), t.c.x + "42"), - ), - ( - lambda t: Index("SomeIndex", func.lower(t.c.x), t.c.z + 42), - lambda t: Index("SomeIndex", t.c.z + 42, func.lower(t.c.q)), - ), - ( - lambda t: Index("SomeIndex", t.c.ff + 42), - lambda t: Index("SomeIndex", 42 + t.c.ff), - ), - ] - if flatten: - return list(itertools.chain.from_iterable(diff_pairs)) - else: - return diff_pairs - @testing.fixture def index_changed_tables(self): m1 = MetaData() @@ -1412,12 +1542,16 @@ class AutogenerateExpressionIndexTest(AutogenFixtureTest, TestBase): diffs = self._fixture(m1, m2) eq_(diffs, []) - @combinations(*_lots_of_indexes(flatten=True), argnames="fn") - def test_expression_indexes_no_change(self, index_changed_tables, fn): + @combinations( + *_lost_of_equal_indexes(_lots_of_indexes), argnames="fn1, fn2" + ) + def test_expression_indexes_no_change( + self, index_changed_tables, fn1, fn2 + ): m1, m2, old_fixture_tables, new_fixture_tables = index_changed_tables - resolve_lambda(fn, **old_fixture_tables) - resolve_lambda(fn, **new_fixture_tables) + resolve_lambda(fn1, **old_fixture_tables) + resolve_lambda(fn2, **new_fixture_tables) if self.has_reflection: ctx = nullcontext() |