summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/sqlalchemy/orm/properties.py40
-rw-r--r--lib/sqlalchemy/sql/expression.py2
-rw-r--r--lib/sqlalchemy/sql/util.py48
-rw-r--r--lib/sqlalchemy/sql/visitors.py6
-rw-r--r--test/orm/test_joins.py12
-rw-r--r--test/orm/test_query.py8
-rw-r--r--test/orm/test_relationships.py2
-rw-r--r--test/sql/test_selectable.py19
8 files changed, 110 insertions, 27 deletions
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 59c4cb3dc..a590ad706 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -14,7 +14,7 @@ mapped attributes.
from sqlalchemy import sql, util, log, exc as sa_exc
from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, \
join_condition, _shallow_annotate
-from sqlalchemy.sql import operators, expression
+from sqlalchemy.sql import operators, expression, visitors
from sqlalchemy.orm import attributes, dependency, mapper, \
object_mapper, strategies, configure_mappers
from sqlalchemy.orm.util import CascadeOptions, _class_to_mapper, \
@@ -444,6 +444,7 @@ class RelationshipProperty(StrategizedProperty):
else:
j = _orm_annotate(pj, exclude=self.property.remote_side)
+ # MARKMARK
if criterion is not None and target_adapter:
# limit this adapter to annotated only?
criterion = target_adapter.traverse(criterion)
@@ -1376,6 +1377,34 @@ class RelationshipProperty(StrategizedProperty):
"argument to indicate which column lazy join "
"condition should bind." % (col, self.mapper))
+ count = [0]
+ def clone(elem):
+ if set(['local', 'remote']).intersection(elem._annotations):
+ return None
+ elif elem in self.local_side and elem in self.remote_side:
+ # TODO: OK this still sucks. this is basically,
+ # refuse, refuse, refuse the temptation to guess!
+ # but crap we really have to guess don't we. we
+ # might want to traverse here with cloned_traverse
+ # so we can see the binary exprs and do it at that
+ # level....
+ if count[0] % 2 == 0:
+ elem = elem._annotate({'local':True})
+ else:
+ elem = elem._annotate({'remote':True})
+ count[0] += 1
+ elif elem in self.local_side:
+ elem = elem._annotate({'local':True})
+ elif elem in self.remote_side:
+ elem = elem._annotate({'remote':True})
+ else:
+ elem = None
+ return elem
+
+ self.primaryjoin = visitors.replacement_traverse(
+ self.primaryjoin, {}, clone
+ )
+
def _generate_backref(self):
if not self.is_primary():
return
@@ -1539,17 +1568,20 @@ class RelationshipProperty(StrategizedProperty):
secondary_aliasizer.traverse(secondaryjoin)
else:
primary_aliasizer = ClauseAdapter(dest_selectable,
- exclude=self.local_side,
+ #exclude=self.local_side,
+ exclude_fn=lambda c: "local" in c._annotations,
equivalents=self.mapper._equivalent_columns)
if source_selectable is not None:
primary_aliasizer.chain(
ClauseAdapter(source_selectable,
- exclude=self.remote_side,
+ #exclude=self.remote_side,
+ exclude_fn=lambda c: "remote" in c._annotations,
equivalents=self.parent._equivalent_columns))
secondary_aliasizer = None
+
primaryjoin = primary_aliasizer.traverse(primaryjoin)
target_adapter = secondary_aliasizer or primary_aliasizer
- target_adapter.include = target_adapter.exclude = None
+ target_adapter.include = target_adapter.exclude = target_adapter.exclude_fn = None
else:
target_adapter = None
if source_selectable is None:
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index b11e5ad42..30e19bc68 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -2184,7 +2184,7 @@ class ColumnElement(ClauseElement, _CompareMixin):
for oth in to_compare:
if use_proxies and self.shares_lineage(oth):
return True
- elif oth is self:
+ elif hash(oth) == hash(self):
return True
else:
return False
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 97975441e..f0509c16f 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -225,7 +225,8 @@ def adapt_criterion_to_null(crit, nulls):
return visitors.cloned_traverse(crit, {}, {'binary':visit_binary})
-def join_condition(a, b, ignore_nonexistent_tables=False, a_subset=None):
+def join_condition(a, b, ignore_nonexistent_tables=False,
+ a_subset=None):
"""create a join condition between two tables or selectables.
e.g.::
@@ -535,6 +536,10 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None,
"'consider_as_foreign_keys' or "
"'consider_as_referenced_keys'")
+ def col_is(a, b):
+ #return a is b
+ return a.compare(b)
+
def visit_binary(binary):
if not any_operator and binary.operator is not operators.eq:
return
@@ -544,20 +549,20 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None,
if consider_as_foreign_keys:
if binary.left in consider_as_foreign_keys and \
- (binary.right is binary.left or
+ (col_is(binary.right, binary.left) or
binary.right not in consider_as_foreign_keys):
pairs.append((binary.right, binary.left))
elif binary.right in consider_as_foreign_keys and \
- (binary.left is binary.right or
+ (col_is(binary.left, binary.right) or
binary.left not in consider_as_foreign_keys):
pairs.append((binary.left, binary.right))
elif consider_as_referenced_keys:
if binary.left in consider_as_referenced_keys and \
- (binary.right is binary.left or
+ (col_is(binary.right, binary.left) or
binary.right not in consider_as_referenced_keys):
pairs.append((binary.left, binary.right))
elif binary.right in consider_as_referenced_keys and \
- (binary.left is binary.right or
+ (col_is(binary.left, binary.right) or
binary.left not in consider_as_referenced_keys):
pairs.append((binary.right, binary.left))
else:
@@ -669,11 +674,22 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
s.c.col1 == table2.c.col1
"""
- def __init__(self, selectable, equivalents=None, include=None, exclude=None, adapt_on_names=False):
+ def __init__(self, selectable, equivalents=None,
+ include=None, exclude=None,
+ include_fn=None, exclude_fn=None,
+ adapt_on_names=False):
self.__traverse_options__ = {'stop_on':[selectable]}
self.selectable = selectable
- self.include = include
- self.exclude = exclude
+ if include:
+ assert not include_fn
+ self.include_fn = lambda e: e in include
+ else:
+ self.include_fn = include_fn
+ if exclude:
+ assert not exclude_fn
+ self.exclude_fn = lambda e: e in exclude
+ else:
+ self.exclude_fn = exclude_fn
self.equivalents = util.column_dict(equivalents or {})
self.adapt_on_names = adapt_on_names
@@ -693,19 +709,17 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
return newcol
def replace(self, col):
- if isinstance(col, expression.FromClause):
- if self.selectable.is_derived_from(col):
+ if isinstance(col, expression.FromClause) and \
+ self.selectable.is_derived_from(col):
return self.selectable
-
- if not isinstance(col, expression.ColumnElement):
+ elif not isinstance(col, expression.ColumnElement):
return None
-
- if self.include and col not in self.include:
+ elif self.include_fn and not self.include_fn(col):
return None
- elif self.exclude and col in self.exclude:
+ elif self.exclude_fn and self.exclude_fn(col):
return None
-
- return self._corresponding_column(col, True)
+ else:
+ return self._corresponding_column(col, True)
class ColumnAdapter(ClauseAdapter):
"""Extends ClauseAdapter with extra utility functions.
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index cdcf40aa8..75e099f0d 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -240,16 +240,16 @@ def replacement_traverse(obj, opts, replace):
replacement by a given replacement function."""
cloned = util.column_dict()
- stop_on = util.column_set(opts.get('stop_on', []))
+ stop_on = util.column_set([id(x) for x in opts.get('stop_on', [])])
def clone(elem, **kw):
- if elem in stop_on or \
+ if id(elem) in stop_on or \
'no_replacement_traverse' in elem._annotations:
return elem
else:
newelem = replace(elem)
if newelem is not None:
- stop_on.add(newelem)
+ stop_on.add(id(newelem))
return newelem
else:
if elem not in cloned:
diff --git a/test/orm/test_joins.py b/test/orm/test_joins.py
index db7c78cdd..6c43a2f39 100644
--- a/test/orm/test_joins.py
+++ b/test/orm/test_joins.py
@@ -1700,21 +1700,29 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL):
sess.flush()
sess.close()
- def test_join(self):
+ def test_join_1(self):
Node = self.classes.Node
-
sess = create_session()
node = sess.query(Node).join('children', aliased=True).filter_by(data='n122').first()
assert node.data=='n12'
+ def test_join_2(self):
+ Node = self.classes.Node
+ sess = create_session()
ret = sess.query(Node.data).join(Node.children, aliased=True).filter_by(data='n122').all()
assert ret == [('n12',)]
+ def test_join_3(self):
+ Node = self.classes.Node
+ sess = create_session()
node = sess.query(Node).join('children', 'children', aliased=True).filter_by(data='n122').first()
assert node.data=='n1'
+ def test_join_4(self):
+ Node = self.classes.Node
+ sess = create_session()
node = sess.query(Node).filter_by(data='n122').join('parent', aliased=True).filter_by(data='n12').\
join('parent', aliased=True, from_joinpoint=True).filter_by(data='n1').first()
assert node.data == 'n122'
diff --git a/test/orm/test_query.py b/test/orm/test_query.py
index 24974ae7e..155f7c68d 100644
--- a/test/orm/test_query.py
+++ b/test/orm/test_query.py
@@ -622,6 +622,14 @@ class OperatorTest(QueryTest, AssertsCompiledSQL):
self._test(Address.user == None, "addresses.user_id IS NULL")
self._test(Address.user != None, "addresses.user_id IS NOT NULL")
+
+ def test_foo(self):
+ Node = self.classes.Node
+ nalias = aliased(Node)
+ self._test(
+ nalias.parent.has(Node.data=='some data'),
+ "EXISTS (SELECT 1 FROM nodes WHERE nodes.id = nodes_1.parent_id AND nodes.data = :data_1)"
+ )
def test_selfref_relationship(self):
Node = self.classes.Node
diff --git a/test/orm/test_relationships.py b/test/orm/test_relationships.py
index 6781d7104..2049088af 100644
--- a/test/orm/test_relationships.py
+++ b/test/orm/test_relationships.py
@@ -249,6 +249,8 @@ class CompositeSelfRefFKTest(fixtures.MappedTest):
def _test(self):
Employee, Company = self.classes.Employee, self.classes.Company
+# employee_t = self.tables.employee_t
+# assert Employee.reports_to.property.local_remote_pairs == [(employee_t.c.reports_to_id, employee_t.c.emp_id), (employee_t.c.company_id, employee_t.c.company_id)]
sess = create_session()
c1 = Company()
diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py
index 8f599f1d6..6d85f7c4f 100644
--- a/test/sql/test_selectable.py
+++ b/test/sql/test_selectable.py
@@ -1023,6 +1023,25 @@ class AnnotationsTest(fixtures.TestBase):
annot = obj._annotate({})
eq_(set([obj]), set([annot]))
+ def test_compare(self):
+ t = table('t', column('x'), column('y'))
+ x_a = t.c.x._annotate({})
+ assert t.c.x.compare(x_a)
+ assert x_a.compare(t.c.x)
+ assert not x_a.compare(t.c.y)
+ assert not t.c.y.compare(x_a)
+ assert (t.c.x == 5).compare(x_a == 5)
+ assert not (t.c.y == 5).compare(x_a == 5)
+
+ s = select([t])
+ x_p = s.c.x
+ assert not x_a.compare(x_p)
+ assert not t.c.x.compare(x_p)
+ x_p_a = x_p._annotate({})
+ assert x_p_a.compare(x_p)
+ assert x_p.compare(x_p_a)
+ assert not x_p_a.compare(x_a)
+
def test_custom_constructions(self):
from sqlalchemy.schema import Column
class MyColumn(Column):