summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/unreleased_13/5808.rst6
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py18
-rw-r--r--test/dialect/mysql/test_compiler.py32
3 files changed, 52 insertions, 4 deletions
diff --git a/doc/build/changelog/unreleased_13/5808.rst b/doc/build/changelog/unreleased_13/5808.rst
new file mode 100644
index 000000000..b6625c050
--- /dev/null
+++ b/doc/build/changelog/unreleased_13/5808.rst
@@ -0,0 +1,6 @@
+.. change::
+ :tags: usecase, mysql
+ :tickets: 5808
+
+ Casting to ``FLOAT`` is now supported in MySQL >= (8, 0, 17) and
+ MariaDb >= (10, 4, 5). \ No newline at end of file
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index 7a4d3261f..a4b583541 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -1624,6 +1624,11 @@ class MySQLCompiler(compiler.SQLCompiler):
return self.dialect.type_compiler.process(type_).replace(
"NUMERIC", "DECIMAL"
)
+ elif (
+ isinstance(type_, sqltypes.Float)
+ and self.dialect._support_float_cast
+ ):
+ return self.dialect.type_compiler.process(type_)
else:
return None
@@ -1631,7 +1636,7 @@ class MySQLCompiler(compiler.SQLCompiler):
type_ = self.process(cast.typeclause)
if type_ is None:
util.warn(
- "Datatype %s does not support CAST on MySQL; "
+ "Datatype %s does not support CAST on MySQL/MariaDb; "
"the CAST will be skipped."
% self.dialect.type_compiler.process(cast.typeclause.type)
)
@@ -2900,6 +2905,17 @@ class MySQLDialect(default.DefaultDialect):
)
@property
+ def _support_float_cast(self):
+ if not self.server_version_info:
+ return False
+ elif self.is_mariadb:
+ # ref https://mariadb.com/kb/en/mariadb-1045-release-notes/
+ return self.server_version_info >= (10, 4, 5)
+ else:
+ # ref https://dev.mysql.com/doc/relnotes/mysql/8.0/en/news-8-0-17.html#mysqld-8-0-17-feature # noqa
+ return self.server_version_info >= (8, 0, 17)
+
+ @property
def _is_mariadb(self):
return self.is_mariadb
diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py
index 2993f96b8..62292b9da 100644
--- a/test/dialect/mysql/test_compiler.py
+++ b/test/dialect/mysql/test_compiler.py
@@ -710,7 +710,9 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL):
def test_unsupported_cast_literal_bind(self):
expr = cast(column("foo", Integer) + 5, Float)
- with expect_warnings("Datatype FLOAT does not support CAST on MySQL;"):
+ with expect_warnings(
+ "Datatype FLOAT does not support CAST on MySQL/MariaDb;"
+ ):
self.assert_compile(expr, "(foo + 5)", literal_binds=True)
m = mysql
@@ -734,11 +736,35 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL):
def test_unsupported_casts(self, type_, expected):
t = sql.table("t", sql.column("col"))
- with expect_warnings("Datatype .* does not support CAST on MySQL;"):
+ with expect_warnings(
+ "Datatype .* does not support CAST on MySQL/MariaDb;"
+ ):
self.assert_compile(cast(t.c.col, type_), expected)
+ @testing.combinations(
+ (m.FLOAT, "CAST(t.col AS FLOAT)"),
+ (Float, "CAST(t.col AS FLOAT)"),
+ (FLOAT, "CAST(t.col AS FLOAT)"),
+ (m.DOUBLE, "CAST(t.col AS DOUBLE)"),
+ (m.FLOAT, "CAST(t.col AS FLOAT)"),
+ argnames="type_,expected",
+ )
+ @testing.combinations(True, False, argnames="maria_db")
+ def test_float_cast(self, type_, expected, maria_db):
+
+ dialect = mysql.dialect()
+ if maria_db:
+ dialect.is_mariadb = maria_db
+ dialect.server_version_info = (10, 4, 5)
+ else:
+ dialect.server_version_info = (8, 0, 17)
+ t = sql.table("t", sql.column("col"))
+ self.assert_compile(cast(t.c.col, type_), expected, dialect=dialect)
+
def test_cast_grouped_expression_non_castable(self):
- with expect_warnings("Datatype FLOAT does not support CAST on MySQL;"):
+ with expect_warnings(
+ "Datatype FLOAT does not support CAST on MySQL/MariaDb;"
+ ):
self.assert_compile(
cast(sql.column("x") + sql.column("y"), Float), "(x + y)"
)