diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/dialects/sqlite/base.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/query.py | 26 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 14 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 13 |
5 files changed, 44 insertions, 12 deletions
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 1ca8f4e64..c7e09b164 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -592,6 +592,7 @@ class SQLiteDialect(default.DefaultDialect): supports_empty_insert = False supports_cast = True supports_multivalues_insert = True + supports_right_nested_joins = False default_paramstyle = 'qmark' execution_ctx_cls = SQLiteExecutionContext diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 85d11ff36..94f59d388 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -49,6 +49,8 @@ class DefaultDialect(interfaces.Dialect): postfetch_lastrowid = True implicit_returning = False + supports_right_nested_joins = True + supports_native_enum = False supports_native_boolean = False diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 66d208a74..80fee4a09 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1797,6 +1797,15 @@ class Query(object): right_entity, onclause, outerjoin, create_aliases, prop) + def _tables_overlap(self, left, right): + """Return True if parent/child tables have some overlap.""" + + return bool( + set(sql_util.find_tables(left)).intersection( + sql_util.find_tables(right) + ) + ) + def _join_left_to_right(self, left, right, onclause, outerjoin, create_aliases, prop): """append a JOIN to the query's from clause.""" @@ -1816,10 +1825,16 @@ class Query(object): "are the same entity" % (left, right)) + # TODO: get the l_info, r_info passed into + # the methods so inspect() doesnt need to be called again + l_info = inspect(left) + r_info = inspect(right) + overlap = self._tables_overlap(l_info.selectable, r_info.selectable) + right, onclause = self._prepare_right_side( right, onclause, create_aliases, - prop) + prop, overlap) # if joining on a MapperProperty path, # track the path to prevent redundant joins @@ -1833,7 +1848,7 @@ class Query(object): self._join_to_left(left, right, onclause, outerjoin) - def _prepare_right_side(self, right, onclause, create_aliases, prop): + def _prepare_right_side(self, right, onclause, create_aliases, prop, overlap): info = inspect(right) right_mapper, right_selectable, right_is_aliased = \ @@ -1875,9 +1890,10 @@ class Query(object): not right_is_aliased and \ ( right_mapper.with_polymorphic - #isinstance( - # right_mapper.mapped_table, - # expression.Join) + or + overlap # test for overlap: + # orm/inheritance/relationships.py + # SelfReferentialM2MTest ) if not need_adapter and (create_aliases or aliased_entity): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index ff041d5e4..d5ba64938 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1079,7 +1079,10 @@ class SQLCompiler(engine.Compiled): def _transform_select_for_nested_joins(self, select): adapters = [] + stop_on = [] + # test for "unconditional" - any statement with + # no_replacement_traverse setup, i.e. query.statement, from_self(), etc. traverse_options = {"cloned": {}, "unconditional": True} def visit_join(elem): @@ -1090,6 +1093,12 @@ class SQLCompiler(engine.Compiled): while adapters: adapt = adapters.pop(-1) selectable = adapt.traverse(selectable) + #stop_on.append(selectable) + + # test: see test_subquery_relations: + # CyclicalInheritingEagerTestTwo.test_integrate + stop_on.append(elem.left) + for c in selectable.c: c._label = c._key_label = c.name @@ -1097,6 +1106,7 @@ class SQLCompiler(engine.Compiled): elem.right = selectable adapter = sql_util.ClauseAdapter(selectable, traverse_options=traverse_options) + adapter.__traverse_options__['stop_on'].extend(stop_on) adapters.append(adapter) select = visitors.cloned_traverse(select, @@ -1119,7 +1129,9 @@ class SQLCompiler(engine.Compiled): positional_names=None, nested_join_translation=False, **kwargs): - #nested_join_translation = True + + if self.dialect.supports_right_nested_joins: + nested_join_translation = True if not nested_join_translation: transformed_select = self._transform_select_for_nested_joins(select) text = self.visit_select( diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index c04153961..592467302 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -186,12 +186,13 @@ class AssertsCompiledSQL(object): dialect = default.DefaultDialect() elif dialect == None and not allow_dialect_select: dialect = getattr(self, '__dialect__', None) - if dialect == 'default': - dialect = default.DefaultDialect() - elif dialect is None: - dialect = config.db.dialect - elif isinstance(dialect, util.string_types): - dialect = create_engine("%s://" % dialect).dialect + + if dialect == 'default': + dialect = default.DefaultDialect() + elif dialect is None: + dialect = config.db.dialect + elif isinstance(dialect, util.string_types): + dialect = create_engine("%s://" % dialect).dialect kw = {} if params is not None: |
