summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/engine/default.py113
-rw-r--r--lib/sqlalchemy/sql/compiler.py34
-rw-r--r--lib/sqlalchemy/sql/default_comparator.py20
-rw-r--r--lib/sqlalchemy/sql/elements.py20
-rw-r--r--lib/sqlalchemy/testing/requirements.py8
-rw-r--r--lib/sqlalchemy/testing/suite/test_select.py56
6 files changed, 235 insertions, 16 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 628e23c9e..d1b54ab01 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -552,6 +552,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
# result column names
_translate_colname = None
+ _expanded_parameters = util.immutabledict()
+
@classmethod
def _init_ddl(cls, dialect, connection, dbapi_connection, compiled_ddl):
"""Initialize execution context for a DDLElement construct."""
@@ -645,6 +647,11 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
processors = compiled._bind_processors
+ if compiled.contains_expanding_parameters:
+ positiontup = self._expand_in_parameters(compiled, processors)
+ elif compiled.positional:
+ positiontup = self.compiled.positiontup
+
# Convert the dictionary of bind parameter values
# into a dict or list to be sent to the DBAPI's
# execute() or executemany() method.
@@ -652,7 +659,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
if compiled.positional:
for compiled_params in self.compiled_parameters:
param = []
- for key in self.compiled.positiontup:
+ for key in positiontup:
if key in processors:
param.append(processors[key](compiled_params[key]))
else:
@@ -684,10 +691,97 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
)
parameters.append(param)
+
self.parameters = dialect.execute_sequence_format(parameters)
return self
+ def _expand_in_parameters(self, compiled, processors):
+ """handle special 'expanding' parameters, IN tuples that are rendered
+ on a per-parameter basis for an otherwise fixed SQL statement string.
+
+ """
+ if self.executemany:
+ raise exc.InvalidRequestError(
+ "'expanding' parameters can't be used with "
+ "executemany()")
+
+ if self.compiled.positional and self.compiled._numeric_binds:
+ # I'm not familiar with any DBAPI that uses 'numeric'
+ raise NotImplementedError(
+ "'expanding' bind parameters not supported with "
+ "'numeric' paramstyle at this time.")
+
+ self._expanded_parameters = {}
+
+ compiled_params = self.compiled_parameters[0]
+ if compiled.positional:
+ positiontup = []
+ else:
+ positiontup = None
+
+ replacement_expressions = {}
+ for name in (
+ self.compiled.positiontup if compiled.positional
+ else self.compiled.binds
+ ):
+ parameter = self.compiled.binds[name]
+ if parameter.expanding:
+ values = compiled_params.pop(name)
+ if not values:
+ raise exc.InvalidRequestError(
+ "'expanding' parameters can't be used with an "
+ "empty list"
+ )
+ elif isinstance(values[0], (tuple, list)):
+ to_update = [
+ ("%s_%s_%s" % (name, i, j), value)
+ for i, tuple_element in enumerate(values, 1)
+ for j, value in enumerate(tuple_element, 1)
+ ]
+ replacement_expressions[name] = ", ".join(
+ "(%s)" % ", ".join(
+ self.compiled.bindtemplate % {
+ "name":
+ to_update[i * len(tuple_element) + j][0]
+ }
+ for j, value in enumerate(tuple_element)
+ )
+ for i, tuple_element in enumerate(values)
+
+ )
+ else:
+ to_update = [
+ ("%s_%s" % (name, i), value)
+ for i, value in enumerate(values, 1)
+ ]
+ replacement_expressions[name] = ", ".join(
+ self.compiled.bindtemplate % {
+ "name": key}
+ for key, value in to_update
+ )
+ compiled_params.update(to_update)
+ processors.update(
+ (key, processors[name])
+ for key in to_update if name in processors
+ )
+ if compiled.positional:
+ positiontup.extend(name for name, value in to_update)
+ self._expanded_parameters[name] = [
+ expand_key for expand_key, value in to_update]
+ elif compiled.positional:
+ positiontup.append(name)
+
+ def process_expanding(m):
+ return replacement_expressions.pop(m.group(1))
+
+ self.statement = re.sub(
+ r"\[EXPANDING_(.+)\]",
+ process_expanding,
+ self.statement
+ )
+ return positiontup
+
@classmethod
def _init_statement(cls, dialect, connection, dbapi_connection,
statement, parameters):
@@ -1039,7 +1133,11 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
get_dbapi_type(self.dialect.dbapi)
if dbtype is not None and \
(not exclude_types or dbtype not in exclude_types):
- inputsizes.append(dbtype)
+ if key in self._expanded_parameters:
+ inputsizes.extend(
+ [dbtype] * len(self._expanded_parameters[key]))
+ else:
+ inputsizes.append(dbtype)
try:
self.cursor.setinputsizes(*inputsizes)
except BaseException as e:
@@ -1054,10 +1152,19 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
if dbtype is not None and \
(not exclude_types or dbtype not in exclude_types):
if translate:
+ # TODO: this part won't work w/ the
+ # expanded_parameters feature, e.g. for cx_oracle
+ # quoted bound names
key = translate.get(key, key)
if not self.dialect.supports_unicode_binds:
key = self.dialect._encoder(key)[0]
- inputsizes[key] = dbtype
+ if key in self._expanded_parameters:
+ inputsizes.update(
+ (expand_key, dbtype) for expand_key
+ in self._expanded_parameters[key]
+ )
+ else:
+ inputsizes[key] = dbtype
try:
self.cursor.setinputsizes(**inputsizes)
except BaseException as e:
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index cc4248009..6da064797 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -350,6 +350,14 @@ class SQLCompiler(Compiled):
columns with the table name (i.e. MySQL only)
"""
+ contains_expanding_parameters = False
+ """True if we've encountered bindparam(..., expanding=True).
+
+ These need to be converted before execution time against the
+ string statement.
+
+ """
+
ansi_bind_rules = False
"""SQL 92 doesn't allow bind parameters to be used
in the columns clause of a SELECT, nor does it allow
@@ -370,8 +378,14 @@ class SQLCompiler(Compiled):
True unless using an unordered TextAsFrom.
"""
- insert_prefetch = update_prefetch = ()
+ _numeric_binds = False
+ """
+ True if paramstyle is "numeric". This paramstyle is trickier than
+ all the others.
+ """
+
+ insert_prefetch = update_prefetch = ()
def __init__(self, dialect, statement, column_keys=None,
inline=False, **kwargs):
@@ -418,6 +432,7 @@ class SQLCompiler(Compiled):
self.positional = dialect.positional
if self.positional:
self.positiontup = []
+ self._numeric_binds = dialect.paramstyle == "numeric"
self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
self.ctes = None
@@ -439,7 +454,7 @@ class SQLCompiler(Compiled):
) and statement._returning:
self.returning = statement._returning
- if self.positional and dialect.paramstyle == 'numeric':
+ if self.positional and self._numeric_binds:
self._apply_numbered_params()
@property
@@ -492,7 +507,8 @@ class SQLCompiler(Compiled):
return dict(
(key, value) for key, value in
((self.bind_names[bindparam],
- bindparam.type._cached_bind_processor(self.dialect))
+ bindparam.type._cached_bind_processor(self.dialect)
+ )
for bindparam in self.bind_names)
if value is not None
)
@@ -1238,7 +1254,8 @@ class SQLCompiler(Compiled):
self.binds[bindparam.key] = self.binds[name] = bindparam
- return self.bindparam_string(name, **kwargs)
+ return self.bindparam_string(
+ name, expanding=bindparam.expanding, **kwargs)
def render_literal_bindparam(self, bindparam, **kw):
value = bindparam.effective_value
@@ -1300,13 +1317,18 @@ class SQLCompiler(Compiled):
self.anon_map[derived] = anonymous_counter + 1
return derived + "_" + str(anonymous_counter)
- def bindparam_string(self, name, positional_names=None, **kw):
+ def bindparam_string(
+ self, name, positional_names=None, expanding=False, **kw):
if self.positional:
if positional_names is not None:
positional_names.append(name)
else:
self.positiontup.append(name)
- return self.bindtemplate % {'name': name}
+ if expanding:
+ self.contains_expanding_parameters = True
+ return "([EXPANDING_%s])" % name
+ else:
+ return self.bindtemplate % {'name': name}
def visit_cte(self, cte, asfrom=False, ashint=False,
fromhints=None,
diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py
index d409ebacc..4ba53ef75 100644
--- a/lib/sqlalchemy/sql/default_comparator.py
+++ b/lib/sqlalchemy/sql/default_comparator.py
@@ -127,10 +127,18 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw):
return _boolean_compare(expr, op, seq_or_selectable,
negate=negate_op, **kw)
elif isinstance(seq_or_selectable, ClauseElement):
- raise exc.InvalidRequestError(
- 'in_() accepts'
- ' either a list of expressions '
- 'or a selectable: %r' % seq_or_selectable)
+ if isinstance(seq_or_selectable, BindParameter) and \
+ seq_or_selectable.expanding:
+ return _boolean_compare(
+ expr, op,
+ seq_or_selectable,
+ negate=negate_op)
+ else:
+ raise exc.InvalidRequestError(
+ 'in_() accepts'
+ ' either a list of expressions, '
+ 'a selectable, or an "expanding" bound parameter: %r'
+ % seq_or_selectable)
# Handle non selectable arguments as sequences
args = []
@@ -139,8 +147,8 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw):
if not isinstance(o, operators.ColumnOperators):
raise exc.InvalidRequestError(
'in_() accepts'
- ' either a list of expressions '
- 'or a selectable: %r' % o)
+ ' either a list of expressions, '
+ 'a selectable, or an "expanding" bound parameter: %r' % o)
elif o is None:
o = Null()
else:
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 001c3d042..414e3f477 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -867,6 +867,7 @@ class BindParameter(ColumnElement):
def __init__(self, key, value=NO_ARG, type_=None,
unique=False, required=NO_ARG,
quote=None, callable_=None,
+ expanding=False,
isoutparam=False,
_compared_to_operator=None,
_compared_to_type=None):
@@ -1052,6 +1053,23 @@ class BindParameter(ColumnElement):
"OUT" parameter. This applies to backends such as Oracle which
support OUT parameters.
+ :param expanding:
+ if True, this parameter will be treated as an "expanding" parameter
+ at execution time; the parameter value is expected to be a sequence,
+ rather than a scalar value, and the string SQL statement will
+ be transformed on a per-execution basis to accomodate the sequence
+ with a variable number of parameter slots passed to the DBAPI.
+ This is to allow statement caching to be used in conjunction with
+ an IN clause.
+
+ .. note:: The "expanding" feature does not support "executemany"-
+ style parameter sets, nor does it support empty IN expressions.
+
+ .. note:: The "expanding" feature should be considered as
+ **experimental** within the 1.2 series.
+
+ .. versionadded:: 1.2
+
.. seealso::
:ref:`coretutorial_bind_param`
@@ -1093,6 +1111,8 @@ class BindParameter(ColumnElement):
self.callable = callable_
self.isoutparam = isoutparam
self.required = required
+ self.expanding = expanding
+
if type_ is None:
if _compared_to_type is not None:
self.type = \
diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py
index d38a69159..95aef0e17 100644
--- a/lib/sqlalchemy/testing/requirements.py
+++ b/lib/sqlalchemy/testing/requirements.py
@@ -220,6 +220,14 @@ class SuiteRequirements(Requirements):
)
@property
+ def tuple_in(self):
+ """Target platform supports the syntax
+ "(x, y) IN ((x1, y1), (x2, y2), ...)"
+ """
+
+ return exclusions.closed()
+
+ @property
def duplicate_names_in_cursor_description(self):
"""target platform supports a SELECT statement that has
the same name repeated more than once in the columns list."""
diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py
index e7de356b8..4086a4c24 100644
--- a/lib/sqlalchemy/testing/suite/test_select.py
+++ b/lib/sqlalchemy/testing/suite/test_select.py
@@ -2,7 +2,7 @@ from .. import fixtures, config
from ..assertions import eq_
from sqlalchemy import util
-from sqlalchemy import Integer, String, select, func, bindparam, union
+from sqlalchemy import Integer, String, select, func, bindparam, union, tuple_
from sqlalchemy import testing
from ..schema import Table, Column
@@ -310,3 +310,57 @@ class CompoundSelectTest(fixtures.TablesTest):
u1.order_by(u1.c.id),
[(2, 2, 3), (3, 3, 4)]
)
+
+
+class ExpandingBoundInTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table("some_table", metadata,
+ Column('id', Integer, primary_key=True),
+ Column('x', Integer),
+ Column('y', Integer))
+
+ @classmethod
+ def insert_data(cls):
+ config.db.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "x": 1, "y": 2},
+ {"id": 2, "x": 2, "y": 3},
+ {"id": 3, "x": 3, "y": 4},
+ {"id": 4, "x": 4, "y": 5},
+ ]
+ )
+
+ def _assert_result(self, select, result, params=()):
+ eq_(
+ config.db.execute(select, params).fetchall(),
+ result
+ )
+
+ def test_bound_in_scalar(self):
+ table = self.tables.some_table
+
+ stmt = select([table.c.id]).where(
+ table.c.x.in_(bindparam('q', expanding=True)))
+
+ self._assert_result(
+ stmt,
+ [(2, ), (3, ), (4, )],
+ params={"q": [2, 3, 4]},
+ )
+
+ @testing.requires.tuple_in
+ def test_bound_in_two_tuple(self):
+ table = self.tables.some_table
+
+ stmt = select([table.c.id]).where(
+ tuple_(table.c.x, table.c.y).in_(bindparam('q', expanding=True)))
+
+ self._assert_result(
+ stmt,
+ [(2, ), (3, ), (4, )],
+ params={"q": [(2, 3), (3, 4), (4, 5)]},
+ )