summaryrefslogtreecommitdiff
path: root/test/sql/test_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/sql/test_operators.py')
-rw-r--r--test/sql/test_operators.py69
1 files changed, 69 insertions, 0 deletions
diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py
index a19eb20bc..f6a13f8ca 100644
--- a/test/sql/test_operators.py
+++ b/test/sql/test_operators.py
@@ -326,6 +326,60 @@ class CustomUnaryOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
return MyInteger
+ @testing.fixture
+ def modulus(self):
+ class MyInteger(Integer):
+ class comparator_factory(Integer.Comparator):
+ def modulus(self):
+ return UnaryExpression(
+ self.expr,
+ modifier=operators.custom_op("%"),
+ type_=MyInteger,
+ )
+
+ def modulus_prefix(self):
+ return UnaryExpression(
+ self.expr,
+ operator=operators.custom_op("%"),
+ type_=MyInteger,
+ )
+
+ return MyInteger
+
+ @testing.combinations(
+ ("format",),
+ ("qmark",),
+ ("named",),
+ ("pyformat",),
+ argnames="paramstyle",
+ )
+ def test_modulus(self, modulus, paramstyle):
+ col = column("somecol", modulus())
+ self.assert_compile(
+ col.modulus(),
+ "somecol %%"
+ if paramstyle in ("format", "pyformat")
+ else "somecol %",
+ dialect=default.DefaultDialect(paramstyle=paramstyle),
+ )
+
+ @testing.combinations(
+ ("format",),
+ ("qmark",),
+ ("named",),
+ ("pyformat",),
+ argnames="paramstyle",
+ )
+ def test_modulus_prefix(self, modulus, paramstyle):
+ col = column("somecol", modulus())
+ self.assert_compile(
+ col.modulus_prefix(),
+ "%% somecol"
+ if paramstyle in ("format", "pyformat")
+ else "% somecol",
+ dialect=default.DefaultDialect(paramstyle=paramstyle),
+ )
+
def test_factorial(self, factorial):
col = column("somecol", factorial())
self.assert_compile(col.factorial(), "somecol !")
@@ -1950,6 +2004,21 @@ class MathOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
):
self.assert_compile(py_op(lhs, rhs), res % sql_op)
+ @testing.combinations(
+ ("format", "mytable.myid %% %s"),
+ ("qmark", "mytable.myid % ?"),
+ ("named", "mytable.myid % :myid_1"),
+ ("pyformat", "mytable.myid %% %(myid_1)s"),
+ )
+ def test_custom_op_percent_escaping(self, paramstyle, expected):
+ expr = self.table1.c.myid.op("%")(5)
+
+ self.assert_compile(
+ expr,
+ expected,
+ dialect=default.DefaultDialect(paramstyle=paramstyle),
+ )
+
class ComparisonOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
__dialect__ = "default"