summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFederico Caselli <cfederico87@gmail.com>2023-05-11 21:49:14 +0200
committerFederico Caselli <cfederico87@gmail.com>2023-05-11 22:41:09 +0200
commit230a2932f646800b006c00b434be95c164598525 (patch)
tree03b4b3b665b41461798b54c3a0b29ff9b7cdec75
parent92e54a0e1c96cecd99397cb1aee9c3bb28f780c6 (diff)
downloadalembic-230a2932f646800b006c00b434be95c164598525.tar.gz
Removed server default quoting from compare
Don't modify the metadata server default when comparing it in the autogenerate process. This impacts the value passes to user provided functions passed in :paramref:`.EnvironmentContext.configure.compare_server_default` and third party dialect that implement a custom ``compare_server_default``. Fixes: #1178 Change-Id: Ib429efcf9077337f768ad5aad91659867e89391a
-rw-r--r--alembic/autogenerate/compare.py12
-rw-r--r--alembic/context.pyi17
-rw-r--r--alembic/ddl/mysql.py11
-rw-r--r--alembic/runtime/environment.py16
-rw-r--r--docs/build/unreleased/1178.rst9
-rw-r--r--tests/test_postgresql.py4
6 files changed, 55 insertions, 14 deletions
diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py
index 595631c..b489328 100644
--- a/alembic/autogenerate/compare.py
+++ b/alembic/autogenerate/compare.py
@@ -1023,9 +1023,7 @@ def _compare_type(
def _render_server_default_for_compare(
- metadata_default: Optional[Any],
- metadata_col: Column,
- autogen_context: AutogenContext,
+ metadata_default: Optional[Any], autogen_context: AutogenContext
) -> Optional[str]:
if isinstance(metadata_default, sa_schema.DefaultClause):
@@ -1039,11 +1037,7 @@ def _render_server_default_for_compare(
)
)
if isinstance(metadata_default, str):
- if metadata_col.type._type_affinity is sqltypes.String:
- metadata_default = re.sub(r"^'|'$", "", metadata_default)
- return f"'{metadata_default}'"
- else:
- return metadata_default
+ return metadata_default
else:
return None
@@ -1190,7 +1184,7 @@ def _compare_server_default(
)
else:
rendered_metadata_default = _render_server_default_for_compare(
- metadata_default, metadata_col, autogen_context
+ metadata_default, autogen_context
)
rendered_conn_default = (
diff --git a/alembic/context.pyi b/alembic/context.pyi
index c81a14f..621599d 100644
--- a/alembic/context.pyi
+++ b/alembic/context.pyi
@@ -22,6 +22,8 @@ if TYPE_CHECKING:
from sqlalchemy.engine.base import Connection
from sqlalchemy.engine.url import URL
from sqlalchemy.sql.elements import ClauseElement
+ from sqlalchemy.sql.schema import Column
+ from sqlalchemy.sql.schema import FetchedValue
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.schema import SchemaItem
@@ -144,7 +146,20 @@ def configure(
]
] = None,
compare_type: bool = False,
- compare_server_default: bool = False,
+ compare_server_default: Union[
+ bool,
+ Callable[
+ [
+ MigrationContext,
+ Column,
+ Column,
+ Optional[str],
+ Optional[FetchedValue],
+ Optional[str],
+ ],
+ Optional[bool],
+ ],
+ ] = False,
render_item: Optional[
Callable[[str, Any, AutogenContext], Union[str, Literal[False]]]
] = None,
diff --git a/alembic/ddl/mysql.py b/alembic/ddl/mysql.py
index a452760..5e66f53 100644
--- a/alembic/ddl/mysql.py
+++ b/alembic/ddl/mysql.py
@@ -185,13 +185,22 @@ class MySQLImpl(DefaultImpl):
and rendered_inspector_default == "'0'"
):
return False
- elif inspector_column.type._type_affinity is sqltypes.Integer:
+ elif (
+ rendered_inspector_default
+ and inspector_column.type._type_affinity is sqltypes.Integer
+ ):
rendered_inspector_default = (
re.sub(r"^'|'$", "", rendered_inspector_default)
if rendered_inspector_default is not None
else None
)
return rendered_inspector_default != rendered_metadata_default
+ elif (
+ rendered_metadata_default
+ and metadata_column.type._type_affinity is sqltypes.String
+ ):
+ metadata_default = re.sub(r"^'|'$", "", rendered_metadata_default)
+ return rendered_inspector_default != f"'{metadata_default}'"
elif rendered_inspector_default and rendered_metadata_default:
# adjust for "function()" vs. "FUNCTION" as can occur particularly
# for the CURRENT_TIMESTAMP function on newer MariaDB versions
diff --git a/alembic/runtime/environment.py b/alembic/runtime/environment.py
index 71a5309..3087377 100644
--- a/alembic/runtime/environment.py
+++ b/alembic/runtime/environment.py
@@ -15,6 +15,8 @@ from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
+from sqlalchemy.sql.schema import Column
+from sqlalchemy.sql.schema import FetchedValue
from typing_extensions import Literal
from .migration import _ProxyTransaction
@@ -79,6 +81,18 @@ OnVersionApplyFn = Callable[
None,
]
+CompareServerDefault = Callable[
+ [
+ MigrationContext,
+ Column,
+ Column,
+ Optional[str],
+ Optional[FetchedValue],
+ Optional[str],
+ ],
+ Optional[bool],
+]
+
class EnvironmentContext(util.ModuleClsProxy):
@@ -398,7 +412,7 @@ class EnvironmentContext(util.ModuleClsProxy):
ProcessRevisionDirectiveFn
] = None,
compare_type: bool = False,
- compare_server_default: bool = False,
+ compare_server_default: Union[bool, CompareServerDefault] = False,
render_item: Optional[RenderItemFn] = None,
literal_binds: bool = False,
upgrade_token: str = "upgrades",
diff --git a/docs/build/unreleased/1178.rst b/docs/build/unreleased/1178.rst
new file mode 100644
index 0000000..25789d3
--- /dev/null
+++ b/docs/build/unreleased/1178.rst
@@ -0,0 +1,9 @@
+.. change::
+ :tags: changed, autogenerate
+ :tickets: 1178
+
+ Don't modify the metadata server default when comparing it in the
+ autogenerate process.
+ This impacts the value passes to user provided functions passed in
+ :paramref:`.EnvironmentContext.configure.compare_server_default`
+ and third party dialect that implement a custom ``compare_server_default``.
diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py
index 77ed4da..818dae7 100644
--- a/tests/test_postgresql.py
+++ b/tests/test_postgresql.py
@@ -846,7 +846,7 @@ class PostgresqlDetectSerialTest(TestBase):
eq_(
_render_server_default_for_compare(
- tab.c.x.server_default, tab.c.x, self.autogen_context
+ tab.c.x.server_default, self.autogen_context
),
c_expected,
)
@@ -867,7 +867,7 @@ class PostgresqlDetectSerialTest(TestBase):
server_default = diffs[0][0][4]["existing_server_default"]
eq_(
_render_server_default_for_compare(
- server_default, tab.c.x, self.autogen_context
+ server_default, self.autogen_context
),
c_expected,
)