diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/dialects/mysql/base.py | 14 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/mysql/mysqldb.py | 13 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 13 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/associationproxy.py | 100 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/declarative/base.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/ext/horizontal_shard.py | 47 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/dynamic.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/query.py | 14 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/relationships.py | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/strategies.py | 133 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/default_comparator.py | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 10 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/engines.py | 6 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/suite/test_select.py | 42 |
17 files changed, 367 insertions, 66 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index cd17bcdc4..07eca78bb 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1180,8 +1180,18 @@ class MySQLCompiler(compiler.SQLCompiler): fromhints=from_hints, **kw) for t in [from_table] + extra_froms) - def visit_empty_set_expr(self, type_): - return 'SELECT 1 FROM (SELECT 1) as _empty_set WHERE 1!=1' + def visit_empty_set_expr(self, element_types): + return ( + "SELECT %(outer)s FROM (SELECT %(inner)s) " + "as _empty_set WHERE 1!=1" % { + "inner": ", ".join( + "1 AS _in_%s" % idx + for idx, type_ in enumerate(element_types)), + "outer": ", ".join( + "_in_%s" % idx + for idx, type_ in enumerate(element_types)) + } + ) class MySQLDDLCompiler(compiler.DDLCompiler): diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index 7554d244c..dfa9b52df 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -87,6 +87,19 @@ class MySQLDialect_mysqldb(MySQLDialect): def __init__(self, server_side_cursors=False, **kwargs): super(MySQLDialect_mysqldb, self).__init__(**kwargs) self.server_side_cursors = server_side_cursors + self._mysql_dbapi_version = self._parse_dbapi_version( + self.dbapi.__version__) if self.dbapi is not None \ + and hasattr(self.dbapi, '__version__') else (0, 0, 0) + + def _parse_dbapi_version(self, version): + m = re.match(r'(\d+)\.(\d+)(?:\.(\d+))?', version) + if m: + return tuple( + int(x) + for x in m.group(1, 2, 3) + if x is not None) + else: + return (0, 0, 0) @util.langhelpers.memoized_property def supports_server_side_cursors(self): diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 11fcc41d5..5251a000d 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1485,14 +1485,17 @@ class PGCompiler(compiler.SQLCompiler): if escape else '' ) - def visit_empty_set_expr(self, type_, **kw): + def visit_empty_set_expr(self, element_types): # cast the empty set to the type we are comparing against. if # we are comparing against the null type, pick an arbitrary # datatype for the empty set - if type_._isnull: - type_ = INTEGER() - return 'SELECT CAST(NULL AS %s) WHERE 1!=1' % \ - self.dialect.type_compiler.process(type_, **kw) + return 'SELECT %s WHERE 1!=1' % ( + ", ".join( + "CAST(NULL AS %s)" % self.dialect.type_compiler.process( + INTEGER() if type_._isnull else type_, + ) for type_ in element_types or [INTEGER()] + ), + ) def render_literal_value(self, value, type_): value = super(PGCompiler, self).render_literal_value(value, type_) diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index f48217a4e..5c96e4240 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -737,7 +737,10 @@ class DefaultExecutionContext(interfaces.ExecutionContext): to_update = [] replacement_expressions[name] = ( self.compiled.visit_empty_set_expr( - type_=parameter.type) + parameter._expanding_in_types + if parameter._expanding_in_types + else [parameter.type] + ) ) elif isinstance(values[0], (tuple, list)): diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index 1c28b10a1..629b4ac64 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -17,6 +17,7 @@ import operator from .. import exc, orm, util from ..orm import collections, interfaces from ..sql import or_ +from ..sql.operators import ColumnOperators from .. import inspect @@ -217,7 +218,7 @@ class AssociationProxy(interfaces.InspectionAttrInfo): except KeyError: owner = self._calc_owner(class_) if owner is not None: - result = AssociationProxyInstance(self, owner) + result = AssociationProxyInstance.for_proxy(self, owner) setattr(class_, self.key + "_inst", result) return result else: @@ -283,13 +284,49 @@ class AssociationProxyInstance(object): """ - def __init__(self, parent, owning_class): + def __init__(self, parent, owning_class, target_class, value_attr): self.parent = parent self.key = parent.key self.owning_class = owning_class self.target_collection = parent.target_collection self.value_attr = parent.value_attr self.collection_class = None + self.target_class = target_class + self.value_attr = value_attr + + target_class = None + """The intermediary class handled by this + :class:`.AssociationProxyInstance`. + + Intercepted append/set/assignment events will result + in the generation of new instances of this class. + + """ + + @classmethod + def for_proxy(cls, parent, owning_class): + target_collection = parent.target_collection + value_attr = parent.value_attr + prop = orm.class_mapper(owning_class).\ + get_property(target_collection) + target_class = prop.mapper.class_ + + target_assoc = cls._cls_unwrap_target_assoc_proxy( + target_class, value_attr) + if target_assoc is not None: + return ObjectAssociationProxyInstance( + parent, owning_class, target_class, value_attr + ) + + is_object = getattr(target_class, value_attr).impl.uses_objects + if is_object: + return ObjectAssociationProxyInstance( + parent, owning_class, target_class, value_attr + ) + else: + return ColumnAssociationProxyInstance( + parent, owning_class, target_class, value_attr + ) def _get_property(self): return orm.class_mapper(self.owning_class).\ @@ -299,13 +336,18 @@ class AssociationProxyInstance(object): def _comparator(self): return self._get_property().comparator - @util.memoized_property - def _unwrap_target_assoc_proxy(self): - attr = getattr(self.target_class, self.value_attr) + @classmethod + def _cls_unwrap_target_assoc_proxy(cls, target_class, value_attr): + attr = getattr(target_class, value_attr) if isinstance(attr, (AssociationProxy, AssociationProxyInstance)): return attr return None + @util.memoized_property + def _unwrap_target_assoc_proxy(self): + return self._cls_unwrap_target_assoc_proxy( + self.target_class, self.value_attr) + @property def remote_attr(self): """The 'remote' :class:`.MapperProperty` referenced by this @@ -353,17 +395,6 @@ class AssociationProxyInstance(object): return (self.local_attr, self.remote_attr) @util.memoized_property - def target_class(self): - """The intermediary class handled by this - :class:`.AssociationProxyInstance`. - - Intercepted append/set/assignment events will result - in the generation of new instances of this class. - - """ - return self._get_property().mapper.class_ - - @util.memoized_property def scalar(self): """Return ``True`` if this :class:`.AssociationProxyInstance` proxies a scalar relationship on the local side.""" @@ -378,9 +409,9 @@ class AssociationProxyInstance(object): return not self._get_property().\ mapper.get_property(self.value_attr).uselist - @util.memoized_property + @property def _target_is_object(self): - return getattr(self.target_class, self.value_attr).impl.uses_objects + raise NotImplementedError() def _initialize_scalar_accessors(self): if self.parent.getset_factory: @@ -587,6 +618,12 @@ class AssociationProxyInstance(object): return self._criterion_exists( criterion=criterion, is_has=True, **kwargs) + +class ObjectAssociationProxyInstance(AssociationProxyInstance): + """an :class:`.AssociationProxyInstance` that has an object as a target. + """ + _target_is_object = True + def contains(self, obj): """Produce a proxied 'contains' expression using EXISTS. @@ -611,7 +648,7 @@ class AssociationProxyInstance(object): elif self._target_is_object and self.scalar and \ self._value_is_scalar: raise exc.InvalidRequestError( - "contains() doesn't apply to a scalar endpoint; use ==") + "contains() doesn't apply to a scalar object endpoint; use ==") else: return self._comparator._criterion_exists(**{self.value_attr: obj}) @@ -634,6 +671,31 @@ class AssociationProxyInstance(object): getattr(self.target_class, self.value_attr) != obj) +class ColumnAssociationProxyInstance( + ColumnOperators, AssociationProxyInstance): + """an :class:`.AssociationProxyInstance` that has a database column as a + target. + """ + _target_is_object = False + + def __eq__(self, other): + # special case "is None" to check for no related row as well + expr = self._criterion_exists( + self.remote_attr.operate(operator.eq, other) + ) + if other is None: + return or_( + expr, self._comparator == None + ) + else: + return expr + + def operate(self, op, *other, **kwargs): + return self._criterion_exists( + self.remote_attr.operate(op, *other, **kwargs) + ) + + class _lazy_collection(object): def __init__(self, obj, target): self.parent = obj diff --git a/lib/sqlalchemy/ext/declarative/base.py b/lib/sqlalchemy/ext/declarative/base.py index 9e15582d6..a6642364d 100644 --- a/lib/sqlalchemy/ext/declarative/base.py +++ b/lib/sqlalchemy/ext/declarative/base.py @@ -295,7 +295,8 @@ class _MapperConfig(object): # produces nested proxies, so we are only # looking one level deep right now. if isinstance(ret, InspectionAttr) and \ - ret._is_internal_proxy: + ret._is_internal_proxy and not isinstance( + ret.original_property, MapperProperty): ret = ret.descriptor dict_[name] = column_copies[obj] = ret diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index 6ef4c5612..f86e4fc93 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -45,7 +45,7 @@ class ShardedQuery(Query): def iter_for_shard(shard_id): context.attributes['shard_id'] = context.identity_token = shard_id result = self._connection_from_session( - mapper=self._mapper_zero(), + mapper=self._bind_mapper(), shard_id=shard_id).execute( context.statement, self._params) @@ -64,6 +64,28 @@ class ShardedQuery(Query): # were done, this is where it would happen return iter(partial) + def _execute_crud(self, stmt, mapper): + def exec_for_shard(shard_id): + conn = self._connection_from_session( + mapper=mapper, + shard_id=shard_id, + clause=stmt, + close_with_result=True) + result = conn.execute(stmt, self._params) + return result + + if self._shard_id is not None: + return exec_for_shard(self._shard_id) + else: + rowcount = 0 + results = [] + for shard_id in self.query_chooser(self): + result = exec_for_shard(shard_id) + rowcount += result.rowcount + results.append(result) + + return ShardedResult(results, rowcount) + def _identity_lookup( self, mapper, primary_key_identity, identity_token=None, lazy_loaded_from=None, **kw): @@ -123,6 +145,29 @@ class ShardedQuery(Query): primary_key_identity, _db_load_fn, identity_token=identity_token) +class ShardedResult(object): + """A value object that represents multiple :class:`.ResultProxy` objects. + + This is used by the :meth:`.ShardedQuery._execute_crud` hook to return + an object that takes the place of the single :class:`.ResultProxy`. + + Attribute include ``result_proxies``, which is a sequence of the + actual :class:`.ResultProxy` objects, as well as ``aggregate_rowcount`` + or ``rowcount``, which is the sum of all the individual rowcount values. + + .. versionadded:: 1.3 + """ + + __slots__ = ('result_proxies', 'aggregate_rowcount',) + + def __init__(self, result_proxies, aggregate_rowcount): + self.result_proxies = result_proxies + self.aggregate_rowcount = aggregate_rowcount + + @property + def rowcount(self): + return self.aggregate_rowcount + class ShardedSession(Session): def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None, query_cls=ShardedQuery, **kwargs): diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index 73d9ef3bb..3c59f61d7 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -219,6 +219,10 @@ class AppenderMixin(object): mapper = object_mapper(instance) prop = mapper._props[self.attr.key] + + if prop.secondary is not None: + self._set_select_from([prop.secondary], False) + self._criterion = prop._with_parent( instance, alias_secondary=False) @@ -284,6 +288,7 @@ class AppenderMixin(object): query = sess.query(self.attr.target_mapper) query._criterion = self._criterion + query._from_obj = self._from_obj query._order_by = self._order_by return query diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 95e26d83c..afa3b50b9 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -1337,9 +1337,7 @@ class BulkUD(object): self._do_post() def _execute_stmt(self, stmt): - self.result = self.query.session.execute( - stmt, params=self.query._params, - mapper=self.mapper) + self.result = self.query._execute_crud(stmt, self.mapper) self.rowcount = self.result.rowcount @util.dependencies("sqlalchemy.orm.query") diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 7e7c93527..bfddb5cfe 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -3016,6 +3016,12 @@ class Query(object): result = conn.execute(querycontext.statement, self._params) return loading.instances(querycontext.query, result, querycontext) + def _execute_crud(self, stmt, mapper): + conn = self._connection_from_session( + mapper=mapper, clause=stmt, close_with_result=True) + + return conn.execute(stmt, self._params) + def _get_bind_args(self, querycontext, fn, **kw): return fn( mapper=self._bind_mapper(), @@ -4014,13 +4020,17 @@ class _BundleEntity(_QueryEntity): if isinstance(expr, Bundle): _BundleEntity(self, expr) else: - _ColumnEntity(self, expr, namespace=self) + _ColumnEntity(self, expr) self.supports_single_entity = self.bundle.single_entity @property def mapper(self): - return self.entity_zero.mapper + ezero = self.entity_zero + if ezero is not None: + return ezero.mapper + else: + return None @property def entities(self): diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 27a75d45c..818f1c0ae 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -116,7 +116,8 @@ class RelationshipProperty(StrategizedProperty): bake_queries=True, _local_remote_pairs=None, query_class=None, - info=None): + info=None, + omit_join=None): """Provide a relationship between two mapped classes. This corresponds to a parent-child or associative table relationship. @@ -816,6 +817,13 @@ class RelationshipProperty(StrategizedProperty): the full set of related objects, to prevent modifications of the collection from resulting in persistence operations. + :param omit_join: + Allows manual control over the "selectin" automatic join + optimization. Set to ``False`` to disable the "omit join" feature + added in SQLAlchemy 1.3. + + .. versionadded:: 1.3 + """ super(RelationshipProperty, self).__init__() @@ -843,6 +851,7 @@ class RelationshipProperty(StrategizedProperty): self.doc = doc self.active_history = active_history self.join_depth = join_depth + self.omit_join = omit_join self.local_remote_pairs = _local_remote_pairs self.extension = extension self.bake_queries = bake_queries diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index d7597d3b2..b9abf0647 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -1837,8 +1837,8 @@ class JoinedLoader(AbstractRelationshipLoader): @properties.RelationshipProperty.strategy_for(lazy="selectin") class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots): __slots__ = ( - 'join_depth', '_parent_alias', '_in_expr', '_parent_pk_cols', - '_zero_idx', '_bakery' + 'join_depth', 'omit_join', '_parent_alias', '_in_expr', + '_pk_cols', '_zero_idx', '_bakery' ) _chunksize = 500 @@ -1846,9 +1846,46 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots): def __init__(self, parent, strategy_key): super(SelectInLoader, self).__init__(parent, strategy_key) self.join_depth = self.parent_property.join_depth + + if self.parent_property.omit_join is not None: + self.omit_join = self.parent_property.omit_join + else: + lazyloader = self.parent_property._get_strategy( + (("lazy", "select"),)) + self.omit_join = self.parent._get_clause[0].compare( + lazyloader._rev_lazywhere, + use_proxies=True, + equivalents=self.parent._equivalent_columns + ) + if self.omit_join: + self._init_for_omit_join() + else: + self._init_for_join() + + def _init_for_omit_join(self): + pk_to_fk = dict( + self.parent_property._join_condition.local_remote_pairs + ) + pk_to_fk.update( + (equiv, pk_to_fk[k]) + for k in list(pk_to_fk) + for equiv in self.parent._equivalent_columns.get(k, ()) + ) + + self._pk_cols = fk_cols = [ + pk_to_fk[col] + for col in self.parent.primary_key if col in pk_to_fk] + if len(fk_cols) > 1: + self._in_expr = sql.tuple_(*fk_cols) + self._zero_idx = False + else: + self._in_expr = fk_cols[0] + self._zero_idx = True + + def _init_for_join(self): self._parent_alias = aliased(self.parent.class_) pa_insp = inspect(self._parent_alias) - self._parent_pk_cols = pk_cols = [ + self._pk_cols = pk_cols = [ pa_insp._adapt_element(col) for col in self.parent.primary_key] if len(pk_cols) > 1: self._in_expr = sql.tuple_(*pk_cols) @@ -1922,8 +1959,24 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots): for state, overwrite in states ] - pk_cols = self._parent_pk_cols - pa = self._parent_alias + pk_cols = self._pk_cols + in_expr = self._in_expr + + if self.omit_join: + # in "omit join" mode, the primary key column and the + # "in" expression are in terms of the related entity. So + # if the related entity is polymorphic or otherwise aliased, + # we need to adapt our "_pk_cols" and "_in_expr" to that + # entity. in non-"omit join" mode, these are against the + # parent entity and do not need adaption. + insp = inspect(effective_entity) + if insp.is_aliased_class: + pk_cols = [ + insp._adapt_element(col) + for col in pk_cols + ] + in_expr = insp._adapt_element(in_expr) + pk_cols = [insp._adapt_element(col) for col in pk_cols] q = self._bakery( lambda session: session.query( @@ -1931,15 +1984,30 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots): ), self ) + if self.omit_join: + # the Bundle we have in the "omit_join" case is against raw, non + # annotated columns, so to ensure the Query knows its primary + # entity, we add it explictly. If we made the Bundle against + # annotated columns, we hit a performance issue in this specific + # case, which is detailed in issue #4347. + q.add_criteria(lambda q: q.select_from(effective_entity)) + else: + # in the non-omit_join case, the Bundle is against the annotated/ + # mapped column of the parent entity, but the #4347 issue does not + # occur in this case. + pa = self._parent_alias + q.add_criteria( + lambda q: q.select_from(pa).join( + getattr(pa, self.parent_property.key).of_type( + effective_entity) + ) + ) + q.add_criteria( - lambda q: q.select_from(pa).join( - getattr(pa, - self.parent_property.key).of_type(effective_entity)). - filter( - self._in_expr.in_( - sql.bindparam('primary_keys', expanding=True)) - ).order_by(*pk_cols) - ) + lambda q: q.filter( + in_expr.in_( + sql.bindparam("primary_keys", expanding=True)) + ).order_by(*pk_cols)) orig_query = context.query @@ -1954,23 +2022,30 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots): ) if self.parent_property.order_by: - def _setup_outermost_orderby(q): - # imitate the same method that - # subquery eager loading does it, looking for the - # adapted "secondary" table - eagerjoin = q._from_obj[0] - eager_order_by = \ - eagerjoin._target_adapter.\ - copy_and_process( - util.to_list( - self.parent_property.order_by + if self.omit_join: + eager_order_by = self.parent_property.order_by + if insp.is_aliased_class: + eager_order_by = [ + insp._adapt_element(elem) for elem in + eager_order_by + ] + q.add_criteria( + lambda q: q.order_by(*eager_order_by) + ) + else: + def _setup_outermost_orderby(q): + # imitate the same method that subquery eager loading uses, + # looking for the adapted "secondary" table + eagerjoin = q._from_obj[0] + eager_order_by = \ + eagerjoin._target_adapter.\ + copy_and_process( + util.to_list(self.parent_property.order_by) ) - ) - return q.order_by(*eager_order_by) - - q.add_criteria( - _setup_outermost_orderby - ) + return q.order_by(*eager_order_by) + q.add_criteria( + _setup_outermost_orderby + ) uselist = self.uselist _empty_result = () if uselist else None diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 2f68b7e2e..27ee4afc6 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1056,7 +1056,7 @@ class SQLCompiler(Compiled): self._emit_empty_in_warning() return self.process(binary.left == binary.left) - def visit_empty_set_expr(self, type_): + def visit_empty_set_expr(self, element_types): raise NotImplementedError( "Dialect '%s' does not support empty set expression." % self.dialect.name diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 5d02f65a1..8149f9731 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -15,7 +15,8 @@ from .elements import BindParameter, True_, False_, BinaryExpression, \ Null, _const_expr, _clause_element_as_expr, \ ClauseList, ColumnElement, TextClause, UnaryExpression, \ collate, _is_literal, _literal_as_text, ClauseElement, and_, or_, \ - Slice, Visitable, _literal_as_binds, CollectionAggregate + Slice, Visitable, _literal_as_binds, CollectionAggregate, \ + Tuple from .selectable import SelectBase, Alias, Selectable, ScalarSelect @@ -145,6 +146,14 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw): elif isinstance(seq_or_selectable, ClauseElement): if isinstance(seq_or_selectable, BindParameter) and \ seq_or_selectable.expanding: + + if isinstance(expr, Tuple): + seq_or_selectable = ( + seq_or_selectable._with_expanding_in_types( + [elem.type for elem in expr] + ) + ) + return _boolean_compare( expr, op, seq_or_selectable, diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index dd16b6862..de3b7992a 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -865,6 +865,7 @@ class BindParameter(ColumnElement): __visit_name__ = 'bindparam' _is_crud = False + _expanding_in_types = () def __init__(self, key, value=NO_ARG, type_=None, unique=False, required=NO_ARG, @@ -1134,6 +1135,15 @@ class BindParameter(ColumnElement): else: self.type = type_ + def _with_expanding_in_types(self, types): + """Return a copy of this :class:`.BindParameter` in + the context of an expanding IN against a tuple. + + """ + cloned = self._clone() + cloned._expanding_in_types = types + return cloned + def _with_value(self, value): """Return a copy of this :class:`.BindParameter` with the given value set. diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index 7404befb8..d17e30edf 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -59,6 +59,12 @@ class ConnectionKiller(object): # not sure if this should be if pypy/jython only. # note that firebird/fdb definitely needs this though for conn, rec in list(self.conns): + if rec.connection is None: + # this is a hint that the connection is closed, which + # is causing segfaults on mysqlclient due to + # https://github.com/PyMySQL/mysqlclient-python/issues/270; + # try to work around here + continue self._safe(conn.rollback) def _stop_test_ctx(self): diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py index 78b34f496..73ce02492 100644 --- a/lib/sqlalchemy/testing/suite/test_select.py +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -402,6 +402,34 @@ class ExpandingBoundInTest(fixtures.TablesTest): params={"q": [], "p": []}, ) + @testing.requires.tuple_in + def test_empty_heterogeneous_tuples(self): + table = self.tables.some_table + + stmt = select([table.c.id]).where( + tuple_(table.c.x, table.c.z).in_( + bindparam('q', expanding=True))).order_by(table.c.id) + + self._assert_result( + stmt, + [], + params={"q": []}, + ) + + @testing.requires.tuple_in + def test_empty_homogeneous_tuples(self): + table = self.tables.some_table + + stmt = select([table.c.id]).where( + tuple_(table.c.x, table.c.y).in_( + bindparam('q', expanding=True))).order_by(table.c.id) + + self._assert_result( + stmt, + [], + params={"q": []}, + ) + def test_bound_in_scalar(self): table = self.tables.some_table @@ -428,6 +456,20 @@ class ExpandingBoundInTest(fixtures.TablesTest): params={"q": [(2, 3), (3, 4), (4, 5)]}, ) + @testing.requires.tuple_in + def test_bound_in_heterogeneous_two_tuple(self): + table = self.tables.some_table + + stmt = select([table.c.id]).where( + tuple_(table.c.x, table.c.z).in_( + bindparam('q', expanding=True))).order_by(table.c.id) + + self._assert_result( + stmt, + [(2, ), (3, ), (4, )], + params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]}, + ) + def test_empty_set_against_integer(self): table = self.tables.some_table |
