summaryrefslogtreecommitdiff
path: root/test/sql/test_compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/sql/test_compiler.py')
-rw-r--r--test/sql/test_compiler.py118
1 files changed, 118 insertions, 0 deletions
diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py
index c71cfd61f..205ce5157 100644
--- a/test/sql/test_compiler.py
+++ b/test/sql/test_compiler.py
@@ -4880,6 +4880,124 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase):
stmt, expected, literal_binds=True, params=params
)
+ standalone_escape = testing.combinations(
+ ("normalname", "normalname"),
+ ("_name", "_name"),
+ ("[BracketsAndCase]", "_BracketsAndCase_"),
+ ("has spaces", "has_spaces"),
+ argnames="paramname, expected",
+ )
+
+ @standalone_escape
+ @testing.variation("use_positional", [True, False])
+ def test_standalone_bindparam_escape(
+ self, paramname, expected, use_positional
+ ):
+ stmt = select(table1.c.myid).where(
+ table1.c.name == bindparam(paramname, value="x")
+ )
+
+ if use_positional:
+ self.assert_compile(
+ stmt,
+ "SELECT mytable.myid FROM mytable WHERE mytable.name = ?",
+ params={paramname: "y"},
+ checkpositional=("y",),
+ dialect="sqlite",
+ )
+ else:
+ self.assert_compile(
+ stmt,
+ "SELECT mytable.myid FROM mytable WHERE mytable.name = :%s"
+ % (expected,),
+ params={paramname: "y"},
+ checkparams={expected: "y"},
+ dialect="default",
+ )
+
+ @standalone_escape
+ @testing.variation("use_assert_compile", [True, False])
+ @testing.variation("use_positional", [True, False])
+ def test_standalone_bindparam_escape_expanding(
+ self, paramname, expected, use_assert_compile, use_positional
+ ):
+ stmt = select(table1.c.myid).where(
+ table1.c.name.in_(bindparam(paramname, value=["a", "b"]))
+ )
+
+ if use_assert_compile:
+ if use_positional:
+ self.assert_compile(
+ stmt,
+ "SELECT mytable.myid FROM mytable "
+ "WHERE mytable.name IN (?, ?)",
+ params={paramname: ["y", "z"]},
+ # NOTE: this is what render_postcompile will do right now
+ # if you run construct_params(). render_postcompile mode
+ # is not actually used by the execution internals, it's for
+ # user-facing compilation code. So this is likely a
+ # current limitation of construct_params() which is not
+ # doing the full blown postcompile; just assert that's
+ # what it does for now. it likely should be corrected
+ # to make more sense.
+ checkpositional=(["y", "z"], ["y", "z"]),
+ dialect="sqlite",
+ render_postcompile=True,
+ )
+ else:
+ self.assert_compile(
+ stmt,
+ "SELECT mytable.myid FROM mytable WHERE mytable.name IN "
+ "(:%s_1, :%s_2)" % (expected, expected),
+ params={paramname: ["y", "z"]},
+ # NOTE: this is what render_postcompile will do right now
+ # if you run construct_params(). render_postcompile mode
+ # is not actually used by the execution internals, it's for
+ # user-facing compilation code. So this is likely a
+ # current limitation of construct_params() which is not
+ # doing the full blown postcompile; just assert that's
+ # what it does for now. it likely should be corrected
+ # to make more sense.
+ checkparams={
+ "%s_1" % expected: ["y", "z"],
+ "%s_2" % expected: ["y", "z"],
+ },
+ dialect="default",
+ render_postcompile=True,
+ )
+ else:
+ # this is what DefaultDialect actually does.
+ # this should be matched to DefaultDialect._init_compiled()
+ if use_positional:
+ compiled = stmt.compile(
+ dialect=default.DefaultDialect(paramstyle="qmark")
+ )
+ else:
+ compiled = stmt.compile(dialect=default.DefaultDialect())
+
+ checkparams = compiled.construct_params(
+ {paramname: ["y", "z"]}, escape_names=False
+ )
+
+ # nothing actually happened. if the compiler had
+ # render_postcompile set, the
+ # above weird param thing happens
+ eq_(checkparams, {paramname: ["y", "z"]})
+
+ expanded_state = compiled._process_parameters_for_postcompile(
+ checkparams
+ )
+ eq_(
+ expanded_state.additional_parameters,
+ {f"{expected}_1": "y", f"{expected}_2": "z"},
+ )
+
+ if use_positional:
+ eq_(
+ expanded_state.positiontup,
+ [f"{expected}_1", f"{expected}_2"],
+ )
+
class UnsupportedTest(fixtures.TestBase):
def test_unsupported_element_str_visit_name(self):