summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2023-04-10 21:01:11 +0000
committerGerrit Code Review <gerrit@bbpush.zzzcomputing.com>2023-04-10 21:01:11 +0000
commit33a9947686a7c8bb3becd6169ec7abe84cf2c7e6 (patch)
tree2ea396fa8b4552ab126d3b01e523b80ea4917b80
parent1ae44556907f14f92874f05d05242cb57bb0f855 (diff)
parent157c521736f1c9cfceb9b3a6ecf17f782d358c46 (diff)
downloadalembic-33a9947686a7c8bb3becd6169ec7abe84cf2c7e6.tar.gz
Merge "Use column sort in index compare on postgresql" into main
-rw-r--r--alembic/autogenerate/compare.py24
-rw-r--r--alembic/ddl/postgresql.py73
-rw-r--r--docs/build/unreleased/1213.rst6
-rw-r--r--tests/requirements.py6
-rw-r--r--tests/test_autogen_indexes.py288
5 files changed, 305 insertions, 92 deletions
diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py
index 4f5126f..85cb426 100644
--- a/alembic/autogenerate/compare.py
+++ b/alembic/autogenerate/compare.py
@@ -8,6 +8,7 @@ from typing import cast
from typing import Dict
from typing import Iterator
from typing import List
+from typing import Mapping
from typing import Optional
from typing import Set
from typing import Tuple
@@ -19,6 +20,7 @@ from sqlalchemy import inspect
from sqlalchemy import schema as sa_schema
from sqlalchemy import text
from sqlalchemy import types as sqltypes
+from sqlalchemy.sql import expression
from sqlalchemy.util import OrderedSet
from alembic.ddl.base import _fk_spec
@@ -278,15 +280,35 @@ def _compare_tables(
upgrade_ops.ops.append(modify_table_ops)
+_IndexColumnSortingOps: Mapping[str, Any] = util.immutabledict(
+ {
+ "asc": expression.asc,
+ "desc": expression.desc,
+ "nulls_first": expression.nullsfirst,
+ "nulls_last": expression.nullslast,
+ "nullsfirst": expression.nullsfirst, # 1_3 name
+ "nullslast": expression.nullslast, # 1_3 name
+ }
+)
+
+
def _make_index(params: Dict[str, Any], conn_table: Table) -> Optional[Index]:
exprs: list[Union[Column[Any], TextClause]] = []
+ sorting = params.get("column_sorting")
+
for num, col_name in enumerate(params["column_names"]):
item: Union[Column[Any], TextClause]
if col_name is None:
assert "expressions" in params
- item = text(params["expressions"][num])
+ name = params["expressions"][num]
+ item = text(name)
else:
+ name = col_name
item = conn_table.c[col_name]
+ if sorting and name in sorting:
+ for operator in sorting[name]:
+ if operator in _IndexColumnSortingOps:
+ item = _IndexColumnSortingOps[operator](item)
exprs.append(item)
ix = sa_schema.Index(
params["name"], *exprs, unique=params["unique"], _table=conn_table
diff --git a/alembic/ddl/postgresql.py b/alembic/ddl/postgresql.py
index 4ffc2eb..247838b 100644
--- a/alembic/ddl/postgresql.py
+++ b/alembic/ddl/postgresql.py
@@ -21,8 +21,10 @@ from sqlalchemy.dialects.postgresql import BIGINT
from sqlalchemy.dialects.postgresql import ExcludeConstraint
from sqlalchemy.dialects.postgresql import INTEGER
from sqlalchemy.schema import CreateIndex
+from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.elements import TextClause
+from sqlalchemy.sql.elements import UnaryExpression
from sqlalchemy.types import NULLTYPE
from .base import alter_column
@@ -53,6 +55,7 @@ if TYPE_CHECKING:
from sqlalchemy.dialects.postgresql.json import JSON
from sqlalchemy.dialects.postgresql.json import JSONB
from sqlalchemy.sql.elements import BinaryExpression
+ from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.schema import Table
@@ -248,11 +251,14 @@ class PostgresqlImpl(DefaultImpl):
if not sqla_compat.sqla_2:
self._skip_functional_indexes(metadata_indexes, conn_indexes)
- def _cleanup_index_expr(self, index: Index, expr: str) -> str:
+ def _cleanup_index_expr(
+ self, index: Index, expr: str, remove_suffix: str
+ ) -> str:
# start = expr
expr = expr.lower()
expr = expr.replace('"', "")
if index.table is not None:
+ # should not be needed, since include_table=False is in compile
expr = expr.replace(f"{index.table.name.lower()}.", "")
while expr and expr[0] == "(" and expr[-1] == ")":
@@ -261,25 +267,64 @@ class PostgresqlImpl(DefaultImpl):
# strip :: cast. types can have spaces in them
expr = re.sub(r"(::[\w ]+\w)", "", expr)
+ if remove_suffix and expr.endswith(remove_suffix):
+ expr = expr[: -len(remove_suffix)]
+
# print(f"START: {start} END: {expr}")
return expr
+ def _default_modifiers(self, exp: ClauseElement) -> str:
+ to_remove = ""
+ while isinstance(exp, UnaryExpression):
+ if exp.modifier is None:
+ exp = exp.element
+ else:
+ op = exp.modifier
+ if isinstance(exp.element, UnaryExpression):
+ inner_op = exp.element.modifier
+ else:
+ inner_op = None
+ if inner_op is None:
+ if op == operators.asc_op:
+ # default is asc
+ to_remove = " asc"
+ elif op == operators.nullslast_op:
+ # default is nulls last
+ to_remove = " nulls last"
+ else:
+ if (
+ inner_op == operators.asc_op
+ and op == operators.nullslast_op
+ ):
+ # default is asc nulls last
+ to_remove = " asc nulls last"
+ elif (
+ inner_op == operators.desc_op
+ and op == operators.nullsfirst_op
+ ):
+ # default for desc is nulls first
+ to_remove = " nulls first"
+ break
+ return to_remove
+
def create_index_sig(self, index: Index) -> Tuple[Any, ...]:
- if sqla_compat.is_expression_index(index):
- return tuple(
- self._cleanup_index_expr(
- index,
- e
+ return tuple(
+ self._cleanup_index_expr(
+ index,
+ *(
+ (e, "")
if isinstance(e, str)
- else e.compile(
- dialect=self.dialect,
- compile_kwargs={"literal_binds": True},
- ).string,
- )
- for e in index.expressions
+ else (self._compile_element(e), self._default_modifiers(e))
+ ),
)
- else:
- return super().create_index_sig(index)
+ for e in index.expressions
+ )
+
+ def _compile_element(self, element: ClauseElement) -> str:
+ return element.compile(
+ dialect=self.dialect,
+ compile_kwargs={"literal_binds": True, "include_table": False},
+ ).string
def render_type(
self, type_: TypeEngine, autogen_context: AutogenContext
diff --git a/docs/build/unreleased/1213.rst b/docs/build/unreleased/1213.rst
new file mode 100644
index 0000000..29b88a4
--- /dev/null
+++ b/docs/build/unreleased/1213.rst
@@ -0,0 +1,6 @@
+.. change::
+ :tags: postgresql, autogenerate
+ :tickets: 1213
+
+ Added support for autogenerate comparison of indexes on PostgreSQL which
+ include SQL sort option, such as ``ASC`` or ``NULLS FIRST``.
diff --git a/tests/requirements.py b/tests/requirements.py
index 1a100dd..dbbb88a 100644
--- a/tests/requirements.py
+++ b/tests/requirements.py
@@ -138,9 +138,15 @@ class DefaultRequirements(SuiteRequirements):
def reflects_indexes_w_sorting(self):
# TODO: figure out what's happening on the SQLAlchemy side
# when we reflect an index that has asc() / desc() on the column
+ # Tracked by https://github.com/sqlalchemy/sqlalchemy/issues/9597
return exclusions.fails_on(["oracle"])
@property
+ def reflects_indexes_column_sorting(self):
+ "Actually reflect column_sorting on the indexes"
+ return exclusions.only_on(["postgresql"])
+
+ @property
def long_names(self):
if sqla_compat.sqla_14:
return exclusions.skip_if("oracle<18")
diff --git a/tests/test_autogen_indexes.py b/tests/test_autogen_indexes.py
index 30b7d90..f697e5a 100644
--- a/tests/test_autogen_indexes.py
+++ b/tests/test_autogen_indexes.py
@@ -15,8 +15,11 @@ from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy import UniqueConstraint
from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION
+from sqlalchemy.sql.expression import asc
from sqlalchemy.sql.expression import column
from sqlalchemy.sql.expression import desc
+from sqlalchemy.sql.expression import nullsfirst
+from sqlalchemy.sql.expression import nullslast
from alembic import testing
from alembic.testing import combinations
@@ -1130,11 +1133,6 @@ class AutogenerateIndexTest(AutogenFixtureTest, TestBase):
eq_(diffs, [])
- # fails in the 0.8 series where we have truncation rules,
- # but no control over quoting. passes in 0.7.9 where we don't have
- # truncation rules either. dropping these ancient versions
- # is long overdue.
-
def test_unchanged_case_sensitive_implicit_idx(self):
m1 = MetaData()
m2 = MetaData()
@@ -1216,6 +1214,206 @@ class AutogenerateIndexTest(AutogenFixtureTest, TestBase):
],
)
+ @config.requirements.reflects_indexes_column_sorting
+ @testing.combinations(
+ (desc, asc),
+ (asc, desc),
+ (desc, lambda x: nullslast(desc(x))),
+ (nullslast, nullsfirst),
+ (nullsfirst, nullslast),
+ (lambda x: nullslast(desc(x)), lambda x: nullsfirst(asc(x))),
+ )
+ def test_column_sort_changed(self, old_fn, new_fn):
+ m1 = MetaData()
+ m2 = MetaData()
+
+ old = Index("SomeIndex", old_fn("y"))
+ Table("order_change", m1, Column("y", Integer), old)
+
+ new = Index("SomeIndex", new_fn("y"))
+ Table("order_change", m2, Column("y", Integer), new)
+ diffs = self._fixture(m1, m2)
+ eq_(
+ diffs,
+ [
+ (
+ "remove_index",
+ schemacompare.CompareIndex(old, name_only=True),
+ ),
+ ("add_index", schemacompare.CompareIndex(new, name_only=True)),
+ ],
+ )
+
+ @config.requirements.reflects_indexes_column_sorting
+ @testing.combinations(
+ (asc, asc),
+ (desc, desc),
+ (nullslast, nullslast),
+ (nullsfirst, nullsfirst),
+ (lambda x: x, asc),
+ (lambda x: x, nullslast),
+ (desc, lambda x: nullsfirst(desc(x))),
+ (lambda x: nullslast(asc(x)), lambda x: x),
+ )
+ def test_column_sort_not_changed(self, old_fn, new_fn):
+ m1 = MetaData()
+ m2 = MetaData()
+
+ old = Index("SomeIndex", old_fn("y"))
+ Table("order_change", m1, Column("y", Integer), old)
+
+ new = Index("SomeIndex", new_fn("y"))
+ Table("order_change", m2, Column("y", Integer), new)
+ diffs = self._fixture(m1, m2)
+ eq_(diffs, [])
+
+
+def _lots_of_indexes(flatten: bool = False):
+ diff_pairs = [
+ (
+ lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
+ lambda t: Index("SomeIndex", func.lower(t.c.x)),
+ ),
+ (
+ lambda CapT: Index("SomeIndex", "y", func.lower(CapT.c.XCol)),
+ lambda CapT: Index("SomeIndex", func.lower(CapT.c.XCol)),
+ ),
+ (
+ lambda t: Index(
+ "SomeIndex", "y", func.lower(column("x")), _table=t
+ ),
+ lambda t: Index("SomeIndex", func.lower(column("x")), _table=t),
+ ),
+ (
+ lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
+ lambda t: Index("SomeIndex", func.lower(t.c.x), t.c.y),
+ ),
+ (
+ lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
+ lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.q)),
+ ),
+ (
+ lambda t: Index("SomeIndex", t.c.z, func.lower(t.c.x)),
+ lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
+ ),
+ (
+ lambda t: Index("SomeIndex", func.lower(t.c.x)),
+ lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
+ ),
+ (
+ lambda t: Index("SomeIndex", t.c.y, func.upper(t.c.x)),
+ lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
+ ),
+ (
+ lambda t: Index("SomeIndex", t.c.y, t.c.ff + 1),
+ lambda t: Index("SomeIndex", t.c.y, t.c.ff + 3),
+ ),
+ (
+ lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
+ lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x + t.c.q)),
+ ),
+ (
+ lambda t: Index("SomeIndex", t.c.y, t.c.z + 3),
+ lambda t: Index("SomeIndex", t.c.y, t.c.z * 3),
+ ),
+ (
+ lambda t: Index("SomeIndex", func.lower(t.c.x), t.c.q + "42"),
+ lambda t: Index("SomeIndex", func.lower(t.c.q), t.c.x + "42"),
+ ),
+ (
+ lambda t: Index("SomeIndex", func.lower(t.c.x), t.c.z + 42),
+ lambda t: Index("SomeIndex", t.c.z + 42, func.lower(t.c.q)),
+ ),
+ (
+ lambda t: Index("SomeIndex", t.c.ff + 42),
+ lambda t: Index("SomeIndex", 42 + t.c.ff),
+ ),
+ ]
+
+ with_sort = [
+ (
+ lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
+ lambda t: Index("SomeIndex", "y", desc(func.lower(t.c.x))),
+ ),
+ (
+ lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
+ lambda t: Index("SomeIndex", desc("y"), func.lower(t.c.x)),
+ ),
+ (
+ lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
+ lambda t: Index("SomeIndex", "y", nullsfirst(func.lower(t.c.x))),
+ ),
+ (
+ lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
+ lambda t: Index("SomeIndex", nullsfirst("y"), func.lower(t.c.x)),
+ ),
+ (
+ lambda t: Index("SomeIndex", asc(func.lower(t.c.x))),
+ lambda t: Index("SomeIndex", desc(func.lower(t.c.x))),
+ ),
+ (
+ lambda t: Index("SomeIndex", desc(func.lower(t.c.x))),
+ lambda t: Index("SomeIndex", asc(func.lower(t.c.x))),
+ ),
+ (
+ lambda t: Index("SomeIndex", nullslast(asc(func.lower(t.c.x)))),
+ lambda t: Index("SomeIndex", nullslast(desc(func.lower(t.c.x)))),
+ ),
+ (
+ lambda t: Index("SomeIndex", nullslast(desc(func.lower(t.c.x)))),
+ lambda t: Index("SomeIndex", nullsfirst(desc(func.lower(t.c.x)))),
+ ),
+ (
+ lambda t: Index("SomeIndex", nullsfirst(func.lower(t.c.x))),
+ lambda t: Index("SomeIndex", desc(func.lower(t.c.x))),
+ ),
+ ]
+
+ req = config.requirements.reflects_indexes_column_sorting
+
+ if flatten:
+
+ flat = list(itertools.chain.from_iterable(diff_pairs))
+ for f1, f2 in with_sort:
+ flat.extend([(f1, req), (f2, req)])
+ return flat
+ else:
+ return diff_pairs + [(f1, f2, req) for f1, f2 in with_sort]
+
+
+def _lost_of_equal_indexes(_lots_of_indexes):
+ equal_pairs = [
+ (fn, fn) if not isinstance(fn, tuple) else (fn[0], fn[0], fn[1])
+ for fn in _lots_of_indexes(flatten=True)
+ ]
+ equal_pairs += [
+ (
+ lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
+ lambda t: Index("SomeIndex", "y", asc(func.lower(t.c.x))),
+ config.requirements.reflects_indexes_column_sorting,
+ ),
+ (
+ lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
+ lambda t: Index("SomeIndex", "y", nullslast(func.lower(t.c.x))),
+ config.requirements.reflects_indexes_column_sorting,
+ ),
+ (
+ lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
+ lambda t: Index(
+ "SomeIndex", "y", nullslast(asc(func.lower(t.c.x)))
+ ),
+ config.requirements.reflects_indexes_column_sorting,
+ ),
+ (
+ lambda t: Index("SomeIndex", "y", desc(func.lower(t.c.x))),
+ lambda t: Index(
+ "SomeIndex", "y", nullsfirst(desc(func.lower(t.c.x)))
+ ),
+ config.requirements.reflects_indexes_column_sorting,
+ ),
+ ]
+ return equal_pairs
+
class AutogenerateExpressionIndexTest(AutogenFixtureTest, TestBase):
"""tests involving indexes with expression"""
@@ -1263,74 +1461,6 @@ class AutogenerateExpressionIndexTest(AutogenFixtureTest, TestBase):
diffs = self._fixture(m1, m2)
eq_(diffs, [])
- def _lots_of_indexes(flatten: bool = False):
- diff_pairs = [
- (
- lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
- lambda t: Index("SomeIndex", func.lower(t.c.x)),
- ),
- (
- lambda CapT: Index("SomeIndex", "y", func.lower(CapT.c.XCol)),
- lambda CapT: Index("SomeIndex", func.lower(CapT.c.XCol)),
- ),
- (
- lambda t: Index(
- "SomeIndex", "y", func.lower(column("x")), _table=t
- ),
- lambda t: Index(
- "SomeIndex", func.lower(column("x")), _table=t
- ),
- ),
- (
- lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
- lambda t: Index("SomeIndex", func.lower(t.c.x), t.c.y),
- ),
- (
- lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
- lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.q)),
- ),
- (
- lambda t: Index("SomeIndex", t.c.z, func.lower(t.c.x)),
- lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
- ),
- (
- lambda t: Index("SomeIndex", func.lower(t.c.x)),
- lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
- ),
- (
- lambda t: Index("SomeIndex", t.c.y, func.upper(t.c.x)),
- lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
- ),
- (
- lambda t: Index("SomeIndex", t.c.y, t.c.ff + 1),
- lambda t: Index("SomeIndex", t.c.y, t.c.ff + 3),
- ),
- (
- lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
- lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x + t.c.q)),
- ),
- (
- lambda t: Index("SomeIndex", t.c.y, t.c.z + 3),
- lambda t: Index("SomeIndex", t.c.y, t.c.z * 3),
- ),
- (
- lambda t: Index("SomeIndex", func.lower(t.c.x), t.c.q + "42"),
- lambda t: Index("SomeIndex", func.lower(t.c.q), t.c.x + "42"),
- ),
- (
- lambda t: Index("SomeIndex", func.lower(t.c.x), t.c.z + 42),
- lambda t: Index("SomeIndex", t.c.z + 42, func.lower(t.c.q)),
- ),
- (
- lambda t: Index("SomeIndex", t.c.ff + 42),
- lambda t: Index("SomeIndex", 42 + t.c.ff),
- ),
- ]
- if flatten:
- return list(itertools.chain.from_iterable(diff_pairs))
- else:
- return diff_pairs
-
@testing.fixture
def index_changed_tables(self):
m1 = MetaData()
@@ -1412,12 +1542,16 @@ class AutogenerateExpressionIndexTest(AutogenFixtureTest, TestBase):
diffs = self._fixture(m1, m2)
eq_(diffs, [])
- @combinations(*_lots_of_indexes(flatten=True), argnames="fn")
- def test_expression_indexes_no_change(self, index_changed_tables, fn):
+ @combinations(
+ *_lost_of_equal_indexes(_lots_of_indexes), argnames="fn1, fn2"
+ )
+ def test_expression_indexes_no_change(
+ self, index_changed_tables, fn1, fn2
+ ):
m1, m2, old_fixture_tables, new_fixture_tables = index_changed_tables
- resolve_lambda(fn, **old_fixture_tables)
- resolve_lambda(fn, **new_fixture_tables)
+ resolve_lambda(fn1, **old_fixture_tables)
+ resolve_lambda(fn2, **new_fixture_tables)
if self.has_reflection:
ctx = nullcontext()