diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/dialects/sqlite/base.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/query.py | 44 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/relationships.py | 8 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/util.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 129 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 28 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/util.py | 33 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 15 |
10 files changed, 222 insertions, 57 deletions
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 1ca8f4e64..c7e09b164 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -592,6 +592,7 @@ class SQLiteDialect(default.DefaultDialect): supports_empty_insert = False supports_cast = True supports_multivalues_insert = True + supports_right_nested_joins = False default_paramstyle = 'qmark' execution_ctx_cls = SQLiteExecutionContext diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 91869ab75..2ad7002c4 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -49,6 +49,8 @@ class DefaultDialect(interfaces.Dialect): postfetch_lastrowid = True implicit_returning = False + supports_right_nested_joins = True + supports_native_enum = False supports_native_boolean = False @@ -106,6 +108,7 @@ class DefaultDialect(interfaces.Dialect): def __init__(self, convert_unicode=False, encoding='utf-8', paramstyle=None, dbapi=None, implicit_returning=None, + supports_right_nested_joins=None, case_sensitive=True, label_length=None, **kwargs): @@ -130,6 +133,8 @@ class DefaultDialect(interfaces.Dialect): self.positional = self.paramstyle in ('qmark', 'format', 'numeric') self.identifier_preparer = self.preparer(self) self.type_compiler = self.type_compiler(self) + if supports_right_nested_joins is not None: + self.supports_right_nested_joins = supports_right_nested_joins self.case_sensitive = case_sensitive diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index beae7aba0..39ed8d8bf 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -447,6 +447,8 @@ class Query(object): statement if self._params: stmt = stmt.params(self._params) + + # TODO: there's no tests covering effects of # the annotation not being there return stmt._annotate({'no_replacement_traverse': True}) @@ -1795,6 +1797,7 @@ class Query(object): right_entity, onclause, outerjoin, create_aliases, prop) + def _join_left_to_right(self, left, right, onclause, outerjoin, create_aliases, prop): """append a JOIN to the query's from clause.""" @@ -1814,10 +1817,21 @@ class Query(object): "are the same entity" % (left, right)) + l_info = inspect(left) + r_info = inspect(right) + + overlap = not create_aliases and \ + sql_util.selectables_overlap(l_info.selectable, + r_info.selectable) + if overlap and l_info.selectable is r_info.selectable: + raise sa_exc.InvalidRequestError( + "Can't join table/selectable '%s' to itself" % + l_info.selectable) + right, onclause = self._prepare_right_side( - right, onclause, + r_info, right, onclause, create_aliases, - prop) + prop, overlap) # if joining on a MapperProperty path, # track the path to prevent redundant joins @@ -1829,10 +1843,11 @@ class Query(object): else: self._joinpoint = {'_joinpoint_entity': right} - self._join_to_left(left, right, onclause, outerjoin) + self._join_to_left(l_info, left, right, onclause, outerjoin) - def _prepare_right_side(self, right, onclause, create_aliases, prop): - info = inspect(right) + def _prepare_right_side(self, r_info, right, onclause, create_aliases, + prop, overlap): + info = r_info right_mapper, right_selectable, right_is_aliased = \ getattr(info, 'mapper', None), \ @@ -1862,19 +1877,23 @@ class Query(object): (right_selectable.description, right_mapper.mapped_table.description)) - if not isinstance(right_selectable, expression.Alias): + if isinstance(right_selectable, expression.SelectBase): + # TODO: this isn't even covered now! right_selectable = right_selectable.alias() + need_adapter = True right = aliased(right_mapper, right_selectable) - need_adapter = True aliased_entity = right_mapper and \ not right_is_aliased and \ ( - right_mapper.with_polymorphic or isinstance( - right_mapper.mapped_table, - expression.Join) + right_mapper._with_polymorphic_selectable, + expression.Alias) + or + overlap # test for overlap: + # orm/inheritance/relationships.py + # SelfReferentialM2MTest ) if not need_adapter and (create_aliases or aliased_entity): @@ -1910,8 +1929,8 @@ class Query(object): return right, onclause - def _join_to_left(self, left, right, onclause, outerjoin): - info = inspect(left) + def _join_to_left(self, l_info, left, right, onclause, outerjoin): + info = l_info left_mapper = getattr(info, 'mapper', None) left_selectable = info.selectable @@ -1946,7 +1965,6 @@ class Query(object): clause = left_selectable assert clause is not None - try: clause = orm_join(clause, right, onclause, isouter=outerjoin) except sa_exc.ArgumentError as ae: diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 95fa28613..33377d3ec 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -17,7 +17,7 @@ from .. import sql, util, exc as sa_exc, schema from ..sql.util import ( ClauseAdapter, join_condition, _shallow_annotate, visit_binary_product, - _deep_deannotate, find_tables + _deep_deannotate, find_tables, selectables_overlap ) from ..sql import operators, expression, visitors from .interfaces import MANYTOMANY, MANYTOONE, ONETOMANY @@ -404,11 +404,7 @@ class JoinCondition(object): def _tables_overlap(self): """Return True if parent/child tables have some overlap.""" - return bool( - set(find_tables(self.parent_selectable)).intersection( - find_tables(self.child_selectable) - ) - ) + return selectables_overlap(self.parent_selectable, self.child_selectable) def _annotate_remote(self): """Annotate the primaryjoin and secondaryjoin diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index bd8228f2c..c21e7eace 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -493,6 +493,7 @@ class AliasedClass(object): """ def __init__(self, cls, alias=None, name=None, + flat=True, adapt_on_names=False, # TODO: None for default here? with_polymorphic_mappers=(), @@ -501,7 +502,7 @@ class AliasedClass(object): use_mapper_path=False): mapper = _class_to_mapper(cls) if alias is None: - alias = mapper._with_polymorphic_selectable.alias(name=name) + alias = mapper._with_polymorphic_selectable.alias(name=name, flat=flat) self._aliased_insp = AliasedInsp( self, mapper, @@ -837,7 +838,7 @@ def with_polymorphic(base, classes, selectable=False, _with_polymorphic_args(classes, selectable, innerjoin=innerjoin) if aliased: - selectable = selectable.alias() + selectable = selectable.alias(flat=True) return AliasedClass(base, selectable, with_polymorphic_mappers=mappers, diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 73b094053..dd2a6e08c 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1113,23 +1113,115 @@ class SQLCompiler(engine.Compiled): def get_crud_hint_text(self, table, text): return None + def _transform_select_for_nested_joins(self, select): + """Rewrite any "a JOIN (b JOIN c)" expression as + "a JOIN (select * from b JOIN c) AS anon", to support + databases that can't parse a parenthesized join correctly + (i.e. sqlite the main one). + + """ + cloned = {} + column_translate = [{}] + + # TODO: should we be using isinstance() for this, + # as this whole system won't work for custom Join/Select + # subclasses where compilation routines + # call down to compiler.visit_join(), compiler.visit_select() + join_name = sql.Join.__visit_name__ + select_name = sql.Select.__visit_name__ + + def visit(element, **kw): + if element in column_translate[-1]: + return column_translate[-1][element] + + elif element in cloned: + return cloned[element] + + newelem = cloned[element] = element._clone() + + if newelem.__visit_name__ is join_name and \ + isinstance(newelem.right, sql.FromGrouping): + + newelem._reset_exported() + newelem.left = visit(newelem.left, **kw) + + right = visit(newelem.right, **kw) + + selectable = sql.select( + [right.element], + use_labels=True).alias() + + for c in selectable.c: + c._label = c._key_label = c.name + translate_dict = dict( + zip(right.element.c, selectable.c) + ) + translate_dict[right.element.left] = selectable + translate_dict[right.element.right] = selectable + + # propagate translations that we've gained + # from nested visit(newelem.right) outwards + # to the enclosing select here. this happens + # only when we have more than one level of right + # join nesting, i.e. "a JOIN (b JOIN (c JOIN d))" + for k, v in list(column_translate[-1].items()): + if v in translate_dict: + # remarkably, no current ORM tests (May 2013) + # hit this condition, only test_join_rewriting + # does. + column_translate[-1][k] = translate_dict[v] + + column_translate[-1].update(translate_dict) + + newelem.right = selectable + newelem.onclause = visit(newelem.onclause, **kw) + elif newelem.__visit_name__ is select_name: + column_translate.append({}) + newelem._copy_internals(clone=visit, **kw) + del column_translate[-1] + else: + newelem._copy_internals(clone=visit, **kw) + + return newelem + + return visit(select) + + def _transform_result_map_for_nested_joins(self, select, transformed_select): + inner_col = dict((c._key_label, c) for + c in transformed_select.inner_columns) + d = dict( + (inner_col[c._key_label], c) + for c in select.inner_columns + ) + for key, (name, objs, typ) in list(self.result_map.items()): + objs = tuple([d.get(col, col) for col in objs]) + self.result_map[key] = (name, objs, typ) + def visit_select(self, select, asfrom=False, parens=True, iswrapper=False, fromhints=None, compound_index=0, force_result_map=False, - positional_names=None, **kwargs): - entry = self.stack and self.stack[-1] or {} - - existingfroms = entry.get('from', None) - - froms = select._get_display_froms(existingfroms, asfrom=asfrom) - - correlate_froms = set(sql._from_objects(*froms)) + positional_names=None, + nested_join_translation=False, **kwargs): + + needs_nested_translation = \ + select.use_labels and \ + not nested_join_translation and \ + not self.stack and \ + not self.dialect.supports_right_nested_joins + + if needs_nested_translation: + transformed_select = self._transform_select_for_nested_joins(select) + text = self.visit_select( + transformed_select, asfrom=asfrom, parens=parens, + iswrapper=iswrapper, fromhints=fromhints, + compound_index=compound_index, + force_result_map=force_result_map, + positional_names=positional_names, + nested_join_translation=True, **kwargs + ) - # TODO: might want to propagate existing froms for - # select(select(select)) where innermost select should correlate - # to outermost if existingfroms: correlate_froms = - # correlate_froms.union(existingfroms) + entry = self.stack and self.stack[-1] or {} populate_result_map = force_result_map or ( compound_index == 0 and ( @@ -1138,6 +1230,19 @@ class SQLCompiler(engine.Compiled): ) ) + if needs_nested_translation: + if populate_result_map: + self._transform_result_map_for_nested_joins( + select, transformed_select) + return text + + existingfroms = entry.get('from', None) + + froms = select._get_display_froms(existingfroms, asfrom=asfrom) + + correlate_froms = set(sql._from_objects(*froms)) + + self.stack.append({'from': correlate_froms, 'iswrapper': iswrapper}) diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 6dc134d98..f0c6134e5 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -795,7 +795,7 @@ def intersect_all(*selects, **kwargs): return CompoundSelect(CompoundSelect.INTERSECT_ALL, *selects, **kwargs) -def alias(selectable, name=None): +def alias(selectable, name=None, flat=False): """Return an :class:`.Alias` object. An :class:`.Alias` represents any :class:`.FromClause` @@ -2636,7 +2636,7 @@ class FromClause(Selectable): return Join(self, right, onclause, True) - def alias(self, name=None): + def alias(self, name=None, flat=False): """return an alias of this :class:`.FromClause`. This is shorthand for calling:: @@ -3980,7 +3980,7 @@ class Join(FromClause): def bind(self): return self.left.bind or self.right.bind - def alias(self, name=None): + def alias(self, name=None, flat=False): """return an alias of this :class:`.Join`. Used against a :class:`.Join` object, @@ -4008,7 +4008,17 @@ class Join(FromClause): aliases. """ - return self.select(use_labels=True, correlate=False).alias(name) + if flat: + assert name is None, "Can't send name argument with flat" + left_a, right_a = self.left.alias(flat=True), \ + self.right.alias(flat=True) + adapter = sqlutil.ClauseAdapter(left_a).\ + chain(sqlutil.ClauseAdapter(right_a)) + + return left_a.join(right_a, + adapter.traverse(self.onclause), isouter=self.isouter) + else: + return self.select(use_labels=True, correlate=False).alias(name) @property def _hide_froms(self): @@ -4138,7 +4148,7 @@ class CTE(Alias): self._restates = _restates super(CTE, self).__init__(selectable, name=name) - def alias(self, name=None): + def alias(self, name=None, flat=False): return CTE( self.original, name=name, @@ -4221,10 +4231,10 @@ class FromGrouping(FromClause): @property def foreign_keys(self): - # this could be - # self.element.foreign_keys - # see SelectableTest.test_join_condition - return set() + return self.element.foreign_keys + + def is_derived_from(self, element): + return self.element.is_derived_from(element) @property def _hide_froms(self): diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 91740dc16..6f4d27e1b 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -200,15 +200,28 @@ def clause_is_present(clause, search): """ - stack = [search] - while stack: - elem = stack.pop() + for elem in surface_selectables(search): if clause == elem: # use == here so that Annotated's compare return True - elif isinstance(elem, expression.Join): + else: + return False + +def surface_selectables(clause): + stack = [clause] + while stack: + elem = stack.pop() + yield elem + if isinstance(elem, expression.Join): stack.extend((elem.left, elem.right)) - return False +def selectables_overlap(left, right): + """Return True if left/right have some overlapping selectable""" + + return bool( + set(surface_selectables(left)).intersection( + surface_selectables(right) + ) + ) def bind_values(clause): """Return an ordered list of "bound" values in the given clause. @@ -797,8 +810,11 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): def __init__(self, selectable, equivalents=None, include=None, exclude=None, include_fn=None, exclude_fn=None, - adapt_on_names=False): + adapt_on_names=False, + traverse_options=None): self.__traverse_options__ = {'stop_on': [selectable]} + if traverse_options: + self.__traverse_options__.update(traverse_options) self.selectable = selectable if include: assert not include_fn @@ -829,10 +845,11 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): newcol = self.selectable.c.get(col.name) return newcol + magic_flag = False def replace(self, col): - if isinstance(col, expression.FromClause) and \ + if not self.magic_flag and isinstance(col, expression.FromClause) and \ self.selectable.is_derived_from(col): - return self.selectable + return self.selectable elif not isinstance(col, expression.ColumnElement): return None elif self.include_fn and not self.include_fn(col): diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 62f46ab64..c5a45ffd4 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -30,6 +30,7 @@ import operator __all__ = ['VisitableType', 'Visitable', 'ClauseVisitor', 'CloningVisitor', 'ReplacingCloningVisitor', 'iterate', 'iterate_depthfirst', 'traverse_using', 'traverse', + 'traverse_depthfirst', 'cloned_traverse', 'replacement_traverse'] @@ -255,7 +256,11 @@ def cloned_traverse(obj, opts, visitors): """clone the given expression structure, allowing modifications by visitors.""" - cloned = util.column_dict() + + if "cloned" in opts: + cloned = opts['cloned'] + else: + cloned = util.column_dict() stop_on = util.column_set(opts.get('stop_on', [])) def clone(elem): @@ -281,10 +286,12 @@ def replacement_traverse(obj, opts, replace): cloned = util.column_dict() stop_on = util.column_set([id(x) for x in opts.get('stop_on', [])]) + unconditional = opts.get('unconditional', False) def clone(elem, **kw): if id(elem) in stop_on or \ - 'no_replacement_traverse' in elem._annotations: + (not unconditional + and 'no_replacement_traverse' in elem._annotations): return elem else: newelem = replace(elem) diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index c04153961..96a8bc023 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -184,15 +184,20 @@ class AssertsCompiledSQL(object): allow_dialect_select=False): if use_default_dialect: dialect = default.DefaultDialect() - elif dialect == None and not allow_dialect_select: - dialect = getattr(self, '__dialect__', None) - if dialect == 'default': - dialect = default.DefaultDialect() - elif dialect is None: + elif allow_dialect_select: + dialect = None + else: + if dialect is None: + dialect = getattr(self, '__dialect__', None) + + if dialect is None: dialect = config.db.dialect + elif dialect == 'default': + dialect = default.DefaultDialect() elif isinstance(dialect, util.string_types): dialect = create_engine("%s://" % dialect).dialect + kw = {} if params is not None: kw['column_keys'] = list(params) |
