summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2023-01-28 19:50:25 -0500
committerFederico Caselli <cfederico87@gmail.com>2023-01-30 22:28:53 +0100
commitd23dcbaea2a8e000c5fa2ba443e1b683b3b79fa6 (patch)
tree6b89a07b8bda5a469bf6c8dde165101315f571ed /lib/sqlalchemy/sql/compiler.py
parentb99b0c522ddb94468da27867ddfa1f7e2633c920 (diff)
downloadsqlalchemy-d23dcbaea2a8e000c5fa2ba443e1b683b3b79fa6.tar.gz
don't count / gather INSERT bind names inside of a CTE
Fixed regression related to the implementation for the new "insertmanyvalues" feature where an internal ``TypeError`` would occur in arrangements where a :func:`_sql.insert` would be referred towards inside of another :func:`_sql.insert` via a CTE; made additional repairs for this use case for positional dialects such as asyncpg when using "insertmanyvalues". at the core here is a change to positional insertmanyvalues where we now get exactly the positions for the "manyvalues" within the larger list, allowing non-"manyvalues" on the left and right sides at the same time, not assuming anything about how RETURNING renders etc., since CTEs are in the mix also. Fixes: #9173 Change-Id: I5ff071fbef0d92a2d6046b9c4e609bb008438afd
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py137
1 files changed, 94 insertions, 43 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 2c50081fb..d4ddc2e5d 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1545,12 +1545,12 @@ class SQLCompiler(Compiled):
self.positiontup = list(param_pos)
if self.escaped_bind_names:
- reverse_escape = {v: k for k, v in self.escaped_bind_names.items()}
- assert len(self.escaped_bind_names) == len(reverse_escape)
+ len_before = len(param_pos)
param_pos = {
self.escaped_bind_names.get(name, name): pos
for name, pos in param_pos.items()
}
+ assert len(param_pos) == len_before
# Can't use format here since % chars are not escaped.
self.string = self._pyformat_pattern.sub(
@@ -3374,7 +3374,6 @@ class SQLCompiler(Compiled):
skip_bind_expression=False,
literal_execute=False,
render_postcompile=False,
- accumulate_bind_names=None,
**kwargs,
):
if not skip_bind_expression:
@@ -3388,7 +3387,6 @@ class SQLCompiler(Compiled):
literal_binds=literal_binds and not bindparam.expanding,
literal_execute=literal_execute,
render_postcompile=render_postcompile,
- accumulate_bind_names=accumulate_bind_names,
**kwargs,
)
if bindparam.expanding:
@@ -3490,9 +3488,6 @@ class SQLCompiler(Compiled):
self.binds[bindparam.key] = self.binds[name] = bindparam
- if accumulate_bind_names is not None:
- accumulate_bind_names.add(name)
-
# if we are given a cache key that we're going to match against,
# relate the bindparam here to one that is most likely present
# in the "extracted params" portion of the cache key. this is used
@@ -3646,11 +3641,19 @@ class SQLCompiler(Compiled):
expanding: bool = False,
escaped_from: Optional[str] = None,
bindparam_type: Optional[TypeEngine[Any]] = None,
+ accumulate_bind_names: Optional[Set[str]] = None,
+ visited_bindparam: Optional[List[str]] = None,
**kw: Any,
) -> str:
- if self._visited_bindparam is not None:
- self._visited_bindparam.append(name)
+ # TODO: accumulate_bind_names is passed by crud.py to gather
+ # names on a per-value basis, visited_bindparam is passed by
+ # visit_insert() to collect all parameters in the statement.
+ # see if this gathering can be simplified somehow
+ if accumulate_bind_names is not None:
+ accumulate_bind_names.add(name)
+ if visited_bindparam is not None:
+ visited_bindparam.append(name)
if not escaped_from:
@@ -5086,6 +5089,8 @@ class SQLCompiler(Compiled):
assert insert_crud_params is not None
escaped_bind_names: Mapping[str, str]
+ expand_pos_lower_index = expand_pos_upper_index = 0
+
if not self.positional:
if self.escaped_bind_names:
escaped_bind_names = self.escaped_bind_names
@@ -5124,6 +5129,31 @@ class SQLCompiler(Compiled):
keys_to_replace = set()
base_parameters = {}
executemany_values_w_comma = f"({imv.single_values_expr}), "
+
+ all_names_we_will_expand: Set[str] = set()
+ for elem in imv.insert_crud_params:
+ all_names_we_will_expand.update(elem[3])
+
+ # get the start and end position in a particular list
+ # of parameters where we will be doing the "expanding".
+ # statements can have params on either side or both sides,
+ # given RETURNING and CTEs
+ if all_names_we_will_expand:
+ positiontup = self.positiontup
+ assert positiontup is not None
+
+ all_expand_positions = {
+ idx
+ for idx, name in enumerate(positiontup)
+ if name in all_names_we_will_expand
+ }
+ expand_pos_lower_index = min(all_expand_positions)
+ expand_pos_upper_index = max(all_expand_positions) + 1
+ assert (
+ len(all_expand_positions)
+ == expand_pos_upper_index - expand_pos_lower_index
+ )
+
if self._numeric_binds:
escaped = re.escape(self._numeric_binds_identifier_char)
executemany_values_w_comma = re.sub(
@@ -5149,52 +5179,61 @@ class SQLCompiler(Compiled):
replaced_parameters: Any
if self.positional:
- # the assumption here is that any parameters that are not
- # in the VALUES clause are expected to be parameterized
- # expressions in the RETURNING (or maybe ON CONFLICT) clause.
- # So based on
- # which sequence comes first in the compiler's INSERT
- # statement tells us where to expand the parameters.
-
- # otherwise we probably shouldn't be doing insertmanyvalues
- # on the statement.
-
num_ins_params = imv.num_positional_params_counted
batch_iterator: Iterable[Tuple[Any, ...]]
if num_ins_params == len(batch[0]):
- extra_params = ()
+ extra_params_left = extra_params_right = ()
batch_iterator = batch
- elif self.returning_precedes_values or self._numeric_binds:
- extra_params = batch[0][:-num_ins_params]
- batch_iterator = (b[-num_ins_params:] for b in batch)
else:
- extra_params = batch[0][num_ins_params:]
- batch_iterator = (b[:num_ins_params] for b in batch)
+ extra_params_left = batch[0][:expand_pos_lower_index]
+ extra_params_right = batch[0][expand_pos_upper_index:]
+ batch_iterator = (
+ b[expand_pos_lower_index:expand_pos_upper_index]
+ for b in batch
+ )
+
+ expanded_values_string = (
+ executemany_values_w_comma * len(batch)
+ )[:-2]
- values_string = (executemany_values_w_comma * len(batch))[:-2]
if self._numeric_binds and num_ins_params > 0:
+ # numeric will always number the parameters inside of
+ # VALUES (and thus order self.positiontup) to be higher
+ # than non-VALUES parameters, no matter where in the
+ # statement those non-VALUES parameters appear (this is
+ # ensured in _process_numeric by numbering first all
+ # params that are not in _values_bindparam)
+ # therefore all extra params are always
+ # on the left side and numbered lower than the VALUES
+ # parameters
+ assert not extra_params_right
+
+ start = expand_pos_lower_index + 1
+ end = num_ins_params * (len(batch)) + start
+
# need to format here, since statement may contain
# unescaped %, while values_string contains just (%s, %s)
- start = len(extra_params) + 1
- end = num_ins_params * len(batch) + start
positions = tuple(
f"{self._numeric_binds_identifier_char}{i}"
for i in range(start, end)
)
- values_string = values_string % positions
+ expanded_values_string = expanded_values_string % positions
replaced_statement = statement.replace(
- "__EXECMANY_TOKEN__", values_string
+ "__EXECMANY_TOKEN__", expanded_values_string
)
replaced_parameters = tuple(
itertools.chain.from_iterable(batch_iterator)
)
- if self.returning_precedes_values or self._numeric_binds:
- replaced_parameters = extra_params + replaced_parameters
- else:
- replaced_parameters = replaced_parameters + extra_params
+
+ replaced_parameters = (
+ extra_params_left
+ + replaced_parameters
+ + extra_params_right
+ )
+
else:
replaced_values_clauses = []
replaced_parameters = base_parameters.copy()
@@ -5224,7 +5263,7 @@ class SQLCompiler(Compiled):
)
batchnum += 1
- def visit_insert(self, insert_stmt, **kw):
+ def visit_insert(self, insert_stmt, visited_bindparam=None, **kw):
compile_state = insert_stmt._compile_state_factory(
insert_stmt, self, **kw
@@ -5250,6 +5289,9 @@ class SQLCompiler(Compiled):
counted_bindparam = 0
+ # reset any incoming "visited_bindparam" collection
+ visited_bindparam = None
+
# for positional, insertmanyvalues needs to know how many
# bound parameters are in the VALUES sequence; there's no simple
# rule because default expressions etc. can have zero or more
@@ -5257,21 +5299,30 @@ class SQLCompiler(Compiled):
# this very simplistic "count after" works and is
# likely the least amount of callcounts, though looks clumsy
if self.positional:
- self._visited_bindparam = []
+ # if we are inside a CTE, don't count parameters
+ # here since they wont be for insertmanyvalues. keep
+ # visited_bindparam at None so no counting happens.
+ # see #9173
+ has_visiting_cte = "visiting_cte" in kw
+ if not has_visiting_cte:
+ visited_bindparam = []
crud_params_struct = crud._get_crud_params(
- self, insert_stmt, compile_state, toplevel, **kw
+ self,
+ insert_stmt,
+ compile_state,
+ toplevel,
+ visited_bindparam=visited_bindparam,
+ **kw,
)
- if self.positional:
- assert self._visited_bindparam is not None
- counted_bindparam = len(self._visited_bindparam)
+ if self.positional and visited_bindparam is not None:
+ counted_bindparam = len(visited_bindparam)
if self._numeric_binds:
if self._values_bindparam is not None:
- self._values_bindparam += self._visited_bindparam
+ self._values_bindparam += visited_bindparam
else:
- self._values_bindparam = self._visited_bindparam
- self._visited_bindparam = None
+ self._values_bindparam = visited_bindparam
crud_params_single = crud_params_struct.single_params