summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/sqlite/base.py1
-rw-r--r--lib/sqlalchemy/engine/default.py2
-rw-r--r--lib/sqlalchemy/orm/query.py26
-rw-r--r--lib/sqlalchemy/sql/compiler.py14
-rw-r--r--lib/sqlalchemy/testing/assertions.py13
5 files changed, 44 insertions, 12 deletions
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py
index 1ca8f4e64..c7e09b164 100644
--- a/lib/sqlalchemy/dialects/sqlite/base.py
+++ b/lib/sqlalchemy/dialects/sqlite/base.py
@@ -592,6 +592,7 @@ class SQLiteDialect(default.DefaultDialect):
supports_empty_insert = False
supports_cast = True
supports_multivalues_insert = True
+ supports_right_nested_joins = False
default_paramstyle = 'qmark'
execution_ctx_cls = SQLiteExecutionContext
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 85d11ff36..94f59d388 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -49,6 +49,8 @@ class DefaultDialect(interfaces.Dialect):
postfetch_lastrowid = True
implicit_returning = False
+ supports_right_nested_joins = True
+
supports_native_enum = False
supports_native_boolean = False
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 66d208a74..80fee4a09 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -1797,6 +1797,15 @@ class Query(object):
right_entity, onclause,
outerjoin, create_aliases, prop)
+ def _tables_overlap(self, left, right):
+ """Return True if parent/child tables have some overlap."""
+
+ return bool(
+ set(sql_util.find_tables(left)).intersection(
+ sql_util.find_tables(right)
+ )
+ )
+
def _join_left_to_right(self, left, right,
onclause, outerjoin, create_aliases, prop):
"""append a JOIN to the query's from clause."""
@@ -1816,10 +1825,16 @@ class Query(object):
"are the same entity" %
(left, right))
+ # TODO: get the l_info, r_info passed into
+ # the methods so inspect() doesnt need to be called again
+ l_info = inspect(left)
+ r_info = inspect(right)
+ overlap = self._tables_overlap(l_info.selectable, r_info.selectable)
+
right, onclause = self._prepare_right_side(
right, onclause,
create_aliases,
- prop)
+ prop, overlap)
# if joining on a MapperProperty path,
# track the path to prevent redundant joins
@@ -1833,7 +1848,7 @@ class Query(object):
self._join_to_left(left, right, onclause, outerjoin)
- def _prepare_right_side(self, right, onclause, create_aliases, prop):
+ def _prepare_right_side(self, right, onclause, create_aliases, prop, overlap):
info = inspect(right)
right_mapper, right_selectable, right_is_aliased = \
@@ -1875,9 +1890,10 @@ class Query(object):
not right_is_aliased and \
(
right_mapper.with_polymorphic
- #isinstance(
- # right_mapper.mapped_table,
- # expression.Join)
+ or
+ overlap # test for overlap:
+ # orm/inheritance/relationships.py
+ # SelfReferentialM2MTest
)
if not need_adapter and (create_aliases or aliased_entity):
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index ff041d5e4..d5ba64938 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1079,7 +1079,10 @@ class SQLCompiler(engine.Compiled):
def _transform_select_for_nested_joins(self, select):
adapters = []
+ stop_on = []
+ # test for "unconditional" - any statement with
+ # no_replacement_traverse setup, i.e. query.statement, from_self(), etc.
traverse_options = {"cloned": {}, "unconditional": True}
def visit_join(elem):
@@ -1090,6 +1093,12 @@ class SQLCompiler(engine.Compiled):
while adapters:
adapt = adapters.pop(-1)
selectable = adapt.traverse(selectable)
+ #stop_on.append(selectable)
+
+ # test: see test_subquery_relations:
+ # CyclicalInheritingEagerTestTwo.test_integrate
+ stop_on.append(elem.left)
+
for c in selectable.c:
c._label = c._key_label = c.name
@@ -1097,6 +1106,7 @@ class SQLCompiler(engine.Compiled):
elem.right = selectable
adapter = sql_util.ClauseAdapter(selectable,
traverse_options=traverse_options)
+ adapter.__traverse_options__['stop_on'].extend(stop_on)
adapters.append(adapter)
select = visitors.cloned_traverse(select,
@@ -1119,7 +1129,9 @@ class SQLCompiler(engine.Compiled):
positional_names=None,
nested_join_translation=False, **kwargs):
- #nested_join_translation = True
+
+ if self.dialect.supports_right_nested_joins:
+ nested_join_translation = True
if not nested_join_translation:
transformed_select = self._transform_select_for_nested_joins(select)
text = self.visit_select(
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index c04153961..592467302 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -186,12 +186,13 @@ class AssertsCompiledSQL(object):
dialect = default.DefaultDialect()
elif dialect == None and not allow_dialect_select:
dialect = getattr(self, '__dialect__', None)
- if dialect == 'default':
- dialect = default.DefaultDialect()
- elif dialect is None:
- dialect = config.db.dialect
- elif isinstance(dialect, util.string_types):
- dialect = create_engine("%s://" % dialect).dialect
+
+ if dialect == 'default':
+ dialect = default.DefaultDialect()
+ elif dialect is None:
+ dialect = config.db.dialect
+ elif isinstance(dialect, util.string_types):
+ dialect = create_engine("%s://" % dialect).dialect
kw = {}
if params is not None: