diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2021-12-27 16:29:23 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@ci3.zzzcomputing.com> | 2021-12-27 16:29:23 +0000 |
| commit | fd99a4aa808f91f87d0a678708dd9c2b131fda04 (patch) | |
| tree | 75ef065f4b6dc1b250467dddd1c713bac51d8f18 /lib/sqlalchemy | |
| parent | 4a12848a1cf47ed43c93c5ee8029b644242d0a17 (diff) | |
| parent | 6d589ffbb5fe04a4ee606819e948974045f62b80 (diff) | |
| download | sqlalchemy-fd99a4aa808f91f87d0a678708dd9c2b131fda04.tar.gz | |
Merge "consider truediv as truediv; support floordiv operator" into main
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/dialects/mysql/base.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/oracle/base.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/sqlite/base.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/sqlite/pysqlite.py | 18 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 36 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/default_comparator.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/operators.py | 29 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/suite/test_types.py | 86 |
10 files changed, 179 insertions, 9 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index fe0624d08..50bae40b8 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -2327,6 +2327,8 @@ class MySQLDialect(default.DefaultDialect): max_index_name_length = 64 max_constraint_name_length = 64 + div_is_floordiv = False + supports_native_enum = True supports_sequences = False # default for MySQL ... diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 63131bf95..94feeefce 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -1441,6 +1441,8 @@ class OracleDialect(default.DefaultDialect): implicit_returning = True + div_is_floordiv = False + supports_simple_order_by_label = False cte_follows_insert = True diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 43883c4b7..d238de1ab 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -1215,6 +1215,13 @@ class SQLiteCompiler(compiler.SQLCompiler): }, ) + def visit_truediv_binary(self, binary, operator, **kw): + return ( + self.process(binary.left, **kw) + + " / " + + "(%s + 0.0)" % self.process(binary.right, **kw) + ) + def visit_now_func(self, fn, **kw): return "CURRENT_TIMESTAMP" diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py index 944d714a3..77c9ebce7 100644 --- a/lib/sqlalchemy/dialects/sqlite/pysqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py @@ -402,6 +402,7 @@ by adding the desired locking mode to our ``"BEGIN"``:: """ # noqa +import math import os import re @@ -505,14 +506,23 @@ class SQLiteDialect_pysqlite(SQLiteDialect): return None return re.search(a, b) is not None + create_func_kw = {"deterministic": True} if util.py38 else {} + def set_regexp(dbapi_connection): dbapi_connection.create_function( - "regexp", - 2, - regexp, + "regexp", 2, regexp, **create_func_kw + ) + + def floor_func(dbapi_connection): + # NOTE: floor is optionally present in sqlite 3.35+ , however + # as it is normally non-present we deliver floor() unconditionally + # for now. + # https://www.sqlite.org/lang_mathfunc.html + dbapi_connection.create_function( + "floor", 1, math.floor, **create_func_kw ) - fns = [set_regexp] + fns = [set_regexp, floor_func] def connect(conn): for fn in fns: diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index e91e34f00..779939be8 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -55,6 +55,8 @@ class DefaultDialect(interfaces.Dialect): inline_comments = False supports_statement_cache = True + div_is_floordiv = True + bind_typing = interfaces.BindTyping.NONE include_set_input_sizes = None diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 28c1bf069..5f6ee5f41 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -177,7 +177,6 @@ OPERATORS = { operators.mul: " * ", operators.sub: " - ", operators.mod: " % ", - operators.truediv: " / ", operators.neg: "-", operators.lt: " < ", operators.le: " <= ", @@ -1923,6 +1922,41 @@ class SQLCompiler(Compiled): "Unary expression has no operator or modifier" ) + def visit_truediv_binary(self, binary, operator, **kw): + if self.dialect.div_is_floordiv: + return ( + self.process(binary.left, **kw) + + " / " + # TODO: would need a fast cast again here, + # unless we want to use an implicit cast like "+ 0.0" + + self.process( + elements.Cast(binary.right, sqltypes.Numeric()), **kw + ) + ) + else: + return ( + self.process(binary.left, **kw) + + " / " + + self.process(binary.right, **kw) + ) + + def visit_floordiv_binary(self, binary, operator, **kw): + if ( + self.dialect.div_is_floordiv + and binary.right.type._type_affinity is sqltypes.Integer + ): + return ( + self.process(binary.left, **kw) + + " / " + + self.process(binary.right, **kw) + ) + else: + return "FLOOR(%s)" % ( + self.process(binary.left, **kw) + + " / " + + self.process(binary.right, **kw) + ) + def visit_is_true_unary_operator(self, element, operator, **kw): if ( element._is_implicitly_boolean diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 036a96e9f..2bbead673 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -310,6 +310,7 @@ operator_lookup = { "div": (_binary_operate,), "mod": (_binary_operate,), "truediv": (_binary_operate,), + "floordiv": (_binary_operate,), "custom_op": (_custom_op_operate,), "json_path_getitem_op": (_binary_operate,), "json_getitem_op": (_binary_operate,), diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 6d45cd033..74eb73e46 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -14,6 +14,7 @@ from operator import add from operator import and_ from operator import contains from operator import eq +from operator import floordiv from operator import ge from operator import getitem from operator import gt @@ -1220,7 +1221,12 @@ class ColumnOperators(Operators): def __truediv__(self, other): """Implement the ``/`` operator. - In a column context, produces the clause ``a / b``. + In a column context, produces the clause ``a / b``, and + considers the result type to be numeric. + + .. versionchanged:: 2.0 The truediv operator against two integers + is now considered to return a numeric value. Behavior on specific + backends may vary. """ return self.operate(truediv, other) @@ -1233,6 +1239,26 @@ class ColumnOperators(Operators): """ return self.reverse_operate(truediv, other) + def __floordiv__(self, other): + """Implement the ``//`` operator. + + In a column context, produces the clause ``a / b``, + which is the same as "truediv", but considers the result + type to be integer. + + .. versionadded:: 2.0 + + """ + return self.operate(floordiv, other) + + def __rfloordiv__(self, other): + """Implement the ``//`` operator in reverse. + + See :meth:`.ColumnOperators.__floordiv__`. + + """ + return self.reverse_operate(floordiv, other) + _commutative = {eq, ne, add, mul} _comparison = {eq, ne, lt, gt, ge, le} @@ -1588,6 +1614,7 @@ _PRECEDENCE = { json_path_getitem_op: 15, mul: 8, truediv: 8, + floordiv: 8, mod: 8, neg: 8, add: 7, diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index e65fa3c14..f035284f4 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -310,8 +310,6 @@ class Integer(_LookupExpressionAdapter, TypeEngine): @util.memoized_property def _expression_adaptations(self): - # TODO: need a dictionary object that will - # handle operators generically here, this is incomplete return { operators.add: { Date: Date, @@ -323,7 +321,8 @@ class Integer(_LookupExpressionAdapter, TypeEngine): Integer: self.__class__, Numeric: Numeric, }, - operators.truediv: {Integer: self.__class__, Numeric: Numeric}, + operators.truediv: {Integer: Numeric, Numeric: Numeric}, + operators.floordiv: {Integer: self.__class__, Numeric: Numeric}, operators.sub: {Integer: self.__class__, Numeric: Numeric}, } diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index e7131ec6e..78596457e 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -26,6 +26,7 @@ from ... import Float from ... import Integer from ... import JSON from ... import literal +from ... import literal_column from ... import MetaData from ... import null from ... import Numeric @@ -505,6 +506,90 @@ class CastTypeDecoratorTest(_LiteralRoundTripFixture, fixtures.TestBase): eq_(result, {2}) +class TrueDivTest(fixtures.TestBase): + @testing.combinations( + ("15", "10", 1.5), + ("-15", "10", -1.5), + argnames="left, right, expected", + ) + def test_truediv_integer(self, connection, left, right, expected): + """test #4926""" + + eq_( + connection.scalar( + select( + literal_column(left, type_=Integer()) + / literal_column(right, type_=Integer()) + ) + ), + expected, + ) + + @testing.combinations( + ("15", "10", 1), ("-15", "5", -3), argnames="left, right, expected" + ) + def test_floordiv_integer(self, connection, left, right, expected): + """test #4926""" + + eq_( + connection.scalar( + select( + literal_column(left, type_=Integer()) + // literal_column(right, type_=Integer()) + ) + ), + expected, + ) + + @testing.combinations( + ("5.52", "2.4", "2.3"), argnames="left, right, expected" + ) + def test_truediv_numeric(self, connection, left, right, expected): + """test #4926""" + + eq_( + connection.scalar( + select( + literal_column(left, type_=Numeric()) + / literal_column(right, type_=Numeric()) + ) + ), + decimal.Decimal(expected), + ) + + @testing.combinations( + ("5.52", "2.4", "2.0"), argnames="left, right, expected" + ) + def test_floordiv_numeric(self, connection, left, right, expected): + """test #4926""" + + eq_( + connection.scalar( + select( + literal_column(left, type_=Numeric()) + // literal_column(right, type_=Numeric()) + ) + ), + decimal.Decimal(expected), + ) + + def test_truediv_integer_bound(self, connection): + """test #4926""" + + eq_( + connection.scalar(select(literal(15) / literal(10))), + 1.5, + ) + + def test_floordiv_integer_bound(self, connection): + """test #4926""" + + eq_( + connection.scalar(select(literal(15) // literal(10))), + 1, + ) + + class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): __backend__ = True @@ -1439,6 +1524,7 @@ __all__ = ( "TimeMicrosecondsTest", "TimestampMicrosecondsTest", "TimeTest", + "TrueDivTest", "DateTimeMicrosecondsTest", "DateHistoricTest", "StringTest", |
