diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 31 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 51 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/fixtures.py | 4 |
3 files changed, 43 insertions, 43 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index fc114efa3..c188e155c 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -1001,14 +1001,31 @@ class DefaultExecutionContext(ExecutionContext): self.parameters = core_positional_parameters else: core_dict_parameters: MutableSequence[Dict[str, Any]] = [] + escaped_names = compiled.escaped_bind_names + + # note that currently, "expanded" parameters will be present + # in self.compiled_parameters in their quoted form. This is + # slightly inconsistent with the approach taken as of + # #8056 where self.compiled_parameters is meant to contain unquoted + # param names. + d_param: Dict[str, Any] for compiled_params in self.compiled_parameters: - - d_param: Dict[str, Any] = { - key: flattened_processors[key](compiled_params[key]) - if key in flattened_processors - else compiled_params[key] - for key in compiled_params - } + if escaped_names: + d_param = { + escaped_names.get(key, key): flattened_processors[key]( + compiled_params[key] + ) + if key in flattened_processors + else compiled_params[key] + for key in compiled_params + } + else: + d_param = { + key: flattened_processors[key](compiled_params[key]) + if key in flattened_processors + else compiled_params[key] + for key in compiled_params + } core_dict_parameters.append(d_param) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 63ed45a96..12a598717 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1143,16 +1143,11 @@ class SQLCompiler(Compiled): str, Union[_BindProcessorType[Any], Sequence[_BindProcessorType[Any]]] ]: - _escaped_bind_names = self.escaped_bind_names - has_escaped_names = bool(_escaped_bind_names) - # mypy is not able to see the two value types as the above Union, # it just sees "object". don't know how to resolve return dict( ( - _escaped_bind_names.get(key, key) - if has_escaped_names - else key, + key, value, ) # type: ignore for key, value in ( @@ -1186,8 +1181,6 @@ class SQLCompiler(Compiled): ) -> _MutableCoreSingleExecuteParams: """return a dictionary of bind parameter keys and values""" - has_escaped_names = bool(self.escaped_bind_names) - if extracted_parameters: # related the bound parameters collected in the original cache key # to those collected in the incoming cache key. They will not have @@ -1217,16 +1210,10 @@ class SQLCompiler(Compiled): if params: pd = {} for bindparam, name in self.bind_names.items(): - escaped_name = ( - self.escaped_bind_names.get(name, name) - if has_escaped_names - else name - ) - if bindparam.key in params: - pd[escaped_name] = params[bindparam.key] + pd[name] = params[bindparam.key] elif name in params: - pd[escaped_name] = params[name] + pd[name] = params[name] elif _check and bindparam.required: if _group_number: @@ -1251,19 +1238,13 @@ class SQLCompiler(Compiled): value_param = bindparam if bindparam.callable: - pd[escaped_name] = value_param.effective_value + pd[name] = value_param.effective_value else: - pd[escaped_name] = value_param.value + pd[name] = value_param.value return pd else: pd = {} for bindparam, name in self.bind_names.items(): - escaped_name = ( - self.escaped_bind_names.get(name, name) - if has_escaped_names - else name - ) - if _check and bindparam.required: if _group_number: raise exc.InvalidRequestError( @@ -1285,9 +1266,9 @@ class SQLCompiler(Compiled): value_param = bindparam if bindparam.callable: - pd[escaped_name] = value_param.effective_value + pd[name] = value_param.effective_value else: - pd[escaped_name] = value_param.value + pd[name] = value_param.value return pd @util.memoized_instancemethod @@ -1359,6 +1340,7 @@ class SQLCompiler(Compiled): N as a bound parameter. """ + if parameters is None: parameters = self.construct_params() @@ -1435,7 +1417,12 @@ class SQLCompiler(Compiled): # process it. the single name is being replaced with # individual numbered parameters for each value in the # param. - values = parameters.pop(escaped_name) + # + # note we are also inserting *escaped* parameter names + # into the given dictionary. default dialect will + # use these param names directly as they will not be + # in the escaped_bind_names dictionary. + values = parameters.pop(name) leep = self._literal_execute_expanding_parameter to_update, replacement_expr = leep( @@ -1541,15 +1528,7 @@ class SQLCompiler(Compiled): @util.memoized_property def _within_exec_param_key_getter(self) -> Callable[[Any], str]: getter = self._get_bind_name_for_col - if self.escaped_bind_names: - - def _get(obj): - key = getter(obj) - return self.escaped_bind_names.get(key, key) - - return _get - else: - return getter + return getter @util.memoized_property @util.preload_module("sqlalchemy.engine.result") diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index 4b5366186..ae7a42488 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -178,6 +178,10 @@ class TestBase: return go + @config.fixture + def fixture_session(self): + return fixture_session() + @config.fixture() def metadata(self, request): """Provide bound MetaData for a single test, dropping afterwards.""" |
