summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-12-15 10:22:36 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2022-12-15 10:34:05 -0500
commit84ba8874e146bcdbf46ce70ece32c4c224c3fd44 (patch)
tree668b79268cc172ad4292a1875b3d9d14c8eaf47e
parent6f6dc443663f6b63a8fe48b3504cae59cfbe9d56 (diff)
downloadsqlalchemy-84ba8874e146bcdbf46ce70ece32c4c224c3fd44.tar.gz
implement literal_binds with expanding + bind_expression
Fixed bug where SQL compilation would fail (assertion fail in 2.0, NoneType error in 1.4) when using an expression whose type included :meth:`_types.TypeEngine.bind_expression`, in the context of an "expanding" (i.e. "IN") parameter in conjunction with the ``literal_binds`` compiler parameter. Fixes: #8989 Change-Id: Ic9fd27b46381b488117295ea5a492d8fc158e39f (cherry picked from commit 8c6de3c2c43ab372cbbe76464b4c5be3b6457252)
-rw-r--r--doc/build/changelog/unreleased_14/8989.rst10
-rw-r--r--lib/sqlalchemy/sql/compiler.py69
-rw-r--r--test/sql/test_type_expressions.py48
3 files changed, 100 insertions, 27 deletions
diff --git a/doc/build/changelog/unreleased_14/8989.rst b/doc/build/changelog/unreleased_14/8989.rst
new file mode 100644
index 000000000..4c38fdf01
--- /dev/null
+++ b/doc/build/changelog/unreleased_14/8989.rst
@@ -0,0 +1,10 @@
+.. change::
+ :tags: bug, types
+ :tickets: 8989
+ :versions: 2.0.0b5
+
+ Fixed bug where SQL compilation would fail (assertion fail in 2.0, NoneType
+ error in 1.4) when using an expression whose type included
+ :meth:`_types.TypeEngine.bind_expression`, in the context of an "expanding"
+ (i.e. "IN") parameter in conjunction with the ``literal_binds`` compiler
+ parameter.
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 8fbf3092a..cb30c7773 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -699,6 +699,8 @@ class SQLCompiler(Compiled):
"""
+ _post_compile_pattern = re.compile(r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]")
+
positiontup = None
"""for a compiled construct that uses a positional paramstyle, will be
a sequence of strings, indicating the names of bound parameters in order.
@@ -1294,7 +1296,7 @@ class SQLCompiler(Compiled):
return expr
statement = re.sub(
- r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]",
+ self._post_compile_pattern,
process_expanding,
self.string,
)
@@ -2094,12 +2096,16 @@ class SQLCompiler(Compiled):
)
def _literal_execute_expanding_parameter_literal_binds(
- self, parameter, values
+ self, parameter, values, bind_expression_template=None
):
typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect)
if not values:
+ # empty IN expression. note we don't need to use
+ # bind_expression_template here because there are no
+ # expressions to render.
+
if typ_dialect_impl._is_tuple_type:
replacement_expression = (
"VALUES " if self.dialect.tuple_in_values else ""
@@ -2120,6 +2126,12 @@ class SQLCompiler(Compiled):
)
):
+ if typ_dialect_impl._has_bind_expression:
+ raise NotImplementedError(
+ "bind_expression() on TupleType not supported with "
+ "literal_binds"
+ )
+
replacement_expression = (
"VALUES " if self.dialect.tuple_in_values else ""
) + ", ".join(
@@ -2135,10 +2147,29 @@ class SQLCompiler(Compiled):
for i, tuple_element in enumerate(values)
)
else:
- replacement_expression = ", ".join(
- self.render_literal_value(value, parameter.type)
- for value in values
- )
+ if bind_expression_template:
+ post_compile_pattern = self._post_compile_pattern
+ m = post_compile_pattern.search(bind_expression_template)
+ assert m and m.group(
+ 2
+ ), "unexpected format for expanding parameter"
+
+ tok = m.group(2).split("~~")
+ be_left, be_right = tok[1], tok[3]
+ replacement_expression = ", ".join(
+ "%s%s%s"
+ % (
+ be_left,
+ self.render_literal_value(value, parameter.type),
+ be_right,
+ )
+ for value in values
+ )
+ else:
+ replacement_expression = ", ".join(
+ self.render_literal_value(value, parameter.type)
+ for value in values
+ )
return (), replacement_expression
@@ -2453,7 +2484,7 @@ class SQLCompiler(Compiled):
bind_expression,
skip_bind_expression=True,
within_columns_clause=within_columns_clause,
- literal_binds=literal_binds,
+ literal_binds=literal_binds and not bindparam.expanding,
literal_execute=literal_execute,
render_postcompile=render_postcompile,
**kwargs
@@ -2461,14 +2492,26 @@ class SQLCompiler(Compiled):
if bindparam.expanding:
# for postcompile w/ expanding, move the "wrapped" part
# of this into the inside
+
m = re.match(
r"^(.*)\(__\[POSTCOMPILE_(\S+?)\]\)(.*)$", wrapped
)
+ assert m, "unexpected format for expanding parameter"
wrapped = "(__[POSTCOMPILE_%s~~%s~~REPL~~%s~~])" % (
m.group(2),
m.group(1),
m.group(3),
)
+
+ if literal_binds:
+ ret = self.render_literal_bindparam(
+ bindparam,
+ within_columns_clause=True,
+ bind_expression_template=wrapped,
+ **kwargs
+ )
+ return "(%s)" % ret
+
return wrapped
if not literal_binds:
@@ -2568,7 +2611,11 @@ class SQLCompiler(Compiled):
return ret
def render_literal_bindparam(
- self, bindparam, render_literal_value=NO_ARG, **kw
+ self,
+ bindparam,
+ render_literal_value=NO_ARG,
+ bind_expression_template=None,
+ **kw
):
if render_literal_value is not NO_ARG:
value = render_literal_value
@@ -2587,7 +2634,11 @@ class SQLCompiler(Compiled):
if bindparam.expanding:
leep = self._literal_execute_expanding_parameter_literal_binds
- to_update, replacement_expr = leep(bindparam, value)
+ to_update, replacement_expr = leep(
+ bindparam,
+ value,
+ bind_expression_template=bind_expression_template,
+ )
return replacement_expr
else:
return self.render_literal_value(value, bindparam.type)
diff --git a/test/sql/test_type_expressions.py b/test/sql/test_type_expressions.py
index e0e0858a4..7c2192620 100644
--- a/test/sql/test_type_expressions.py
+++ b/test/sql/test_type_expressions.py
@@ -182,28 +182,40 @@ class SelectTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL):
"test_table WHERE test_table.y = lower(:y_1)",
)
- def test_in_binds(self):
+ @testing.variation(
+ "compile_opt", ["plain", "postcompile", "literal_binds"]
+ )
+ def test_in_binds(self, compile_opt):
table = self._fixture()
- self.assert_compile(
- select(table).where(
- table.c.y.in_(["hi", "there", "some", "expr"])
- ),
- "SELECT test_table.x, lower(test_table.y) AS y FROM "
- "test_table WHERE test_table.y IN "
- "(__[POSTCOMPILE_y_1~~lower(~~REPL~~)~~])",
- render_postcompile=False,
+ stmt = select(table).where(
+ table.c.y.in_(["hi", "there", "some", "expr"])
)
- self.assert_compile(
- select(table).where(
- table.c.y.in_(["hi", "there", "some", "expr"])
- ),
- "SELECT test_table.x, lower(test_table.y) AS y FROM "
- "test_table WHERE test_table.y IN "
- "(lower(:y_1_1), lower(:y_1_2), lower(:y_1_3), lower(:y_1_4))",
- render_postcompile=True,
- )
+ if compile_opt.plain:
+ self.assert_compile(
+ stmt,
+ "SELECT test_table.x, lower(test_table.y) AS y FROM "
+ "test_table WHERE test_table.y IN "
+ "(__[POSTCOMPILE_y_1~~lower(~~REPL~~)~~])",
+ render_postcompile=False,
+ )
+ elif compile_opt.postcompile:
+ self.assert_compile(
+ stmt,
+ "SELECT test_table.x, lower(test_table.y) AS y FROM "
+ "test_table WHERE test_table.y IN "
+ "(lower(:y_1_1), lower(:y_1_2), lower(:y_1_3), lower(:y_1_4))",
+ render_postcompile=True,
+ )
+ elif compile_opt.literal_binds:
+ self.assert_compile(
+ stmt,
+ "SELECT test_table.x, lower(test_table.y) AS y FROM "
+ "test_table WHERE test_table.y IN "
+ "(lower('hi'), lower('there'), lower('some'), lower('expr'))",
+ literal_binds=True,
+ )
def test_dialect(self):
table = self._fixture()