diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2023-01-28 19:50:25 -0500 |
---|---|---|
committer | Federico Caselli <cfederico87@gmail.com> | 2023-01-30 22:28:53 +0100 |
commit | d23dcbaea2a8e000c5fa2ba443e1b683b3b79fa6 (patch) | |
tree | 6b89a07b8bda5a469bf6c8dde165101315f571ed /lib/sqlalchemy/sql | |
parent | b99b0c522ddb94468da27867ddfa1f7e2633c920 (diff) | |
download | sqlalchemy-d23dcbaea2a8e000c5fa2ba443e1b683b3b79fa6.tar.gz |
don't count / gather INSERT bind names inside of a CTE
Fixed regression related to the implementation for the new
"insertmanyvalues" feature where an internal ``TypeError`` would occur in
arrangements where a :func:`_sql.insert` would be referred towards inside
of another :func:`_sql.insert` via a CTE; made additional repairs for this
use case for positional dialects such as asyncpg when using
"insertmanyvalues".
at the core here is a change to positional insertmanyvalues
where we now get exactly the positions for the "manyvalues" within
the larger list, allowing non-"manyvalues" on the left and right
sides at the same time, not assuming anything about how RETURNING
renders etc., since CTEs are in the mix also.
Fixes: #9173
Change-Id: I5ff071fbef0d92a2d6046b9c4e609bb008438afd
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 137 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 11 |
2 files changed, 105 insertions, 43 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 2c50081fb..d4ddc2e5d 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1545,12 +1545,12 @@ class SQLCompiler(Compiled): self.positiontup = list(param_pos) if self.escaped_bind_names: - reverse_escape = {v: k for k, v in self.escaped_bind_names.items()} - assert len(self.escaped_bind_names) == len(reverse_escape) + len_before = len(param_pos) param_pos = { self.escaped_bind_names.get(name, name): pos for name, pos in param_pos.items() } + assert len(param_pos) == len_before # Can't use format here since % chars are not escaped. self.string = self._pyformat_pattern.sub( @@ -3374,7 +3374,6 @@ class SQLCompiler(Compiled): skip_bind_expression=False, literal_execute=False, render_postcompile=False, - accumulate_bind_names=None, **kwargs, ): if not skip_bind_expression: @@ -3388,7 +3387,6 @@ class SQLCompiler(Compiled): literal_binds=literal_binds and not bindparam.expanding, literal_execute=literal_execute, render_postcompile=render_postcompile, - accumulate_bind_names=accumulate_bind_names, **kwargs, ) if bindparam.expanding: @@ -3490,9 +3488,6 @@ class SQLCompiler(Compiled): self.binds[bindparam.key] = self.binds[name] = bindparam - if accumulate_bind_names is not None: - accumulate_bind_names.add(name) - # if we are given a cache key that we're going to match against, # relate the bindparam here to one that is most likely present # in the "extracted params" portion of the cache key. this is used @@ -3646,11 +3641,19 @@ class SQLCompiler(Compiled): expanding: bool = False, escaped_from: Optional[str] = None, bindparam_type: Optional[TypeEngine[Any]] = None, + accumulate_bind_names: Optional[Set[str]] = None, + visited_bindparam: Optional[List[str]] = None, **kw: Any, ) -> str: - if self._visited_bindparam is not None: - self._visited_bindparam.append(name) + # TODO: accumulate_bind_names is passed by crud.py to gather + # names on a per-value basis, visited_bindparam is passed by + # visit_insert() to collect all parameters in the statement. + # see if this gathering can be simplified somehow + if accumulate_bind_names is not None: + accumulate_bind_names.add(name) + if visited_bindparam is not None: + visited_bindparam.append(name) if not escaped_from: @@ -5086,6 +5089,8 @@ class SQLCompiler(Compiled): assert insert_crud_params is not None escaped_bind_names: Mapping[str, str] + expand_pos_lower_index = expand_pos_upper_index = 0 + if not self.positional: if self.escaped_bind_names: escaped_bind_names = self.escaped_bind_names @@ -5124,6 +5129,31 @@ class SQLCompiler(Compiled): keys_to_replace = set() base_parameters = {} executemany_values_w_comma = f"({imv.single_values_expr}), " + + all_names_we_will_expand: Set[str] = set() + for elem in imv.insert_crud_params: + all_names_we_will_expand.update(elem[3]) + + # get the start and end position in a particular list + # of parameters where we will be doing the "expanding". + # statements can have params on either side or both sides, + # given RETURNING and CTEs + if all_names_we_will_expand: + positiontup = self.positiontup + assert positiontup is not None + + all_expand_positions = { + idx + for idx, name in enumerate(positiontup) + if name in all_names_we_will_expand + } + expand_pos_lower_index = min(all_expand_positions) + expand_pos_upper_index = max(all_expand_positions) + 1 + assert ( + len(all_expand_positions) + == expand_pos_upper_index - expand_pos_lower_index + ) + if self._numeric_binds: escaped = re.escape(self._numeric_binds_identifier_char) executemany_values_w_comma = re.sub( @@ -5149,52 +5179,61 @@ class SQLCompiler(Compiled): replaced_parameters: Any if self.positional: - # the assumption here is that any parameters that are not - # in the VALUES clause are expected to be parameterized - # expressions in the RETURNING (or maybe ON CONFLICT) clause. - # So based on - # which sequence comes first in the compiler's INSERT - # statement tells us where to expand the parameters. - - # otherwise we probably shouldn't be doing insertmanyvalues - # on the statement. - num_ins_params = imv.num_positional_params_counted batch_iterator: Iterable[Tuple[Any, ...]] if num_ins_params == len(batch[0]): - extra_params = () + extra_params_left = extra_params_right = () batch_iterator = batch - elif self.returning_precedes_values or self._numeric_binds: - extra_params = batch[0][:-num_ins_params] - batch_iterator = (b[-num_ins_params:] for b in batch) else: - extra_params = batch[0][num_ins_params:] - batch_iterator = (b[:num_ins_params] for b in batch) + extra_params_left = batch[0][:expand_pos_lower_index] + extra_params_right = batch[0][expand_pos_upper_index:] + batch_iterator = ( + b[expand_pos_lower_index:expand_pos_upper_index] + for b in batch + ) + + expanded_values_string = ( + executemany_values_w_comma * len(batch) + )[:-2] - values_string = (executemany_values_w_comma * len(batch))[:-2] if self._numeric_binds and num_ins_params > 0: + # numeric will always number the parameters inside of + # VALUES (and thus order self.positiontup) to be higher + # than non-VALUES parameters, no matter where in the + # statement those non-VALUES parameters appear (this is + # ensured in _process_numeric by numbering first all + # params that are not in _values_bindparam) + # therefore all extra params are always + # on the left side and numbered lower than the VALUES + # parameters + assert not extra_params_right + + start = expand_pos_lower_index + 1 + end = num_ins_params * (len(batch)) + start + # need to format here, since statement may contain # unescaped %, while values_string contains just (%s, %s) - start = len(extra_params) + 1 - end = num_ins_params * len(batch) + start positions = tuple( f"{self._numeric_binds_identifier_char}{i}" for i in range(start, end) ) - values_string = values_string % positions + expanded_values_string = expanded_values_string % positions replaced_statement = statement.replace( - "__EXECMANY_TOKEN__", values_string + "__EXECMANY_TOKEN__", expanded_values_string ) replaced_parameters = tuple( itertools.chain.from_iterable(batch_iterator) ) - if self.returning_precedes_values or self._numeric_binds: - replaced_parameters = extra_params + replaced_parameters - else: - replaced_parameters = replaced_parameters + extra_params + + replaced_parameters = ( + extra_params_left + + replaced_parameters + + extra_params_right + ) + else: replaced_values_clauses = [] replaced_parameters = base_parameters.copy() @@ -5224,7 +5263,7 @@ class SQLCompiler(Compiled): ) batchnum += 1 - def visit_insert(self, insert_stmt, **kw): + def visit_insert(self, insert_stmt, visited_bindparam=None, **kw): compile_state = insert_stmt._compile_state_factory( insert_stmt, self, **kw @@ -5250,6 +5289,9 @@ class SQLCompiler(Compiled): counted_bindparam = 0 + # reset any incoming "visited_bindparam" collection + visited_bindparam = None + # for positional, insertmanyvalues needs to know how many # bound parameters are in the VALUES sequence; there's no simple # rule because default expressions etc. can have zero or more @@ -5257,21 +5299,30 @@ class SQLCompiler(Compiled): # this very simplistic "count after" works and is # likely the least amount of callcounts, though looks clumsy if self.positional: - self._visited_bindparam = [] + # if we are inside a CTE, don't count parameters + # here since they wont be for insertmanyvalues. keep + # visited_bindparam at None so no counting happens. + # see #9173 + has_visiting_cte = "visiting_cte" in kw + if not has_visiting_cte: + visited_bindparam = [] crud_params_struct = crud._get_crud_params( - self, insert_stmt, compile_state, toplevel, **kw + self, + insert_stmt, + compile_state, + toplevel, + visited_bindparam=visited_bindparam, + **kw, ) - if self.positional: - assert self._visited_bindparam is not None - counted_bindparam = len(self._visited_bindparam) + if self.positional and visited_bindparam is not None: + counted_bindparam = len(visited_bindparam) if self._numeric_binds: if self._values_bindparam is not None: - self._values_bindparam += self._visited_bindparam + self._values_bindparam += visited_bindparam else: - self._values_bindparam = self._visited_bindparam - self._visited_bindparam = None + self._values_bindparam = visited_bindparam crud_params_single = crud_params_struct.single_params diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 5017afa78..04b62d1ff 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -150,6 +150,17 @@ def _get_crud_params( compiler.update_prefetch = [] compiler.implicit_returning = [] + visiting_cte = kw.get("visiting_cte", None) + if visiting_cte is not None: + # for insert -> CTE -> insert, don't populate an incoming + # _crud_accumulate_bind_names collection; the INSERT we process here + # will not be inline within the VALUES of the enclosing INSERT as the + # CTE is placed on the outside. See issue #9173 + kw.pop("accumulate_bind_names", None) + assert ( + "accumulate_bind_names" not in kw + ), "Don't know how to handle insert within insert without a CTE" + # getters - these are normally just column.key, # but in the case of mysql multi-table update, the rules for # .key must conditionally take tablename into account |