summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2021-12-27 16:29:23 +0000
committerGerrit Code Review <gerrit@ci3.zzzcomputing.com>2021-12-27 16:29:23 +0000
commitfd99a4aa808f91f87d0a678708dd9c2b131fda04 (patch)
tree75ef065f4b6dc1b250467dddd1c713bac51d8f18 /lib/sqlalchemy
parent4a12848a1cf47ed43c93c5ee8029b644242d0a17 (diff)
parent6d589ffbb5fe04a4ee606819e948974045f62b80 (diff)
downloadsqlalchemy-fd99a4aa808f91f87d0a678708dd9c2b131fda04.tar.gz
Merge "consider truediv as truediv; support floordiv operator" into main
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py2
-rw-r--r--lib/sqlalchemy/dialects/oracle/base.py2
-rw-r--r--lib/sqlalchemy/dialects/sqlite/base.py7
-rw-r--r--lib/sqlalchemy/dialects/sqlite/pysqlite.py18
-rw-r--r--lib/sqlalchemy/engine/default.py2
-rw-r--r--lib/sqlalchemy/sql/compiler.py36
-rw-r--r--lib/sqlalchemy/sql/default_comparator.py1
-rw-r--r--lib/sqlalchemy/sql/operators.py29
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py5
-rw-r--r--lib/sqlalchemy/testing/suite/test_types.py86
10 files changed, 179 insertions, 9 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index fe0624d08..50bae40b8 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -2327,6 +2327,8 @@ class MySQLDialect(default.DefaultDialect):
max_index_name_length = 64
max_constraint_name_length = 64
+ div_is_floordiv = False
+
supports_native_enum = True
supports_sequences = False # default for MySQL ...
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
index 63131bf95..94feeefce 100644
--- a/lib/sqlalchemy/dialects/oracle/base.py
+++ b/lib/sqlalchemy/dialects/oracle/base.py
@@ -1441,6 +1441,8 @@ class OracleDialect(default.DefaultDialect):
implicit_returning = True
+ div_is_floordiv = False
+
supports_simple_order_by_label = False
cte_follows_insert = True
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py
index 43883c4b7..d238de1ab 100644
--- a/lib/sqlalchemy/dialects/sqlite/base.py
+++ b/lib/sqlalchemy/dialects/sqlite/base.py
@@ -1215,6 +1215,13 @@ class SQLiteCompiler(compiler.SQLCompiler):
},
)
+ def visit_truediv_binary(self, binary, operator, **kw):
+ return (
+ self.process(binary.left, **kw)
+ + " / "
+ + "(%s + 0.0)" % self.process(binary.right, **kw)
+ )
+
def visit_now_func(self, fn, **kw):
return "CURRENT_TIMESTAMP"
diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py
index 944d714a3..77c9ebce7 100644
--- a/lib/sqlalchemy/dialects/sqlite/pysqlite.py
+++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py
@@ -402,6 +402,7 @@ by adding the desired locking mode to our ``"BEGIN"``::
""" # noqa
+import math
import os
import re
@@ -505,14 +506,23 @@ class SQLiteDialect_pysqlite(SQLiteDialect):
return None
return re.search(a, b) is not None
+ create_func_kw = {"deterministic": True} if util.py38 else {}
+
def set_regexp(dbapi_connection):
dbapi_connection.create_function(
- "regexp",
- 2,
- regexp,
+ "regexp", 2, regexp, **create_func_kw
+ )
+
+ def floor_func(dbapi_connection):
+ # NOTE: floor is optionally present in sqlite 3.35+ , however
+ # as it is normally non-present we deliver floor() unconditionally
+ # for now.
+ # https://www.sqlite.org/lang_mathfunc.html
+ dbapi_connection.create_function(
+ "floor", 1, math.floor, **create_func_kw
)
- fns = [set_regexp]
+ fns = [set_regexp, floor_func]
def connect(conn):
for fn in fns:
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index e91e34f00..779939be8 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -55,6 +55,8 @@ class DefaultDialect(interfaces.Dialect):
inline_comments = False
supports_statement_cache = True
+ div_is_floordiv = True
+
bind_typing = interfaces.BindTyping.NONE
include_set_input_sizes = None
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 28c1bf069..5f6ee5f41 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -177,7 +177,6 @@ OPERATORS = {
operators.mul: " * ",
operators.sub: " - ",
operators.mod: " % ",
- operators.truediv: " / ",
operators.neg: "-",
operators.lt: " < ",
operators.le: " <= ",
@@ -1923,6 +1922,41 @@ class SQLCompiler(Compiled):
"Unary expression has no operator or modifier"
)
+ def visit_truediv_binary(self, binary, operator, **kw):
+ if self.dialect.div_is_floordiv:
+ return (
+ self.process(binary.left, **kw)
+ + " / "
+ # TODO: would need a fast cast again here,
+ # unless we want to use an implicit cast like "+ 0.0"
+ + self.process(
+ elements.Cast(binary.right, sqltypes.Numeric()), **kw
+ )
+ )
+ else:
+ return (
+ self.process(binary.left, **kw)
+ + " / "
+ + self.process(binary.right, **kw)
+ )
+
+ def visit_floordiv_binary(self, binary, operator, **kw):
+ if (
+ self.dialect.div_is_floordiv
+ and binary.right.type._type_affinity is sqltypes.Integer
+ ):
+ return (
+ self.process(binary.left, **kw)
+ + " / "
+ + self.process(binary.right, **kw)
+ )
+ else:
+ return "FLOOR(%s)" % (
+ self.process(binary.left, **kw)
+ + " / "
+ + self.process(binary.right, **kw)
+ )
+
def visit_is_true_unary_operator(self, element, operator, **kw):
if (
element._is_implicitly_boolean
diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py
index 036a96e9f..2bbead673 100644
--- a/lib/sqlalchemy/sql/default_comparator.py
+++ b/lib/sqlalchemy/sql/default_comparator.py
@@ -310,6 +310,7 @@ operator_lookup = {
"div": (_binary_operate,),
"mod": (_binary_operate,),
"truediv": (_binary_operate,),
+ "floordiv": (_binary_operate,),
"custom_op": (_custom_op_operate,),
"json_path_getitem_op": (_binary_operate,),
"json_getitem_op": (_binary_operate,),
diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py
index 6d45cd033..74eb73e46 100644
--- a/lib/sqlalchemy/sql/operators.py
+++ b/lib/sqlalchemy/sql/operators.py
@@ -14,6 +14,7 @@ from operator import add
from operator import and_
from operator import contains
from operator import eq
+from operator import floordiv
from operator import ge
from operator import getitem
from operator import gt
@@ -1220,7 +1221,12 @@ class ColumnOperators(Operators):
def __truediv__(self, other):
"""Implement the ``/`` operator.
- In a column context, produces the clause ``a / b``.
+ In a column context, produces the clause ``a / b``, and
+ considers the result type to be numeric.
+
+ .. versionchanged:: 2.0 The truediv operator against two integers
+ is now considered to return a numeric value. Behavior on specific
+ backends may vary.
"""
return self.operate(truediv, other)
@@ -1233,6 +1239,26 @@ class ColumnOperators(Operators):
"""
return self.reverse_operate(truediv, other)
+ def __floordiv__(self, other):
+ """Implement the ``//`` operator.
+
+ In a column context, produces the clause ``a / b``,
+ which is the same as "truediv", but considers the result
+ type to be integer.
+
+ .. versionadded:: 2.0
+
+ """
+ return self.operate(floordiv, other)
+
+ def __rfloordiv__(self, other):
+ """Implement the ``//`` operator in reverse.
+
+ See :meth:`.ColumnOperators.__floordiv__`.
+
+ """
+ return self.reverse_operate(floordiv, other)
+
_commutative = {eq, ne, add, mul}
_comparison = {eq, ne, lt, gt, ge, le}
@@ -1588,6 +1614,7 @@ _PRECEDENCE = {
json_path_getitem_op: 15,
mul: 8,
truediv: 8,
+ floordiv: 8,
mod: 8,
neg: 8,
add: 7,
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index e65fa3c14..f035284f4 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -310,8 +310,6 @@ class Integer(_LookupExpressionAdapter, TypeEngine):
@util.memoized_property
def _expression_adaptations(self):
- # TODO: need a dictionary object that will
- # handle operators generically here, this is incomplete
return {
operators.add: {
Date: Date,
@@ -323,7 +321,8 @@ class Integer(_LookupExpressionAdapter, TypeEngine):
Integer: self.__class__,
Numeric: Numeric,
},
- operators.truediv: {Integer: self.__class__, Numeric: Numeric},
+ operators.truediv: {Integer: Numeric, Numeric: Numeric},
+ operators.floordiv: {Integer: self.__class__, Numeric: Numeric},
operators.sub: {Integer: self.__class__, Numeric: Numeric},
}
diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py
index e7131ec6e..78596457e 100644
--- a/lib/sqlalchemy/testing/suite/test_types.py
+++ b/lib/sqlalchemy/testing/suite/test_types.py
@@ -26,6 +26,7 @@ from ... import Float
from ... import Integer
from ... import JSON
from ... import literal
+from ... import literal_column
from ... import MetaData
from ... import null
from ... import Numeric
@@ -505,6 +506,90 @@ class CastTypeDecoratorTest(_LiteralRoundTripFixture, fixtures.TestBase):
eq_(result, {2})
+class TrueDivTest(fixtures.TestBase):
+ @testing.combinations(
+ ("15", "10", 1.5),
+ ("-15", "10", -1.5),
+ argnames="left, right, expected",
+ )
+ def test_truediv_integer(self, connection, left, right, expected):
+ """test #4926"""
+
+ eq_(
+ connection.scalar(
+ select(
+ literal_column(left, type_=Integer())
+ / literal_column(right, type_=Integer())
+ )
+ ),
+ expected,
+ )
+
+ @testing.combinations(
+ ("15", "10", 1), ("-15", "5", -3), argnames="left, right, expected"
+ )
+ def test_floordiv_integer(self, connection, left, right, expected):
+ """test #4926"""
+
+ eq_(
+ connection.scalar(
+ select(
+ literal_column(left, type_=Integer())
+ // literal_column(right, type_=Integer())
+ )
+ ),
+ expected,
+ )
+
+ @testing.combinations(
+ ("5.52", "2.4", "2.3"), argnames="left, right, expected"
+ )
+ def test_truediv_numeric(self, connection, left, right, expected):
+ """test #4926"""
+
+ eq_(
+ connection.scalar(
+ select(
+ literal_column(left, type_=Numeric())
+ / literal_column(right, type_=Numeric())
+ )
+ ),
+ decimal.Decimal(expected),
+ )
+
+ @testing.combinations(
+ ("5.52", "2.4", "2.0"), argnames="left, right, expected"
+ )
+ def test_floordiv_numeric(self, connection, left, right, expected):
+ """test #4926"""
+
+ eq_(
+ connection.scalar(
+ select(
+ literal_column(left, type_=Numeric())
+ // literal_column(right, type_=Numeric())
+ )
+ ),
+ decimal.Decimal(expected),
+ )
+
+ def test_truediv_integer_bound(self, connection):
+ """test #4926"""
+
+ eq_(
+ connection.scalar(select(literal(15) / literal(10))),
+ 1.5,
+ )
+
+ def test_floordiv_integer_bound(self, connection):
+ """test #4926"""
+
+ eq_(
+ connection.scalar(select(literal(15) // literal(10))),
+ 1,
+ )
+
+
class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
__backend__ = True
@@ -1439,6 +1524,7 @@ __all__ = (
"TimeMicrosecondsTest",
"TimestampMicrosecondsTest",
"TimeTest",
+ "TrueDivTest",
"DateTimeMicrosecondsTest",
"DateHistoricTest",
"StringTest",