summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2008-02-17 01:15:43 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2008-02-17 01:15:43 +0000
commita3f67fecb27363c73f833cc72cefbff5e8754598 (patch)
tree349edaf8f429ebafab3e3db7445f2ab583c4d74d /lib
parent191dbee5c899af3a80050dcfd844c5ebc04195b2 (diff)
downloadsqlalchemy-a3f67fecb27363c73f833cc72cefbff5e8754598.tar.gz
- any(), has(), contains(), attribute level == and != now
work properly with self-referential relations - the clause inside the EXISTS is aliased on the "remote" side to distinguish it from the parent table.
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/orm/properties.py51
1 files changed, 31 insertions, 20 deletions
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index d08dd7124..6339ec575 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -15,7 +15,7 @@ from sqlalchemy.sql.util import ClauseAdapter, ColumnsInClause
from sqlalchemy.sql import visitors, operators, ColumnElement
from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency, object_mapper
from sqlalchemy.orm import session as sessionlib
-from sqlalchemy.orm.util import CascadeOptions
+from sqlalchemy.orm.util import CascadeOptions, PropertyAliasedClauses
from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty
from sqlalchemy.exceptions import ArgumentError
import weakref
@@ -265,33 +265,44 @@ class PropertyLoader(StrategizedProperty):
return sql.and_(*clauses)
else:
return self.prop._optimized_compare(other)
+
+ def _join_and_criterion(self, criterion=None, **kwargs):
+ if self.prop._is_self_referential():
+ pac = PropertyAliasedClauses(self.prop,
+ self.prop.primaryjoin,
+ self.prop.secondaryjoin)
+ j = pac.primaryjoin
+ if pac.secondaryjoin:
+ j = j & pac.secondaryjoin
+ else:
+ j = self.prop.primaryjoin
+ if self.prop.secondaryjoin:
+ j = j & self.prop.secondaryjoin
- def any(self, criterion=None, **kwargs):
- if not self.prop.uselist:
- raise exceptions.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().")
- j = self.prop.primaryjoin
- if self.prop.secondaryjoin:
- j = j & self.prop.secondaryjoin
for k in kwargs:
crit = (getattr(self.prop.mapper.class_, k) == kwargs[k])
if criterion is None:
criterion = crit
else:
criterion = criterion & crit
+
+ if criterion and self.prop._is_self_referential():
+ criterion = pac.adapt_clause(criterion)
+
+ return j, criterion
+
+ def any(self, criterion=None, **kwargs):
+ if not self.prop.uselist:
+ raise exceptions.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().")
+ j, criterion = self._join_and_criterion(criterion, **kwargs)
+
return sql.exists([1], j & criterion)
def has(self, criterion=None, **kwargs):
if self.prop.uselist:
raise exceptions.InvalidRequestError("'has()' not implemented for collections. Use any().")
- j = self.prop.primaryjoin
- if self.prop.secondaryjoin:
- j = j & self.prop.secondaryjoin
- for k in kwargs:
- crit = (getattr(self.prop.mapper.class_, k) == kwargs[k])
- if criterion is None:
- criterion = crit
- else:
- criterion = criterion & crit
+ j, criterion = self._join_and_criterion(criterion, **kwargs)
+
return sql.exists([1], j & criterion)
def contains(self, other):
@@ -309,11 +320,11 @@ class PropertyLoader(StrategizedProperty):
def __ne__(self, other):
if self.prop.uselist and not hasattr(other, '__iter__'):
raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object")
+
+ criterion = sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))])
+ j, criterion = self._join_and_criterion(criterion)
- j = self.prop.primaryjoin
- if self.prop.secondaryjoin:
- j = j & self.prop.secondaryjoin
- return ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))]))
+ return ~sql.exists([1], j & criterion)
def compare(self, op, value, value_is_parent=False):
if op == operators.eq: