summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDimitris Theodorou <dimitris.theodorou@gmail.com>2015-01-12 00:41:59 +0100
committerDimitris Theodorou <dimitris.theodorou@gmail.com>2015-01-12 00:41:59 +0100
commit214228a9dd0fa307fc9d0ebfd6a52390cd63f6f7 (patch)
tree16d79a9ce78986c6397cebd76749eede59ae4025
parent00dae5f77e51c8cb4902939a5b85efdfb920c68b (diff)
downloadalembic-214228a9dd0fa307fc9d0ebfd6a52390cd63f6f7.tar.gz
Change single-quoting of floats in PostgreSQL compare_server_default
Do not wrap string defaults with single quotes when comparing against columns of type float or numeric. This fixes the crash occuring when the default of a float column is an integer value (e.g., DEFAULT 5), while the Python server_default is a string (e.g., server_default="5.0"). This results in the query used in the comparison to throw a DataError ('SELECT 5 = '5.0').
-rw-r--r--alembic/ddl/postgresql.py7
-rw-r--r--tests/test_postgresql.py70
2 files changed, 72 insertions, 5 deletions
diff --git a/alembic/ddl/postgresql.py b/alembic/ddl/postgresql.py
index 0877c95..4c6e9d7 100644
--- a/alembic/ddl/postgresql.py
+++ b/alembic/ddl/postgresql.py
@@ -4,7 +4,7 @@ from .. import compat
from .base import compiles, alter_table, format_table_name, RenameTable
from .impl import DefaultImpl
from sqlalchemy.dialects.postgresql import INTEGER, BIGINT
-from sqlalchemy import text
+from sqlalchemy import text, Float, Numeric
import logging
log = logging.getLogger(__name__)
@@ -35,7 +35,10 @@ class PostgresqlImpl(DefaultImpl):
if metadata_column.server_default is not None and \
isinstance(metadata_column.server_default.arg,
compat.string_types) and \
- not re.match(r"^'.+'$", rendered_metadata_default):
+ not re.match(r"^'.+'$", rendered_metadata_default) and \
+ not isinstance(inspector_column.type, (Float, Numeric)):
+ # don't single quote if the column type is float/numeric,
+ # otherwise a comparison such as SELECT 5 = '5.0' will fail
rendered_metadata_default = "'%s'" % rendered_metadata_default
return not self.connection.scalar(
diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py
index 908eec6..e70d05a 100644
--- a/tests/test_postgresql.py
+++ b/tests/test_postgresql.py
@@ -1,6 +1,6 @@
from sqlalchemy import DateTime, MetaData, Table, Column, text, Integer, \
- String, Interval, Sequence, Numeric, BigInteger
+ String, Interval, Sequence, Numeric, BigInteger, Float, Numeric
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy.engine.reflection import Inspector
from alembic.operations import Operations
@@ -193,8 +193,11 @@ class PostgresqlDefaultCompareTest(TestBase):
def tearDown(self):
self.metadata.drop_all()
- def _compare_default_roundtrip(self, type_, orig_default, alternate=None):
- diff_expected = alternate is not None
+ def _compare_default_roundtrip(
+ self, type_, orig_default, alternate=None, diff_expected=None):
+ diff_expected = diff_expected \
+ if diff_expected is not None \
+ else alternate is not None
if alternate is None:
alternate = orig_default
@@ -274,6 +277,67 @@ class PostgresqlDefaultCompareTest(TestBase):
text("5"), "7"
)
+ def test_compare_float_str(self):
+ self._compare_default_roundtrip(
+ Float(),
+ "5.2",
+ )
+
+ def test_compare_float_text(self):
+ self._compare_default_roundtrip(
+ Float(),
+ text("5.2"),
+ )
+
+ def test_compare_float_no_diff1(self):
+ self._compare_default_roundtrip(
+ Float(),
+ text("5.2"), "5.2",
+ diff_expected=False
+ )
+
+ def test_compare_float_no_diff2(self):
+ self._compare_default_roundtrip(
+ Float(),
+ "5.2", text("5.2"),
+ diff_expected=False
+ )
+
+ def test_compare_float_no_diff3(self):
+ self._compare_default_roundtrip(
+ Float(),
+ text("5"), text("5.0"),
+ diff_expected=False
+ )
+
+ def test_compare_float_no_diff4(self):
+ self._compare_default_roundtrip(
+ Float(),
+ "5", "5.0",
+ diff_expected=False
+ )
+
+ def test_compare_float_no_diff5(self):
+ self._compare_default_roundtrip(
+ Float(),
+ text("5"), "5.0",
+ diff_expected=False
+ )
+
+ def test_compare_float_no_diff6(self):
+ self._compare_default_roundtrip(
+ Float(),
+ "5", text("5.0"),
+ diff_expected=False
+ )
+
+ def test_compare_numeric_no_diff(self):
+ self._compare_default_roundtrip(
+ Numeric(),
+ text("5"), "5.0",
+ diff_expected=False
+ )
+
def test_compare_character_str(self):
self._compare_default_roundtrip(
String(),