summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2021-12-08 08:57:44 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2021-12-26 19:32:53 -0500
commit6d589ffbb5fe04a4ee606819e948974045f62b80 (patch)
tree95fc3ac54ae23945e3bf810f85294193f4fbbd82 /lib
parent2bb6cfc7c9b8f09eaa4efeffc337a1162993979c (diff)
downloadsqlalchemy-6d589ffbb5fe04a4ee606819e948974045f62b80.tar.gz
consider truediv as truediv; support floordiv operator
Implemented full support for "truediv" and "floordiv" using the "/" and "//" operators. A "truediv" operation between two expressions using :class:`_types.Integer` now considers the result to be :class:`_types.Numeric`, and the dialect-level compilation will cast the right operand to a numeric type on a dialect-specific basis to ensure truediv is achieved. For floordiv, conversion is also added for those databases that don't already do floordiv by default (MySQL, Oracle) and the ``FLOOR()`` function is rendered in this case, as well as for cases where the right operand is not an integer (needed for PostgreSQL, others). The change resolves issues both with inconsistent behavior of the division operator on different backends and also fixes an issue where integer division on Oracle would fail to be able to fetch a result due to inappropriate outputtypehandlers. Fixes: #4926 Change-Id: Id54cc018c1fb7a49dd3ce1216d68d40f43fe2659
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py2
-rw-r--r--lib/sqlalchemy/dialects/oracle/base.py2
-rw-r--r--lib/sqlalchemy/dialects/sqlite/base.py7
-rw-r--r--lib/sqlalchemy/dialects/sqlite/pysqlite.py18
-rw-r--r--lib/sqlalchemy/engine/default.py2
-rw-r--r--lib/sqlalchemy/sql/compiler.py36
-rw-r--r--lib/sqlalchemy/sql/default_comparator.py1
-rw-r--r--lib/sqlalchemy/sql/operators.py29
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py5
-rw-r--r--lib/sqlalchemy/testing/suite/test_types.py86
10 files changed, 179 insertions, 9 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index fe0624d08..50bae40b8 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -2327,6 +2327,8 @@ class MySQLDialect(default.DefaultDialect):
max_index_name_length = 64
max_constraint_name_length = 64
+ div_is_floordiv = False
+
supports_native_enum = True
supports_sequences = False # default for MySQL ...
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
index 63131bf95..94feeefce 100644
--- a/lib/sqlalchemy/dialects/oracle/base.py
+++ b/lib/sqlalchemy/dialects/oracle/base.py
@@ -1441,6 +1441,8 @@ class OracleDialect(default.DefaultDialect):
implicit_returning = True
+ div_is_floordiv = False
+
supports_simple_order_by_label = False
cte_follows_insert = True
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py
index 43883c4b7..d238de1ab 100644
--- a/lib/sqlalchemy/dialects/sqlite/base.py
+++ b/lib/sqlalchemy/dialects/sqlite/base.py
@@ -1215,6 +1215,13 @@ class SQLiteCompiler(compiler.SQLCompiler):
},
)
+ def visit_truediv_binary(self, binary, operator, **kw):
+ return (
+ self.process(binary.left, **kw)
+ + " / "
+ + "(%s + 0.0)" % self.process(binary.right, **kw)
+ )
+
def visit_now_func(self, fn, **kw):
return "CURRENT_TIMESTAMP"
diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py
index 944d714a3..77c9ebce7 100644
--- a/lib/sqlalchemy/dialects/sqlite/pysqlite.py
+++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py
@@ -402,6 +402,7 @@ by adding the desired locking mode to our ``"BEGIN"``::
""" # noqa
+import math
import os
import re
@@ -505,14 +506,23 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
return None
return re.search(a, b) is not None
+ create_func_kw = {"deterministic": True} if util.py38 else {}
+
def set_regexp(dbapi_connection):
dbapi_connection.create_function(
- "regexp",
- 2,
- regexp,
+ "regexp", 2, regexp, **create_func_kw
+ )
+
+ def floor_func(dbapi_connection):
+ # NOTE: floor is optionally present in sqlite 3.35+ , however
+ # as it is normally non-present we deliver floor() unconditionally
+ # for now.
+ # https://www.sqlite.org/lang_mathfunc.html
+ dbapi_connection.create_function(
+ "floor", 1, math.floor, **create_func_kw
)
- fns = [set_regexp]
+ fns = [set_regexp, floor_func]
def connect(conn):
for fn in fns:
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index e91e34f00..779939be8 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -55,6 +55,8 @@ class DefaultDialect(interfaces.Dialect):
inline_comments = False
supports_statement_cache = True
+ div_is_floordiv = True
+
bind_typing = interfaces.BindTyping.NONE
include_set_input_sizes = None
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 28c1bf069..5f6ee5f41 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -177,7 +177,6 @@ OPERATORS = {
operators.mul: " * ",
operators.sub: " - ",
operators.mod: " % ",
- operators.truediv: " / ",
operators.neg: "-",
operators.lt: " < ",
operators.le: " <= ",
@@ -1923,6 +1922,41 @@ class SQLCompiler(Compiled):
"Unary expression has no operator or modifier"
)
+ def visit_truediv_binary(self, binary, operator, **kw):
+ if self.dialect.div_is_floordiv:
+ return (
+ self.process(binary.left, **kw)
+ + " / "
+ # TODO: would need a fast cast again here,
+ # unless we want to use an implicit cast like "+ 0.0"
+ + self.process(
+ elements.Cast(binary.right, sqltypes.Numeric()), **kw
+ )
+ )
+ else:
+ return (
+ self.process(binary.left, **kw)
+ + " / "
+ + self.process(binary.right, **kw)
+ )
+
+ def visit_floordiv_binary(self, binary, operator, **kw):
+ if (
+ self.dialect.div_is_floordiv
+ and binary.right.type._type_affinity is sqltypes.Integer
+ ):
+ return (
+ self.process(binary.left, **kw)
+ + " / "
+ + self.process(binary.right, **kw)
+ )
+ else:
+ return "FLOOR(%s)" % (
+ self.process(binary.left, **kw)
+ + " / "
+ + self.process(binary.right, **kw)
+ )
+
def visit_is_true_unary_operator(self, element, operator, **kw):
if (
element._is_implicitly_boolean
diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py
index 036a96e9f..2bbead673 100644
--- a/lib/sqlalchemy/sql/default_comparator.py
+++ b/lib/sqlalchemy/sql/default_comparator.py
@@ -310,6 +310,7 @@ operator_lookup = {
"div": (_binary_operate,),
"mod": (_binary_operate,),
"truediv": (_binary_operate,),
+ "floordiv": (_binary_operate,),
"custom_op": (_custom_op_operate,),
"json_path_getitem_op": (_binary_operate,),
"json_getitem_op": (_binary_operate,),
diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py
index 6d45cd033..74eb73e46 100644
--- a/lib/sqlalchemy/sql/operators.py
+++ b/lib/sqlalchemy/sql/operators.py
@@ -14,6 +14,7 @@ from operator import add
from operator import and_
from operator import contains
from operator import eq
+from operator import floordiv
from operator import ge
from operator import getitem
from operator import gt
@@ -1220,7 +1221,12 @@ class ColumnOperators(Operators):
def __truediv__(self, other):
"""Implement the ``/`` operator.
- In a column context, produces the clause ``a / b``.
+ In a column context, produces the clause ``a / b``, and
+ considers the result type to be numeric.
+
+ .. versionchanged:: 2.0 The truediv operator against two integers
+ is now considered to return a numeric value. Behavior on specific
+ backends may vary.
"""
return self.operate(truediv, other)
@@ -1233,6 +1239,26 @@ class ColumnOperators(Operators):
"""
return self.reverse_operate(truediv, other)
+ def __floordiv__(self, other):
+ """Implement the ``//`` operator.
+
+ In a column context, produces the clause ``a / b``,
+ which is the same as "truediv", but considers the result
+ type to be integer.
+
+ .. versionadded:: 2.0
+
+ """
+ return self.operate(floordiv, other)
+
+ def __rfloordiv__(self, other):
+ """Implement the ``//`` operator in reverse.
+
+ See :meth:`.ColumnOperators.__floordiv__`.
+
+ """
+ return self.reverse_operate(floordiv, other)
+
_commutative = {eq, ne, add, mul}
_comparison = {eq, ne, lt, gt, ge, le}
@@ -1588,6 +1614,7 @@ _PRECEDENCE = {
json_path_getitem_op: 15,
mul: 8,
truediv: 8,
+ floordiv: 8,
mod: 8,
neg: 8,
add: 7,
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index e65fa3c14..f035284f4 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -310,8 +310,6 @@ class Integer(_LookupExpressionAdapter, TypeEngine):
@util.memoized_property
def _expression_adaptations(self):
- # TODO: need a dictionary object that will
- # handle operators generically here, this is incomplete
return {
operators.add: {
Date: Date,
@@ -323,7 +321,8 @@ class Integer(_LookupExpressionAdapter, TypeEngine):
Integer: self.__class__,
Numeric: Numeric,
},
- operators.truediv: {Integer: self.__class__, Numeric: Numeric},
+ operators.truediv: {Integer: Numeric, Numeric: Numeric},
+ operators.floordiv: {Integer: self.__class__, Numeric: Numeric},
operators.sub: {Integer: self.__class__, Numeric: Numeric},
}
diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py
index e7131ec6e..78596457e 100644
--- a/lib/sqlalchemy/testing/suite/test_types.py
+++ b/lib/sqlalchemy/testing/suite/test_types.py
@@ -26,6 +26,7 @@ from ... import Float
from ... import Integer
from ... import JSON
from ... import literal
+from ... import literal_column
from ... import MetaData
from ... import null
from ... import Numeric
@@ -505,6 +506,90 @@ class CastTypeDecoratorTest(_LiteralRoundTripFixture, fixtures.TestBase):
eq_(result, {2})
+class TrueDivTest(fixtures.TestBase):
+ @testing.combinations(
+ ("15", "10", 1.5),
+ ("-15", "10", -1.5),
+ argnames="left, right, expected",
+ )
+ def test_truediv_integer(self, connection, left, right, expected):
+ """test #4926"""
+
+ eq_(
+ connection.scalar(
+ select(
+ literal_column(left, type_=Integer())
+ / literal_column(right, type_=Integer())
+ )
+ ),
+ expected,
+ )
+
+ @testing.combinations(
+ ("15", "10", 1), ("-15", "5", -3), argnames="left, right, expected"
+ )
+ def test_floordiv_integer(self, connection, left, right, expected):
+ """test #4926"""
+
+ eq_(
+ connection.scalar(
+ select(
+ literal_column(left, type_=Integer())
+ // literal_column(right, type_=Integer())
+ )
+ ),
+ expected,
+ )
+
+ @testing.combinations(
+ ("5.52", "2.4", "2.3"), argnames="left, right, expected"
+ )
+ def test_truediv_numeric(self, connection, left, right, expected):
+ """test #4926"""
+
+ eq_(
+ connection.scalar(
+ select(
+ literal_column(left, type_=Numeric())
+ / literal_column(right, type_=Numeric())
+ )
+ ),
+ decimal.Decimal(expected),
+ )
+
+ @testing.combinations(
+ ("5.52", "2.4", "2.0"), argnames="left, right, expected"
+ )
+ def test_floordiv_numeric(self, connection, left, right, expected):
+ """test #4926"""
+
+ eq_(
+ connection.scalar(
+ select(
+ literal_column(left, type_=Numeric())
+ // literal_column(right, type_=Numeric())
+ )
+ ),
+ decimal.Decimal(expected),
+ )
+
+ def test_truediv_integer_bound(self, connection):
+ """test #4926"""
+
+ eq_(
+ connection.scalar(select(literal(15) / literal(10))),
+ 1.5,
+ )
+
+ def test_floordiv_integer_bound(self, connection):
+ """test #4926"""
+
+ eq_(
+ connection.scalar(select(literal(15) // literal(10))),
+ 1,
+ )
+
+
class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
__backend__ = True
@@ -1439,6 +1524,7 @@ __all__ = (
"TimeMicrosecondsTest",
"TimestampMicrosecondsTest",
"TimeTest",
+ "TrueDivTest",
"DateTimeMicrosecondsTest",
"DateHistoricTest",
"StringTest",