summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-12-12 17:56:52 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-12-12 17:56:52 +0000
commit5c14b20f9f02179e4e59e3f196cbab5da8366583 (patch)
tree4305d5eab7b8c6c3eeeef77a1b9447a56e7b0173
parent16810e401139644d9d137d9a18b4d945318db35c (diff)
downloadsqlalchemy-5c14b20f9f02179e4e59e3f196cbab5da8366583.tar.gz
implemented many-to-one comparisons to None generate <column> IS NULL, with column on the left side in all cases
-rw-r--r--CHANGES4
-rw-r--r--lib/sqlalchemy/orm/properties.py5
-rw-r--r--lib/sqlalchemy/orm/strategies.py25
-rw-r--r--test/orm/query.py5
-rw-r--r--test/testlib/fixtures.py2
5 files changed, 38 insertions, 3 deletions
diff --git a/CHANGES b/CHANGES
index 38dec541b..dcc6cbdb2 100644
--- a/CHANGES
+++ b/CHANGES
@@ -79,6 +79,10 @@ CHANGES
statements as well. Filter criterion, order bys, eager load
clauses will be "aliased" against the given statement.
+ - query.filter(SomeClass.somechild == None), when comparing
+ a many-to-one property to None, properly generates "id IS NULL"
+ including that the NULL is on the right side.
+
- eagerload(), lazyload(), eagerload_all() take an optional
second class-or-mapper argument, which will select the mapper
to apply the option towards. This can select among other
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 4d41556a0..b6d6cef63 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -232,7 +232,10 @@ class PropertyLoader(StrategizedProperty):
class Comparator(PropComparator):
def __eq__(self, other):
if other is None:
- return ~sql.exists([1], self.prop.primaryjoin)
+ if self.prop.uselist:
+ return ~sql.exists([1], self.prop.primaryjoin)
+ else:
+ return self.prop._optimized_compare(None)
elif self.prop.uselist:
if not hasattr(other, '__iter__'):
raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object. Use contains().")
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index d9390345e..2ba9d6be1 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -8,7 +8,7 @@
from sqlalchemy import sql, util, exceptions, logging
from sqlalchemy.sql import util as sql_util
-from sqlalchemy.sql import visitors
+from sqlalchemy.sql import visitors, expression, operators
from sqlalchemy.orm import mapper, attributes
from sqlalchemy.orm.interfaces import LoaderStrategy, StrategizedOption, MapperOption, PropertyOption
from sqlalchemy.orm import session as sessionlib
@@ -292,6 +292,9 @@ class LazyLoader(AbstractRelationLoader):
self._register_attribute(self.parent.class_, callable_=lambda i: self.setup_loader(i))
def lazy_clause(self, instance, reverse_direction=False):
+ if instance is None:
+ return self.lazy_none_clause(reverse_direction)
+
if not reverse_direction:
(criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.lazyreverse)
else:
@@ -305,6 +308,26 @@ class LazyLoader(AbstractRelationLoader):
bindparam.value = mapper._get_committed_attr_by_column(instance, bind_to_col[bindparam.key])
return visitors.traverse(criterion, clone=True, visit_bindparam=visit_bindparam)
+ def lazy_none_clause(self, reverse_direction=False):
+ if not reverse_direction:
+ (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.lazyreverse)
+ else:
+ (criterion, lazybinds, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
+ bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
+
+ def visit_binary(binary):
+ mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent
+ if isinstance(binary.left, expression._BindParamClause) and binary.left.key in bind_to_col:
+ # reverse order if the NULL is on the left side
+ binary.left = binary.right
+ binary.right = expression.null()
+ binary.operator = operators.is_
+ elif isinstance(binary.right, expression._BindParamClause) and binary.right.key in bind_to_col:
+ binary.right = expression.null()
+ binary.operator = operators.is_
+
+ return visitors.traverse(criterion, clone=True, visit_binary=visit_binary)
+
def setup_loader(self, instance, options=None, path=None):
if not mapper.has_mapper(instance):
return None
diff --git a/test/orm/query.py b/test/orm/query.py
index 5f85151f0..efd890e0e 100644
--- a/test/orm/query.py
+++ b/test/orm/query.py
@@ -376,6 +376,11 @@ class FilterTest(QueryTest):
assert [Address(id=1), Address(id=5)] == sess.query(Address).filter(Address.user!=user).all()
+ # generates an IS NULL
+ assert [] == sess.query(Address).filter(Address.user == None).all()
+
+ assert [Order(id=5)] == sess.query(Order).filter(Order.address == None).all()
+
class AggregateTest(QueryTest):
def test_sum(self):
sess = create_session()
diff --git a/test/testlib/fixtures.py b/test/testlib/fixtures.py
index 4394780bb..2a4b457ac 100644
--- a/test/testlib/fixtures.py
+++ b/test/testlib/fixtures.py
@@ -150,7 +150,7 @@ def install_fixture_data():
dict(id = 2, user_id = 9, description = 'order 2', isopen=0, address_id=4),
dict(id = 3, user_id = 7, description = 'order 3', isopen=1, address_id=1),
dict(id = 4, user_id = 9, description = 'order 4', isopen=1, address_id=4),
- dict(id = 5, user_id = 7, description = 'order 5', isopen=0, address_id=1)
+ dict(id = 5, user_id = 7, description = 'order 5', isopen=0, address_id=None)
)
items.insert().execute(
dict(id=1, description='item 1'),