diff options
Diffstat (limited to 'test/sql/test_operators.py')
| -rw-r--r-- | test/sql/test_operators.py | 69 |
1 files changed, 69 insertions, 0 deletions
diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index a19eb20bc..f6a13f8ca 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -326,6 +326,60 @@ class CustomUnaryOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): return MyInteger + @testing.fixture + def modulus(self): + class MyInteger(Integer): + class comparator_factory(Integer.Comparator): + def modulus(self): + return UnaryExpression( + self.expr, + modifier=operators.custom_op("%"), + type_=MyInteger, + ) + + def modulus_prefix(self): + return UnaryExpression( + self.expr, + operator=operators.custom_op("%"), + type_=MyInteger, + ) + + return MyInteger + + @testing.combinations( + ("format",), + ("qmark",), + ("named",), + ("pyformat",), + argnames="paramstyle", + ) + def test_modulus(self, modulus, paramstyle): + col = column("somecol", modulus()) + self.assert_compile( + col.modulus(), + "somecol %%" + if paramstyle in ("format", "pyformat") + else "somecol %", + dialect=default.DefaultDialect(paramstyle=paramstyle), + ) + + @testing.combinations( + ("format",), + ("qmark",), + ("named",), + ("pyformat",), + argnames="paramstyle", + ) + def test_modulus_prefix(self, modulus, paramstyle): + col = column("somecol", modulus()) + self.assert_compile( + col.modulus_prefix(), + "%% somecol" + if paramstyle in ("format", "pyformat") + else "% somecol", + dialect=default.DefaultDialect(paramstyle=paramstyle), + ) + def test_factorial(self, factorial): col = column("somecol", factorial()) self.assert_compile(col.factorial(), "somecol !") @@ -1950,6 +2004,21 @@ class MathOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): ): self.assert_compile(py_op(lhs, rhs), res % sql_op) + @testing.combinations( + ("format", "mytable.myid %% %s"), + ("qmark", "mytable.myid % ?"), + ("named", "mytable.myid % :myid_1"), + ("pyformat", "mytable.myid %% %(myid_1)s"), + ) + def test_custom_op_percent_escaping(self, paramstyle, expected): + expr = self.table1.c.myid.op("%")(5) + + self.assert_compile( + expr, + expected, + dialect=default.DefaultDialect(paramstyle=paramstyle), + ) + class ComparisonOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default" |
