diff options
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 113 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 34 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/default_comparator.py | 20 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 20 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/requirements.py | 8 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/suite/test_select.py | 56 |
6 files changed, 235 insertions, 16 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 628e23c9e..d1b54ab01 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -552,6 +552,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext): # result column names _translate_colname = None + _expanded_parameters = util.immutabledict() + @classmethod def _init_ddl(cls, dialect, connection, dbapi_connection, compiled_ddl): """Initialize execution context for a DDLElement construct.""" @@ -645,6 +647,11 @@ class DefaultExecutionContext(interfaces.ExecutionContext): processors = compiled._bind_processors + if compiled.contains_expanding_parameters: + positiontup = self._expand_in_parameters(compiled, processors) + elif compiled.positional: + positiontup = self.compiled.positiontup + # Convert the dictionary of bind parameter values # into a dict or list to be sent to the DBAPI's # execute() or executemany() method. @@ -652,7 +659,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): if compiled.positional: for compiled_params in self.compiled_parameters: param = [] - for key in self.compiled.positiontup: + for key in positiontup: if key in processors: param.append(processors[key](compiled_params[key])) else: @@ -684,10 +691,97 @@ class DefaultExecutionContext(interfaces.ExecutionContext): ) parameters.append(param) + self.parameters = dialect.execute_sequence_format(parameters) return self + def _expand_in_parameters(self, compiled, processors): + """handle special 'expanding' parameters, IN tuples that are rendered + on a per-parameter basis for an otherwise fixed SQL statement string. + + """ + if self.executemany: + raise exc.InvalidRequestError( + "'expanding' parameters can't be used with " + "executemany()") + + if self.compiled.positional and self.compiled._numeric_binds: + # I'm not familiar with any DBAPI that uses 'numeric' + raise NotImplementedError( + "'expanding' bind parameters not supported with " + "'numeric' paramstyle at this time.") + + self._expanded_parameters = {} + + compiled_params = self.compiled_parameters[0] + if compiled.positional: + positiontup = [] + else: + positiontup = None + + replacement_expressions = {} + for name in ( + self.compiled.positiontup if compiled.positional + else self.compiled.binds + ): + parameter = self.compiled.binds[name] + if parameter.expanding: + values = compiled_params.pop(name) + if not values: + raise exc.InvalidRequestError( + "'expanding' parameters can't be used with an " + "empty list" + ) + elif isinstance(values[0], (tuple, list)): + to_update = [ + ("%s_%s_%s" % (name, i, j), value) + for i, tuple_element in enumerate(values, 1) + for j, value in enumerate(tuple_element, 1) + ] + replacement_expressions[name] = ", ".join( + "(%s)" % ", ".join( + self.compiled.bindtemplate % { + "name": + to_update[i * len(tuple_element) + j][0] + } + for j, value in enumerate(tuple_element) + ) + for i, tuple_element in enumerate(values) + + ) + else: + to_update = [ + ("%s_%s" % (name, i), value) + for i, value in enumerate(values, 1) + ] + replacement_expressions[name] = ", ".join( + self.compiled.bindtemplate % { + "name": key} + for key, value in to_update + ) + compiled_params.update(to_update) + processors.update( + (key, processors[name]) + for key in to_update if name in processors + ) + if compiled.positional: + positiontup.extend(name for name, value in to_update) + self._expanded_parameters[name] = [ + expand_key for expand_key, value in to_update] + elif compiled.positional: + positiontup.append(name) + + def process_expanding(m): + return replacement_expressions.pop(m.group(1)) + + self.statement = re.sub( + r"\[EXPANDING_(.+)\]", + process_expanding, + self.statement + ) + return positiontup + @classmethod def _init_statement(cls, dialect, connection, dbapi_connection, statement, parameters): @@ -1039,7 +1133,11 @@ class DefaultExecutionContext(interfaces.ExecutionContext): get_dbapi_type(self.dialect.dbapi) if dbtype is not None and \ (not exclude_types or dbtype not in exclude_types): - inputsizes.append(dbtype) + if key in self._expanded_parameters: + inputsizes.extend( + [dbtype] * len(self._expanded_parameters[key])) + else: + inputsizes.append(dbtype) try: self.cursor.setinputsizes(*inputsizes) except BaseException as e: @@ -1054,10 +1152,19 @@ class DefaultExecutionContext(interfaces.ExecutionContext): if dbtype is not None and \ (not exclude_types or dbtype not in exclude_types): if translate: + # TODO: this part won't work w/ the + # expanded_parameters feature, e.g. for cx_oracle + # quoted bound names key = translate.get(key, key) if not self.dialect.supports_unicode_binds: key = self.dialect._encoder(key)[0] - inputsizes[key] = dbtype + if key in self._expanded_parameters: + inputsizes.update( + (expand_key, dbtype) for expand_key + in self._expanded_parameters[key] + ) + else: + inputsizes[key] = dbtype try: self.cursor.setinputsizes(**inputsizes) except BaseException as e: diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index cc4248009..6da064797 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -350,6 +350,14 @@ class SQLCompiler(Compiled): columns with the table name (i.e. MySQL only) """ + contains_expanding_parameters = False + """True if we've encountered bindparam(..., expanding=True). + + These need to be converted before execution time against the + string statement. + + """ + ansi_bind_rules = False """SQL 92 doesn't allow bind parameters to be used in the columns clause of a SELECT, nor does it allow @@ -370,8 +378,14 @@ class SQLCompiler(Compiled): True unless using an unordered TextAsFrom. """ - insert_prefetch = update_prefetch = () + _numeric_binds = False + """ + True if paramstyle is "numeric". This paramstyle is trickier than + all the others. + """ + + insert_prefetch = update_prefetch = () def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): @@ -418,6 +432,7 @@ class SQLCompiler(Compiled): self.positional = dialect.positional if self.positional: self.positiontup = [] + self._numeric_binds = dialect.paramstyle == "numeric" self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle] self.ctes = None @@ -439,7 +454,7 @@ class SQLCompiler(Compiled): ) and statement._returning: self.returning = statement._returning - if self.positional and dialect.paramstyle == 'numeric': + if self.positional and self._numeric_binds: self._apply_numbered_params() @property @@ -492,7 +507,8 @@ class SQLCompiler(Compiled): return dict( (key, value) for key, value in ((self.bind_names[bindparam], - bindparam.type._cached_bind_processor(self.dialect)) + bindparam.type._cached_bind_processor(self.dialect) + ) for bindparam in self.bind_names) if value is not None ) @@ -1238,7 +1254,8 @@ class SQLCompiler(Compiled): self.binds[bindparam.key] = self.binds[name] = bindparam - return self.bindparam_string(name, **kwargs) + return self.bindparam_string( + name, expanding=bindparam.expanding, **kwargs) def render_literal_bindparam(self, bindparam, **kw): value = bindparam.effective_value @@ -1300,13 +1317,18 @@ class SQLCompiler(Compiled): self.anon_map[derived] = anonymous_counter + 1 return derived + "_" + str(anonymous_counter) - def bindparam_string(self, name, positional_names=None, **kw): + def bindparam_string( + self, name, positional_names=None, expanding=False, **kw): if self.positional: if positional_names is not None: positional_names.append(name) else: self.positiontup.append(name) - return self.bindtemplate % {'name': name} + if expanding: + self.contains_expanding_parameters = True + return "([EXPANDING_%s])" % name + else: + return self.bindtemplate % {'name': name} def visit_cte(self, cte, asfrom=False, ashint=False, fromhints=None, diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index d409ebacc..4ba53ef75 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -127,10 +127,18 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw): return _boolean_compare(expr, op, seq_or_selectable, negate=negate_op, **kw) elif isinstance(seq_or_selectable, ClauseElement): - raise exc.InvalidRequestError( - 'in_() accepts' - ' either a list of expressions ' - 'or a selectable: %r' % seq_or_selectable) + if isinstance(seq_or_selectable, BindParameter) and \ + seq_or_selectable.expanding: + return _boolean_compare( + expr, op, + seq_or_selectable, + negate=negate_op) + else: + raise exc.InvalidRequestError( + 'in_() accepts' + ' either a list of expressions, ' + 'a selectable, or an "expanding" bound parameter: %r' + % seq_or_selectable) # Handle non selectable arguments as sequences args = [] @@ -139,8 +147,8 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw): if not isinstance(o, operators.ColumnOperators): raise exc.InvalidRequestError( 'in_() accepts' - ' either a list of expressions ' - 'or a selectable: %r' % o) + ' either a list of expressions, ' + 'a selectable, or an "expanding" bound parameter: %r' % o) elif o is None: o = Null() else: diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 001c3d042..414e3f477 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -867,6 +867,7 @@ class BindParameter(ColumnElement): def __init__(self, key, value=NO_ARG, type_=None, unique=False, required=NO_ARG, quote=None, callable_=None, + expanding=False, isoutparam=False, _compared_to_operator=None, _compared_to_type=None): @@ -1052,6 +1053,23 @@ class BindParameter(ColumnElement): "OUT" parameter. This applies to backends such as Oracle which support OUT parameters. + :param expanding: + if True, this parameter will be treated as an "expanding" parameter + at execution time; the parameter value is expected to be a sequence, + rather than a scalar value, and the string SQL statement will + be transformed on a per-execution basis to accomodate the sequence + with a variable number of parameter slots passed to the DBAPI. + This is to allow statement caching to be used in conjunction with + an IN clause. + + .. note:: The "expanding" feature does not support "executemany"- + style parameter sets, nor does it support empty IN expressions. + + .. note:: The "expanding" feature should be considered as + **experimental** within the 1.2 series. + + .. versionadded:: 1.2 + .. seealso:: :ref:`coretutorial_bind_param` @@ -1093,6 +1111,8 @@ class BindParameter(ColumnElement): self.callable = callable_ self.isoutparam = isoutparam self.required = required + self.expanding = expanding + if type_ is None: if _compared_to_type is not None: self.type = \ diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index d38a69159..95aef0e17 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -220,6 +220,14 @@ class SuiteRequirements(Requirements): ) @property + def tuple_in(self): + """Target platform supports the syntax + "(x, y) IN ((x1, y1), (x2, y2), ...)" + """ + + return exclusions.closed() + + @property def duplicate_names_in_cursor_description(self): """target platform supports a SELECT statement that has the same name repeated more than once in the columns list.""" diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py index e7de356b8..4086a4c24 100644 --- a/lib/sqlalchemy/testing/suite/test_select.py +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -2,7 +2,7 @@ from .. import fixtures, config from ..assertions import eq_ from sqlalchemy import util -from sqlalchemy import Integer, String, select, func, bindparam, union +from sqlalchemy import Integer, String, select, func, bindparam, union, tuple_ from sqlalchemy import testing from ..schema import Table, Column @@ -310,3 +310,57 @@ class CompoundSelectTest(fixtures.TablesTest): u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] ) + + +class ExpandingBoundInTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table("some_table", metadata, + Column('id', Integer, primary_key=True), + Column('x', Integer), + Column('y', Integer)) + + @classmethod + def insert_data(cls): + config.db.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "x": 1, "y": 2}, + {"id": 2, "x": 2, "y": 3}, + {"id": 3, "x": 3, "y": 4}, + {"id": 4, "x": 4, "y": 5}, + ] + ) + + def _assert_result(self, select, result, params=()): + eq_( + config.db.execute(select, params).fetchall(), + result + ) + + def test_bound_in_scalar(self): + table = self.tables.some_table + + stmt = select([table.c.id]).where( + table.c.x.in_(bindparam('q', expanding=True))) + + self._assert_result( + stmt, + [(2, ), (3, ), (4, )], + params={"q": [2, 3, 4]}, + ) + + @testing.requires.tuple_in + def test_bound_in_two_tuple(self): + table = self.tables.some_table + + stmt = select([table.c.id]).where( + tuple_(table.c.x, table.c.y).in_(bindparam('q', expanding=True))) + + self._assert_result( + stmt, + [(2, ), (3, ), (4, )], + params={"q": [(2, 3), (3, 4), (4, 5)]}, + ) |
