summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/sqlite/base.py1
-rw-r--r--lib/sqlalchemy/engine/default.py5
-rw-r--r--lib/sqlalchemy/orm/query.py44
-rw-r--r--lib/sqlalchemy/orm/relationships.py8
-rw-r--r--lib/sqlalchemy/orm/util.py5
-rw-r--r--lib/sqlalchemy/sql/compiler.py129
-rw-r--r--lib/sqlalchemy/sql/expression.py28
-rw-r--r--lib/sqlalchemy/sql/util.py33
-rw-r--r--lib/sqlalchemy/sql/visitors.py11
-rw-r--r--lib/sqlalchemy/testing/assertions.py15
10 files changed, 222 insertions, 57 deletions
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py
index 1ca8f4e64..c7e09b164 100644
--- a/lib/sqlalchemy/dialects/sqlite/base.py
+++ b/lib/sqlalchemy/dialects/sqlite/base.py
@@ -592,6 +592,7 @@ class SQLiteDialect(default.DefaultDialect):
supports_empty_insert = False
supports_cast = True
supports_multivalues_insert = True
+ supports_right_nested_joins = False
default_paramstyle = 'qmark'
execution_ctx_cls = SQLiteExecutionContext
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 91869ab75..2ad7002c4 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -49,6 +49,8 @@ class DefaultDialect(interfaces.Dialect):
postfetch_lastrowid = True
implicit_returning = False
+ supports_right_nested_joins = True
+
supports_native_enum = False
supports_native_boolean = False
@@ -106,6 +108,7 @@ class DefaultDialect(interfaces.Dialect):
def __init__(self, convert_unicode=False,
encoding='utf-8', paramstyle=None, dbapi=None,
implicit_returning=None,
+ supports_right_nested_joins=None,
case_sensitive=True,
label_length=None, **kwargs):
@@ -130,6 +133,8 @@ class DefaultDialect(interfaces.Dialect):
self.positional = self.paramstyle in ('qmark', 'format', 'numeric')
self.identifier_preparer = self.preparer(self)
self.type_compiler = self.type_compiler(self)
+ if supports_right_nested_joins is not None:
+ self.supports_right_nested_joins = supports_right_nested_joins
self.case_sensitive = case_sensitive
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index beae7aba0..39ed8d8bf 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -447,6 +447,8 @@ class Query(object):
statement
if self._params:
stmt = stmt.params(self._params)
+
+
# TODO: there's no tests covering effects of
# the annotation not being there
return stmt._annotate({'no_replacement_traverse': True})
@@ -1795,6 +1797,7 @@ class Query(object):
right_entity, onclause,
outerjoin, create_aliases, prop)
+
def _join_left_to_right(self, left, right,
onclause, outerjoin, create_aliases, prop):
"""append a JOIN to the query's from clause."""
@@ -1814,10 +1817,21 @@ class Query(object):
"are the same entity" %
(left, right))
+ l_info = inspect(left)
+ r_info = inspect(right)
+
+ overlap = not create_aliases and \
+ sql_util.selectables_overlap(l_info.selectable,
+ r_info.selectable)
+ if overlap and l_info.selectable is r_info.selectable:
+ raise sa_exc.InvalidRequestError(
+ "Can't join table/selectable '%s' to itself" %
+ l_info.selectable)
+
right, onclause = self._prepare_right_side(
- right, onclause,
+ r_info, right, onclause,
create_aliases,
- prop)
+ prop, overlap)
# if joining on a MapperProperty path,
# track the path to prevent redundant joins
@@ -1829,10 +1843,11 @@ class Query(object):
else:
self._joinpoint = {'_joinpoint_entity': right}
- self._join_to_left(left, right, onclause, outerjoin)
+ self._join_to_left(l_info, left, right, onclause, outerjoin)
- def _prepare_right_side(self, right, onclause, create_aliases, prop):
- info = inspect(right)
+ def _prepare_right_side(self, r_info, right, onclause, create_aliases,
+ prop, overlap):
+ info = r_info
right_mapper, right_selectable, right_is_aliased = \
getattr(info, 'mapper', None), \
@@ -1862,19 +1877,23 @@ class Query(object):
(right_selectable.description,
right_mapper.mapped_table.description))
- if not isinstance(right_selectable, expression.Alias):
+ if isinstance(right_selectable, expression.SelectBase):
+ # TODO: this isn't even covered now!
right_selectable = right_selectable.alias()
+ need_adapter = True
right = aliased(right_mapper, right_selectable)
- need_adapter = True
aliased_entity = right_mapper and \
not right_is_aliased and \
(
- right_mapper.with_polymorphic or
isinstance(
- right_mapper.mapped_table,
- expression.Join)
+ right_mapper._with_polymorphic_selectable,
+ expression.Alias)
+ or
+ overlap # test for overlap:
+ # orm/inheritance/relationships.py
+ # SelfReferentialM2MTest
)
if not need_adapter and (create_aliases or aliased_entity):
@@ -1910,8 +1929,8 @@ class Query(object):
return right, onclause
- def _join_to_left(self, left, right, onclause, outerjoin):
- info = inspect(left)
+ def _join_to_left(self, l_info, left, right, onclause, outerjoin):
+ info = l_info
left_mapper = getattr(info, 'mapper', None)
left_selectable = info.selectable
@@ -1946,7 +1965,6 @@ class Query(object):
clause = left_selectable
assert clause is not None
-
try:
clause = orm_join(clause, right, onclause, isouter=outerjoin)
except sa_exc.ArgumentError as ae:
diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py
index 95fa28613..33377d3ec 100644
--- a/lib/sqlalchemy/orm/relationships.py
+++ b/lib/sqlalchemy/orm/relationships.py
@@ -17,7 +17,7 @@ from .. import sql, util, exc as sa_exc, schema
from ..sql.util import (
ClauseAdapter,
join_condition, _shallow_annotate, visit_binary_product,
- _deep_deannotate, find_tables
+ _deep_deannotate, find_tables, selectables_overlap
)
from ..sql import operators, expression, visitors
from .interfaces import MANYTOMANY, MANYTOONE, ONETOMANY
@@ -404,11 +404,7 @@ class JoinCondition(object):
def _tables_overlap(self):
"""Return True if parent/child tables have some overlap."""
- return bool(
- set(find_tables(self.parent_selectable)).intersection(
- find_tables(self.child_selectable)
- )
- )
+ return selectables_overlap(self.parent_selectable, self.child_selectable)
def _annotate_remote(self):
"""Annotate the primaryjoin and secondaryjoin
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index bd8228f2c..c21e7eace 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -493,6 +493,7 @@ class AliasedClass(object):
"""
def __init__(self, cls, alias=None,
name=None,
+ flat=True,
adapt_on_names=False,
# TODO: None for default here?
with_polymorphic_mappers=(),
@@ -501,7 +502,7 @@ class AliasedClass(object):
use_mapper_path=False):
mapper = _class_to_mapper(cls)
if alias is None:
- alias = mapper._with_polymorphic_selectable.alias(name=name)
+ alias = mapper._with_polymorphic_selectable.alias(name=name, flat=flat)
self._aliased_insp = AliasedInsp(
self,
mapper,
@@ -837,7 +838,7 @@ def with_polymorphic(base, classes, selectable=False,
_with_polymorphic_args(classes, selectable,
innerjoin=innerjoin)
if aliased:
- selectable = selectable.alias()
+ selectable = selectable.alias(flat=True)
return AliasedClass(base,
selectable,
with_polymorphic_mappers=mappers,
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 73b094053..dd2a6e08c 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -1113,23 +1113,115 @@ class SQLCompiler(engine.Compiled):
def get_crud_hint_text(self, table, text):
return None
+ def _transform_select_for_nested_joins(self, select):
+ """Rewrite any "a JOIN (b JOIN c)" expression as
+ "a JOIN (select * from b JOIN c) AS anon", to support
+ databases that can't parse a parenthesized join correctly
+ (i.e. sqlite the main one).
+
+ """
+ cloned = {}
+ column_translate = [{}]
+
+ # TODO: should we be using isinstance() for this,
+ # as this whole system won't work for custom Join/Select
+ # subclasses where compilation routines
+ # call down to compiler.visit_join(), compiler.visit_select()
+ join_name = sql.Join.__visit_name__
+ select_name = sql.Select.__visit_name__
+
+ def visit(element, **kw):
+ if element in column_translate[-1]:
+ return column_translate[-1][element]
+
+ elif element in cloned:
+ return cloned[element]
+
+ newelem = cloned[element] = element._clone()
+
+ if newelem.__visit_name__ is join_name and \
+ isinstance(newelem.right, sql.FromGrouping):
+
+ newelem._reset_exported()
+ newelem.left = visit(newelem.left, **kw)
+
+ right = visit(newelem.right, **kw)
+
+ selectable = sql.select(
+ [right.element],
+ use_labels=True).alias()
+
+ for c in selectable.c:
+ c._label = c._key_label = c.name
+ translate_dict = dict(
+ zip(right.element.c, selectable.c)
+ )
+ translate_dict[right.element.left] = selectable
+ translate_dict[right.element.right] = selectable
+
+ # propagate translations that we've gained
+ # from nested visit(newelem.right) outwards
+ # to the enclosing select here. this happens
+ # only when we have more than one level of right
+ # join nesting, i.e. "a JOIN (b JOIN (c JOIN d))"
+ for k, v in list(column_translate[-1].items()):
+ if v in translate_dict:
+ # remarkably, no current ORM tests (May 2013)
+ # hit this condition, only test_join_rewriting
+ # does.
+ column_translate[-1][k] = translate_dict[v]
+
+ column_translate[-1].update(translate_dict)
+
+ newelem.right = selectable
+ newelem.onclause = visit(newelem.onclause, **kw)
+ elif newelem.__visit_name__ is select_name:
+ column_translate.append({})
+ newelem._copy_internals(clone=visit, **kw)
+ del column_translate[-1]
+ else:
+ newelem._copy_internals(clone=visit, **kw)
+
+ return newelem
+
+ return visit(select)
+
+ def _transform_result_map_for_nested_joins(self, select, transformed_select):
+ inner_col = dict((c._key_label, c) for
+ c in transformed_select.inner_columns)
+ d = dict(
+ (inner_col[c._key_label], c)
+ for c in select.inner_columns
+ )
+ for key, (name, objs, typ) in list(self.result_map.items()):
+ objs = tuple([d.get(col, col) for col in objs])
+ self.result_map[key] = (name, objs, typ)
+
def visit_select(self, select, asfrom=False, parens=True,
iswrapper=False, fromhints=None,
compound_index=0,
force_result_map=False,
- positional_names=None, **kwargs):
- entry = self.stack and self.stack[-1] or {}
-
- existingfroms = entry.get('from', None)
-
- froms = select._get_display_froms(existingfroms, asfrom=asfrom)
-
- correlate_froms = set(sql._from_objects(*froms))
+ positional_names=None,
+ nested_join_translation=False, **kwargs):
+
+ needs_nested_translation = \
+ select.use_labels and \
+ not nested_join_translation and \
+ not self.stack and \
+ not self.dialect.supports_right_nested_joins
+
+ if needs_nested_translation:
+ transformed_select = self._transform_select_for_nested_joins(select)
+ text = self.visit_select(
+ transformed_select, asfrom=asfrom, parens=parens,
+ iswrapper=iswrapper, fromhints=fromhints,
+ compound_index=compound_index,
+ force_result_map=force_result_map,
+ positional_names=positional_names,
+ nested_join_translation=True, **kwargs
+ )
- # TODO: might want to propagate existing froms for
- # select(select(select)) where innermost select should correlate
- # to outermost if existingfroms: correlate_froms =
- # correlate_froms.union(existingfroms)
+ entry = self.stack and self.stack[-1] or {}
populate_result_map = force_result_map or (
compound_index == 0 and (
@@ -1138,6 +1230,19 @@ class SQLCompiler(engine.Compiled):
)
)
+ if needs_nested_translation:
+ if populate_result_map:
+ self._transform_result_map_for_nested_joins(
+ select, transformed_select)
+ return text
+
+ existingfroms = entry.get('from', None)
+
+ froms = select._get_display_froms(existingfroms, asfrom=asfrom)
+
+ correlate_froms = set(sql._from_objects(*froms))
+
+
self.stack.append({'from': correlate_froms,
'iswrapper': iswrapper})
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index 6dc134d98..f0c6134e5 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -795,7 +795,7 @@ def intersect_all(*selects, **kwargs):
return CompoundSelect(CompoundSelect.INTERSECT_ALL, *selects, **kwargs)
-def alias(selectable, name=None):
+def alias(selectable, name=None, flat=False):
"""Return an :class:`.Alias` object.
An :class:`.Alias` represents any :class:`.FromClause`
@@ -2636,7 +2636,7 @@ class FromClause(Selectable):
return Join(self, right, onclause, True)
- def alias(self, name=None):
+ def alias(self, name=None, flat=False):
"""return an alias of this :class:`.FromClause`.
This is shorthand for calling::
@@ -3980,7 +3980,7 @@ class Join(FromClause):
def bind(self):
return self.left.bind or self.right.bind
- def alias(self, name=None):
+ def alias(self, name=None, flat=False):
"""return an alias of this :class:`.Join`.
Used against a :class:`.Join` object,
@@ -4008,7 +4008,17 @@ class Join(FromClause):
aliases.
"""
- return self.select(use_labels=True, correlate=False).alias(name)
+ if flat:
+ assert name is None, "Can't send name argument with flat"
+ left_a, right_a = self.left.alias(flat=True), \
+ self.right.alias(flat=True)
+ adapter = sqlutil.ClauseAdapter(left_a).\
+ chain(sqlutil.ClauseAdapter(right_a))
+
+ return left_a.join(right_a,
+ adapter.traverse(self.onclause), isouter=self.isouter)
+ else:
+ return self.select(use_labels=True, correlate=False).alias(name)
@property
def _hide_froms(self):
@@ -4138,7 +4148,7 @@ class CTE(Alias):
self._restates = _restates
super(CTE, self).__init__(selectable, name=name)
- def alias(self, name=None):
+ def alias(self, name=None, flat=False):
return CTE(
self.original,
name=name,
@@ -4221,10 +4231,10 @@ class FromGrouping(FromClause):
@property
def foreign_keys(self):
- # this could be
- # self.element.foreign_keys
- # see SelectableTest.test_join_condition
- return set()
+ return self.element.foreign_keys
+
+ def is_derived_from(self, element):
+ return self.element.is_derived_from(element)
@property
def _hide_froms(self):
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 91740dc16..6f4d27e1b 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -200,15 +200,28 @@ def clause_is_present(clause, search):
"""
- stack = [search]
- while stack:
- elem = stack.pop()
+ for elem in surface_selectables(search):
if clause == elem: # use == here so that Annotated's compare
return True
- elif isinstance(elem, expression.Join):
+ else:
+ return False
+
+def surface_selectables(clause):
+ stack = [clause]
+ while stack:
+ elem = stack.pop()
+ yield elem
+ if isinstance(elem, expression.Join):
stack.extend((elem.left, elem.right))
- return False
+def selectables_overlap(left, right):
+ """Return True if left/right have some overlapping selectable"""
+
+ return bool(
+ set(surface_selectables(left)).intersection(
+ surface_selectables(right)
+ )
+ )
def bind_values(clause):
"""Return an ordered list of "bound" values in the given clause.
@@ -797,8 +810,11 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
def __init__(self, selectable, equivalents=None,
include=None, exclude=None,
include_fn=None, exclude_fn=None,
- adapt_on_names=False):
+ adapt_on_names=False,
+ traverse_options=None):
self.__traverse_options__ = {'stop_on': [selectable]}
+ if traverse_options:
+ self.__traverse_options__.update(traverse_options)
self.selectable = selectable
if include:
assert not include_fn
@@ -829,10 +845,11 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
newcol = self.selectable.c.get(col.name)
return newcol
+ magic_flag = False
def replace(self, col):
- if isinstance(col, expression.FromClause) and \
+ if not self.magic_flag and isinstance(col, expression.FromClause) and \
self.selectable.is_derived_from(col):
- return self.selectable
+ return self.selectable
elif not isinstance(col, expression.ColumnElement):
return None
elif self.include_fn and not self.include_fn(col):
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 62f46ab64..c5a45ffd4 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -30,6 +30,7 @@ import operator
__all__ = ['VisitableType', 'Visitable', 'ClauseVisitor',
'CloningVisitor', 'ReplacingCloningVisitor', 'iterate',
'iterate_depthfirst', 'traverse_using', 'traverse',
+ 'traverse_depthfirst',
'cloned_traverse', 'replacement_traverse']
@@ -255,7 +256,11 @@ def cloned_traverse(obj, opts, visitors):
"""clone the given expression structure, allowing
modifications by visitors."""
- cloned = util.column_dict()
+
+ if "cloned" in opts:
+ cloned = opts['cloned']
+ else:
+ cloned = util.column_dict()
stop_on = util.column_set(opts.get('stop_on', []))
def clone(elem):
@@ -281,10 +286,12 @@ def replacement_traverse(obj, opts, replace):
cloned = util.column_dict()
stop_on = util.column_set([id(x) for x in opts.get('stop_on', [])])
+ unconditional = opts.get('unconditional', False)
def clone(elem, **kw):
if id(elem) in stop_on or \
- 'no_replacement_traverse' in elem._annotations:
+ (not unconditional
+ and 'no_replacement_traverse' in elem._annotations):
return elem
else:
newelem = replace(elem)
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index c04153961..96a8bc023 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -184,15 +184,20 @@ class AssertsCompiledSQL(object):
allow_dialect_select=False):
if use_default_dialect:
dialect = default.DefaultDialect()
- elif dialect == None and not allow_dialect_select:
- dialect = getattr(self, '__dialect__', None)
- if dialect == 'default':
- dialect = default.DefaultDialect()
- elif dialect is None:
+ elif allow_dialect_select:
+ dialect = None
+ else:
+ if dialect is None:
+ dialect = getattr(self, '__dialect__', None)
+
+ if dialect is None:
dialect = config.db.dialect
+ elif dialect == 'default':
+ dialect = default.DefaultDialect()
elif isinstance(dialect, util.string_types):
dialect = create_engine("%s://" % dialect).dialect
+
kw = {}
if params is not None:
kw['column_keys'] = list(params)