diff options
| -rw-r--r-- | doc/build/changelog/unreleased_13/5808.rst | 6 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/mysql/base.py | 18 | ||||
| -rw-r--r-- | test/dialect/mysql/test_compiler.py | 32 |
3 files changed, 52 insertions, 4 deletions
diff --git a/doc/build/changelog/unreleased_13/5808.rst b/doc/build/changelog/unreleased_13/5808.rst new file mode 100644 index 000000000..b6625c050 --- /dev/null +++ b/doc/build/changelog/unreleased_13/5808.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: usecase, mysql + :tickets: 5808 + + Casting to ``FLOAT`` is now supported in MySQL >= (8, 0, 17) and + MariaDb >= (10, 4, 5).
\ No newline at end of file diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 7a4d3261f..a4b583541 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1624,6 +1624,11 @@ class MySQLCompiler(compiler.SQLCompiler): return self.dialect.type_compiler.process(type_).replace( "NUMERIC", "DECIMAL" ) + elif ( + isinstance(type_, sqltypes.Float) + and self.dialect._support_float_cast + ): + return self.dialect.type_compiler.process(type_) else: return None @@ -1631,7 +1636,7 @@ class MySQLCompiler(compiler.SQLCompiler): type_ = self.process(cast.typeclause) if type_ is None: util.warn( - "Datatype %s does not support CAST on MySQL; " + "Datatype %s does not support CAST on MySQL/MariaDb; " "the CAST will be skipped." % self.dialect.type_compiler.process(cast.typeclause.type) ) @@ -2900,6 +2905,17 @@ class MySQLDialect(default.DefaultDialect): ) @property + def _support_float_cast(self): + if not self.server_version_info: + return False + elif self.is_mariadb: + # ref https://mariadb.com/kb/en/mariadb-1045-release-notes/ + return self.server_version_info >= (10, 4, 5) + else: + # ref https://dev.mysql.com/doc/relnotes/mysql/8.0/en/news-8-0-17.html#mysqld-8-0-17-feature # noqa + return self.server_version_info >= (8, 0, 17) + + @property def _is_mariadb(self): return self.is_mariadb diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py index 2993f96b8..62292b9da 100644 --- a/test/dialect/mysql/test_compiler.py +++ b/test/dialect/mysql/test_compiler.py @@ -710,7 +710,9 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): def test_unsupported_cast_literal_bind(self): expr = cast(column("foo", Integer) + 5, Float) - with expect_warnings("Datatype FLOAT does not support CAST on MySQL;"): + with expect_warnings( + "Datatype FLOAT does not support CAST on MySQL/MariaDb;" + ): self.assert_compile(expr, "(foo + 5)", literal_binds=True) m = mysql @@ -734,11 +736,35 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): def test_unsupported_casts(self, type_, expected): t = sql.table("t", sql.column("col")) - with expect_warnings("Datatype .* does not support CAST on MySQL;"): + with expect_warnings( + "Datatype .* does not support CAST on MySQL/MariaDb;" + ): self.assert_compile(cast(t.c.col, type_), expected) + @testing.combinations( + (m.FLOAT, "CAST(t.col AS FLOAT)"), + (Float, "CAST(t.col AS FLOAT)"), + (FLOAT, "CAST(t.col AS FLOAT)"), + (m.DOUBLE, "CAST(t.col AS DOUBLE)"), + (m.FLOAT, "CAST(t.col AS FLOAT)"), + argnames="type_,expected", + ) + @testing.combinations(True, False, argnames="maria_db") + def test_float_cast(self, type_, expected, maria_db): + + dialect = mysql.dialect() + if maria_db: + dialect.is_mariadb = maria_db + dialect.server_version_info = (10, 4, 5) + else: + dialect.server_version_info = (8, 0, 17) + t = sql.table("t", sql.column("col")) + self.assert_compile(cast(t.c.col, type_), expected, dialect=dialect) + def test_cast_grouped_expression_non_castable(self): - with expect_warnings("Datatype FLOAT does not support CAST on MySQL;"): + with expect_warnings( + "Datatype FLOAT does not support CAST on MySQL/MariaDb;" + ): self.assert_compile( cast(sql.column("x") + sql.column("y"), Float), "(x + y)" ) |
