summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2022-12-16 15:08:18 +0000
committerGerrit Code Review <gerrit@ci3.zzzcomputing.com>2022-12-16 15:08:18 +0000
commit5bb48511a126b66ed06abf76d706ab707afafbf1 (patch)
tree17b1a4bf31be3a2a3aa18478cdd710f40270cf18 /lib/sqlalchemy/sql/compiler.py
parenta41d0dc4bdcc698643b6a4d76f265f5aa4765bee (diff)
parente06ef2154210ce1a5ced6f58330a258f7adfaa55 (diff)
downloadsqlalchemy-5bb48511a126b66ed06abf76d706ab707afafbf1.tar.gz
Merge "implement literal_binds with expanding + bind_expression" into main
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py64
1 files changed, 56 insertions, 8 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 7aa89869e..66a294d10 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -2871,12 +2871,16 @@ class SQLCompiler(Compiled):
)
def _literal_execute_expanding_parameter_literal_binds(
- self, parameter, values
+ self, parameter, values, bind_expression_template=None
):
typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect)
if not values:
+ # empty IN expression. note we don't need to use
+ # bind_expression_template here because there are no
+ # expressions to render.
+
if typ_dialect_impl._is_tuple_type:
replacement_expression = (
"VALUES " if self.dialect.tuple_in_values else ""
@@ -2895,6 +2899,12 @@ class SQLCompiler(Compiled):
and not isinstance(values[0], (str, bytes))
):
+ if typ_dialect_impl._has_bind_expression:
+ raise NotImplementedError(
+ "bind_expression() on TupleType not supported with "
+ "literal_binds"
+ )
+
replacement_expression = (
"VALUES " if self.dialect.tuple_in_values else ""
) + ", ".join(
@@ -2910,10 +2920,29 @@ class SQLCompiler(Compiled):
for i, tuple_element in enumerate(values)
)
else:
- replacement_expression = ", ".join(
- self.render_literal_value(value, parameter.type)
- for value in values
- )
+ if bind_expression_template:
+ post_compile_pattern = self._post_compile_pattern
+ m = post_compile_pattern.search(bind_expression_template)
+ assert m and m.group(
+ 2
+ ), "unexpected format for expanding parameter"
+
+ tok = m.group(2).split("~~")
+ be_left, be_right = tok[1], tok[3]
+ replacement_expression = ", ".join(
+ "%s%s%s"
+ % (
+ be_left,
+ self.render_literal_value(value, parameter.type),
+ be_right,
+ )
+ for value in values
+ )
+ else:
+ replacement_expression = ", ".join(
+ self.render_literal_value(value, parameter.type)
+ for value in values
+ )
return (), replacement_expression
@@ -3293,7 +3322,7 @@ class SQLCompiler(Compiled):
bind_expression,
skip_bind_expression=True,
within_columns_clause=within_columns_clause,
- literal_binds=literal_binds,
+ literal_binds=literal_binds and not bindparam.expanding,
literal_execute=literal_execute,
render_postcompile=render_postcompile,
accumulate_bind_names=accumulate_bind_names,
@@ -3302,6 +3331,7 @@ class SQLCompiler(Compiled):
if bindparam.expanding:
# for postcompile w/ expanding, move the "wrapped" part
# of this into the inside
+
m = re.match(
r"^(.*)\(__\[POSTCOMPILE_(\S+?)\]\)(.*)$", wrapped
)
@@ -3311,6 +3341,16 @@ class SQLCompiler(Compiled):
m.group(1),
m.group(3),
)
+
+ if literal_binds:
+ ret = self.render_literal_bindparam(
+ bindparam,
+ within_columns_clause=True,
+ bind_expression_template=wrapped,
+ **kwargs,
+ )
+ return "(%s)" % ret
+
return wrapped
if not literal_binds:
@@ -3436,7 +3476,11 @@ class SQLCompiler(Compiled):
raise NotImplementedError()
def render_literal_bindparam(
- self, bindparam, render_literal_value=NO_ARG, **kw
+ self,
+ bindparam,
+ render_literal_value=NO_ARG,
+ bind_expression_template=None,
+ **kw,
):
if render_literal_value is not NO_ARG:
value = render_literal_value
@@ -3455,7 +3499,11 @@ class SQLCompiler(Compiled):
if bindparam.expanding:
leep = self._literal_execute_expanding_parameter_literal_binds
- to_update, replacement_expr = leep(bindparam, value)
+ to_update, replacement_expr = leep(
+ bindparam,
+ value,
+ bind_expression_template=bind_expression_template,
+ )
return replacement_expr
else:
return self.render_literal_value(value, bindparam.type)