summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGES3
-rw-r--r--lib/sqlalchemy/orm/query.py2
-rw-r--r--lib/sqlalchemy/sql/expression.py58
-rw-r--r--lib/sqlalchemy/sql/util.py17
-rw-r--r--lib/sqlalchemy/sql/visitors.py25
-rw-r--r--test/orm/eager_relations.py17
-rw-r--r--test/orm/query.py4
-rw-r--r--test/sql/generative.py60
8 files changed, 127 insertions, 59 deletions
diff --git a/CHANGES b/CHANGES
index ee5831d58..f76a85901 100644
--- a/CHANGES
+++ b/CHANGES
@@ -172,6 +172,9 @@ CHANGES
- fixed endless loop issue when using lazy="dynamic" on both
sides of a bi-directional relationship [ticket:872]
+ - more fixes to the LIMIT/OFFSET aliasing applied with Query + eagerloads,
+ in this case when mapped against a select statement [ticket:904]
+
- fix to self-referential eager loading such that if the same mapped
instance appears in two or more distinct sets of columns in the same
result set, its eagerly loaded collection will be populated regardless
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index d75a9b8b2..902a4fd3b 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -919,7 +919,7 @@ class Query(object):
adapt_criterion = self.table not in self._get_joinable_tables()
if not adapt_criterion and whereclause is not None and (self.mapper is not self.select_mapper):
- whereclause = sql_util.ClauseAdapter(from_obj).traverse(whereclause, stop_on=util.Set([from_obj]))
+ whereclause = sql_util.ClauseAdapter(from_obj).traverse(whereclause)
# TODO: mappers added via add_entity(), adapt their queries also,
# if those mappers are polymorphic
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index a4d8fa6a0..ddeaaf8ad 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -28,6 +28,7 @@ to stay the same in future releases.
import re
import datetime
import warnings
+from itertools import chain
from sqlalchemy import util, exceptions
from sqlalchemy.sql import operators, visitors
from sqlalchemy import types as sqltypes
@@ -864,6 +865,13 @@ class ClauseElement(object):
return c
+ def _cloned_set(self):
+ f = self
+ while f is not None:
+ yield f
+ f = getattr(f, '_is_clone_of', None)
+ _cloned_set = property(_cloned_set)
+
def _get_from_objects(self, **modifiers):
"""Return objects represented in this ``ClauseElement`` that
should be added to the ``FROM`` list of a query, when this
@@ -1543,7 +1551,8 @@ class FromClause(Selectable):
__visit_name__ = 'fromclause'
named_with_column=False
-
+ _hide_froms = []
+
def __init__(self):
self.oid_column = None
@@ -1588,7 +1597,7 @@ class FromClause(Selectable):
An example would be an Alias of a Table is derived from that Table.
"""
- return fromclause is self
+ return fromclause in util.Set(self._cloned_set)
def replace_selectable(self, old, alias):
"""replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``."""
@@ -1649,22 +1658,6 @@ class FromClause(Selectable):
return getattr(self, 'name', self.__class__.__name__ + " object")
description = property(description)
- def _aggregate_hide_froms(self, **modifiers):
- """Return a list of ``FROM`` clause elements which this ``FromClause`` replaces, taking into account
- the element which this element was cloned from (and so on until the orginal is reached).
- """
-
- s = self
- while s is not None:
- for h in s._hide_froms(**modifiers):
- yield h
- s = getattr(s, '_is_clone_of', None)
-
- def _hide_froms(self, **modifiers):
- """Return a list of ``FROM`` clause elements which this ``FromClause`` replaces."""
-
- return []
-
def _clone_from_clause(self):
# delete all the "generated" collections of columns for a
# newly cloned FromClause, so that they will be re-derived
@@ -2230,6 +2223,7 @@ class Join(FromClause):
def __init__(self, left, right, onclause=None, isouter = False):
self.left = _selectable(left)
self.right = _selectable(right).self_group()
+
self.oid_column = self.left.oid_column
if onclause is None:
self.onclause = self._match_primaries(self.left, self.right)
@@ -2303,7 +2297,7 @@ class Join(FromClause):
self.right = clone(self.right)
self.onclause = clone(self.onclause)
self.__folded_equivalents = None
-
+
def get_children(self, **kwargs):
return self.left, self.right, self.onclause
@@ -2409,9 +2403,10 @@ class Join(FromClause):
return self.select(use_labels=True, correlate=False).alias(name)
- def _hide_froms(self, **modifiers):
- return self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers)
-
+ def _hide_froms(self):
+ return chain(*[x.left._get_from_objects() + x.right._get_from_objects() for x in self._cloned_set])
+ _hide_froms = property(_hide_froms)
+
def _get_from_objects(self, **modifiers):
return [self] + self.onclause._get_from_objects(**modifiers) + self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers)
@@ -2450,6 +2445,8 @@ class Alias(FromClause):
description = property(description)
def is_derived_from(self, fromclause):
+ if fromclause in util.Set(self._cloned_set):
+ return True
return self.selectable.is_derived_from(fromclause)
def supports_execution(self):
@@ -2527,13 +2524,11 @@ class _FromGrouping(FromClause):
self.elem = elem
columns = c = property(lambda s:s.elem.columns)
-
+ _hide_froms = property(lambda s:s.elem._hide_froms)
+
def get_children(self, **kwargs):
return self.elem,
- def _hide_froms(self, **modifiers):
- return self.elem._hide_froms(**modifiers)
-
def _copy_internals(self, clone=_clone):
self.elem = clone(self.elem)
@@ -3066,7 +3061,6 @@ class Select(_SelectBaseMixin, FromClause):
"""
froms = util.OrderedSet()
- hide_froms = util.Set()
for col in self._raw_columns:
froms.update(col._get_from_objects())
@@ -3078,14 +3072,13 @@ class Select(_SelectBaseMixin, FromClause):
froms.update(self._froms)
for f in froms:
- hide_froms.update(f._aggregate_hide_froms())
- froms = froms.difference(hide_froms)
+ froms.difference_update(f._hide_froms)
if len(froms) > 1:
if self.__correlate:
- froms = froms.difference(self.__correlate)
+ froms.difference_update(self.__correlate)
if self._should_correlate and existing_froms is not None:
- froms = froms.difference(existing_froms)
+ froms.difference_update(existing_froms)
if not froms:
raise exceptions.InvalidRequestError("Select statement '%s' is overcorrelated; returned no 'from' clauses" % str(self.__dont_correlate()))
@@ -3129,6 +3122,9 @@ class Select(_SelectBaseMixin, FromClause):
inner_columns = property(_get_inner_columns, doc="""a collection of all ColumnElement expressions which would be rendered into the columns clause of the resulting SELECT statement.""")
def is_derived_from(self, fromclause):
+ if self in util.Set(fromclause._cloned_set):
+ return True
+
for f in self.locate_all_froms():
if f.is_derived_from(fromclause):
return True
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 5aa985f47..d6b10a78a 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -71,6 +71,9 @@ class AbstractClauseProcessor(object):
__traverse_options__ = {'column_collections':False}
+ def __init__(self, stop_on=None):
+ self.stop_on = stop_on
+
def convert_element(self, elem):
"""Define the *conversion* method for this ``AbstractClauseProcessor``."""
@@ -92,13 +95,14 @@ class AbstractClauseProcessor(object):
setattr(tail, attr, visitor)
return self
- def copy_and_process(self, list_, stop_on=None):
+ def copy_and_process(self, list_):
"""Copy the given list to a new list, with each element traversed individually."""
list_ = list(list_)
- stop_on = util.Set()
+ stop_on = util.Set(self.stop_on or [])
+ cloned = {}
for i in range(0, len(list_)):
- list_[i] = self.traverse(list_[i], stop_on=stop_on)
+ list_[i] = self._traverse(list_[i], stop_on, cloned, _clone_toplevel=True)
return list_
def _convert_element(self, elem, stop_on, cloned):
@@ -116,13 +120,11 @@ class AbstractClauseProcessor(object):
cloned[elem] = elem._clone()
return cloned[elem]
- def traverse(self, elem, clone=True, stop_on=None):
+ def traverse(self, elem, clone=True):
if not clone:
raise exceptions.ArgumentError("AbstractClauseProcessor 'clone' argument must be True")
- if stop_on is None:
- stop_on = util.Set()
- return self._traverse(elem, stop_on, {}, _clone_toplevel=True)
+ return self._traverse(elem, util.Set(self.stop_on or []), {}, _clone_toplevel=True)
def _traverse(self, elem, stop_on, cloned, _clone_toplevel=False):
if elem in stop_on:
@@ -178,6 +180,7 @@ class ClauseAdapter(AbstractClauseProcessor):
"""
def __init__(self, selectable, include=None, exclude=None, equivalents=None):
+ AbstractClauseProcessor.__init__(self, [selectable])
self.selectable = selectable
self.include = include
self.exclude = exclude
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 150ee9cc7..bb63ab09c 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -37,18 +37,17 @@ class ClauseVisitor(object):
meth(obj, **kwargs)
v = getattr(v, '_next', None)
- def iterate(self, obj, stop_on=None):
+ def iterate(self, obj):
stack = [obj]
traversal = []
while len(stack) > 0:
t = stack.pop()
- if stop_on is None or t not in stop_on:
- yield t
- traversal.insert(0, t)
- for c in t.get_children(**self.__traverse_options__):
- stack.append(c)
+ yield t
+ traversal.insert(0, t)
+ for c in t.get_children(**self.__traverse_options__):
+ stack.append(c)
- def traverse(self, obj, stop_on=None, clone=False):
+ def traverse(self, obj, clone=False):
if clone:
cloned = {}
@@ -60,17 +59,15 @@ class ClauseVisitor(object):
return cloned[obj]
obj = do_clone(obj)
-
stack = [obj]
traversal = []
while len(stack) > 0:
t = stack.pop()
- if stop_on is None or t not in stop_on:
- traversal.insert(0, t)
- if clone:
- t._copy_internals(clone=do_clone)
- for c in t.get_children(**self.__traverse_options__):
- stack.append(c)
+ traversal.insert(0, t)
+ if clone:
+ t._copy_internals(clone=do_clone)
+ for c in t.get_children(**self.__traverse_options__):
+ stack.append(c)
for target in traversal:
v = self
while v is not None:
diff --git a/test/orm/eager_relations.py b/test/orm/eager_relations.py
index 7a822234c..bef4ecffc 100644
--- a/test/orm/eager_relations.py
+++ b/test/orm/eager_relations.py
@@ -418,6 +418,23 @@ class EagerTest(FixtureTest):
)
] == l.all()
+ def test_limit_4(self):
+ # tests the LIMIT/OFFSET aliasing on a mapper against a select. original issue from ticket #904
+ sel = select([users, addresses.c.email_address], users.c.id==addresses.c.user_id).alias('useralias')
+ mapper(User, sel, properties={
+ 'orders':relation(Order, primaryjoin=sel.c.id==orders.c.user_id, lazy=False)
+ })
+ mapper(Order, orders)
+
+ sess = create_session()
+ self.assertEquals(sess.query(User).first(),
+ User(name=u'jack',orders=[
+ Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1),
+ Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3),
+ Order(address_id=None,description=u'order 5',isopen=0,user_id=7,id=5)],
+ email_address=u'jack@bean.com',id=7)
+ )
+
def test_one_to_many_scalar(self):
mapper(User, users, properties = dict(
address = relation(mapper(Address, addresses), lazy=False, uselist=False)
diff --git a/test/orm/query.py b/test/orm/query.py
index 1dc50ebbb..9471f1012 100644
--- a/test/orm/query.py
+++ b/test/orm/query.py
@@ -955,6 +955,10 @@ class SelectFromTest(QueryTest):
self.assertEquals(sess.query(User).select_from(sel).order_by(asc(User.name)).all(), [
User(name='ed',id=8), User(name='jack',id=7)
])
+
+ self.assertEquals(sess.query(User).select_from(sel).options(eagerload('addresses')).first(),
+ User(name='jack', addresses=[Address(id=1)])
+ )
def test_join(self):
mapper(User, users, properties = {
diff --git a/test/sql/generative.py b/test/sql/generative.py
index 847443330..41b4caebf 100644
--- a/test/sql/generative.py
+++ b/test/sql/generative.py
@@ -269,7 +269,22 @@ class ClauseTest(SQLCompileTest):
self.assert_compile(Vis().traverse(s, clone=True), "SELECT * FROM table1 WHERE table1.col1 = table2.col1 AND table1.col2 = :table1_col2_1")
- def test_clause_adapter(self):
+class ClauseAdapterTest(SQLCompileTest):
+ def setUpAll(self):
+ global t1, t2
+ t1 = table("table1",
+ column("col1"),
+ column("col2"),
+ column("col3"),
+ )
+ t2 = table("table2",
+ column("col1"),
+ column("col2"),
+ column("col3"),
+ )
+
+
+ def test_table_to_alias(self):
t1alias = t1.alias('t1alias')
@@ -302,7 +317,7 @@ class ClauseTest(SQLCompileTest):
self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1), clone=True), "SELECT * FROM table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2), clone=True), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = t2alias.col2")
- def test_selfreferential(self):
+ def test_include_exclude(self):
m = MetaData()
a=Table( 'a',m,
Column( 'id', Integer, primary_key=True),
@@ -319,10 +334,7 @@ class ClauseTest(SQLCompileTest):
assert str(e) == "a_1.id = a.xxx_id"
- def test_joins(self):
- """test that ClauseAdapter can target a Join object, replace it, and not dig into the sub-joins after
- replacing."""
-
+ def test_join_to_alias(self):
metadata = MetaData()
a = Table('a', metadata,
Column('id', Integer, primary_key=True))
@@ -359,6 +371,42 @@ class ClauseTest(SQLCompileTest):
"c JOIN (SELECT a.id AS a_id, b.id AS b_id, b.aid AS b_aid FROM a LEFT OUTER JOIN b ON a.id = b.aid) "
"ON b_id = c.bid) AS foo"
" LEFT OUTER JOIN d ON foo.a_id = d.aid")
+
+ def test_derived_from(self):
+ assert select([t1]).is_derived_from(t1)
+ assert not select([t2]).is_derived_from(t1)
+ assert not t1.is_derived_from(select([t1]))
+ assert t1.alias().is_derived_from(t1)
+
+
+ s1 = select([t1, t2]).alias('foo')
+ s2 = select([s1]).limit(5).offset(10).alias()
+ assert s2.is_derived_from(s1)
+ s2 = s2._clone()
+ assert s2.is_derived_from(s1)
+
+ def test_aliasedselect_to_aliasedselect(self):
+ # original issue from ticket #904
+ s1 = select([t1]).alias('foo')
+ s2 = select([s1]).limit(5).offset(10).alias()
+
+ self.assert_compile(sql_util.ClauseAdapter(s2).traverse(s1),
+ "SELECT foo.col1, foo.col2, foo.col3 FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1) AS foo LIMIT 5 OFFSET 10")
+
+ j = s1.outerjoin(t2, s1.c.col1==t2.c.col1)
+ self.assert_compile(sql_util.ClauseAdapter(s2).traverse(j).select(),
+ "SELECT anon_1.col1, anon_1.col2, anon_1.col3, table2.col1, table2.col2, table2.col3 FROM "\
+ "(SELECT foo.col1 AS col1, foo.col2 AS col2, foo.col3 AS col3 FROM "\
+ "(SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1) AS foo LIMIT 5 OFFSET 10) AS anon_1 "\
+ "LEFT OUTER JOIN table2 ON anon_1.col1 = table2.col1")
+
+ talias = t1.alias('bar')
+ j = s1.outerjoin(talias, s1.c.col1==talias.c.col1)
+ self.assert_compile(sql_util.ClauseAdapter(s2).traverse(j).select(),
+ "SELECT anon_1.col1, anon_1.col2, anon_1.col3, bar.col1, bar.col2, bar.col3 FROM "\
+ "(SELECT foo.col1 AS col1, foo.col2 AS col2, foo.col3 AS col3 FROM "\
+ "(SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1) AS foo LIMIT 5 OFFSET 10) AS anon_1 "\
+ "LEFT OUTER JOIN table1 AS bar ON anon_1.col1 = bar.col1")
class SelectTest(SQLCompileTest):