summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/sqlalchemy/orm/evaluator.py26
-rw-r--r--test/orm/test_evaluator.py15
2 files changed, 35 insertions, 6 deletions
diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py
index f7f12ce12..23c48329d 100644
--- a/lib/sqlalchemy/orm/evaluator.py
+++ b/lib/sqlalchemy/orm/evaluator.py
@@ -17,6 +17,16 @@ class UnevaluatableError(Exception):
pass
+class _NoObject(operators.ColumnOperators):
+ def operate(self, *arg, **kw):
+ return None
+
+ def reverse_operate(self, *arg, **kw):
+ return None
+
+
+_NO_OBJECT = _NoObject()
+
_straight_ops = set(
getattr(operators, op)
for op in (
@@ -36,8 +46,10 @@ _straight_ops = set(
)
_extended_ops = {
- operators.in_op: (lambda a, b: a in b),
- operators.not_in_op: (lambda a, b: a not in b),
+ operators.in_op: (lambda a, b: a in b if a is not _NO_OBJECT else None),
+ operators.not_in_op: (
+ lambda a, b: a not in b if a is not _NO_OBJECT else None
+ ),
}
_notimplemented_ops = set(
@@ -111,7 +123,11 @@ class EvaluatorCompiler(object):
raise UnevaluatableError("Cannot evaluate column: %s" % clause)
get_corresponding_attr = operator.attrgetter(key)
- return lambda obj: get_corresponding_attr(obj)
+ return (
+ lambda obj: get_corresponding_attr(obj)
+ if obj is not None
+ else _NO_OBJECT
+ )
def visit_tuple(self, clause):
return self.visit_clauselist(clause)
@@ -137,7 +153,7 @@ class EvaluatorCompiler(object):
for sub_evaluate in evaluators:
value = sub_evaluate(obj)
if not value:
- if value is None:
+ if value is None or value is _NO_OBJECT:
return None
return False
return True
@@ -148,7 +164,7 @@ class EvaluatorCompiler(object):
values = []
for sub_evaluate in evaluators:
value = sub_evaluate(obj)
- if value is None:
+ if value is None or value is _NO_OBJECT:
return None
values.append(value)
return tuple(values)
diff --git a/test/orm/test_evaluator.py b/test/orm/test_evaluator.py
index a6c889aa7..955e5134f 100644
--- a/test/orm/test_evaluator.py
+++ b/test/orm/test_evaluator.py
@@ -100,7 +100,11 @@ class EvaluateTest(fixtures.MappedTest):
eval_eq(
User.name == None, # noqa
- testcases=[(User(name="foo"), False), (User(name=None), True)],
+ testcases=[
+ (User(name="foo"), False),
+ (User(name=None), True),
+ (None, None),
+ ],
)
def test_warn_on_unannotated_matched_column(self):
@@ -144,6 +148,7 @@ class EvaluateTest(fixtures.MappedTest):
(User(name="foo"), False),
(User(name=True), False),
(User(name=False), True),
+ (None, None),
],
)
@@ -153,6 +158,7 @@ class EvaluateTest(fixtures.MappedTest):
(User(name="foo"), False),
(User(name=True), True),
(User(name=False), False),
+ (None, None),
],
)
@@ -167,6 +173,7 @@ class EvaluateTest(fixtures.MappedTest):
(User(id=1, name="bar"), False),
(User(id=2, name="bar"), False),
(User(id=1, name=None), None),
+ (None, None),
],
)
@@ -179,6 +186,7 @@ class EvaluateTest(fixtures.MappedTest):
(User(id=2, name="bar"), False),
(User(id=1, name=None), True),
(User(id=2, name=None), None),
+ (None, None),
],
)
@@ -201,6 +209,7 @@ class EvaluateTest(fixtures.MappedTest):
(User(id=2, name="bat"), False),
(User(id=1, name="bar"), True),
(User(id=1, name=None), None),
+ (None, None),
],
)
@@ -211,6 +220,7 @@ class EvaluateTest(fixtures.MappedTest):
(User(id=2, name="bat"), True),
(User(id=1, name="bar"), False),
(User(id=1, name=None), None),
+ (None, None),
],
)
@@ -225,6 +235,7 @@ class EvaluateTest(fixtures.MappedTest):
(User(id=1, name="bar"), False),
(User(id=2, name="bar"), True),
(User(id=1, name=None), None),
+ (None, None),
],
)
@@ -236,6 +247,7 @@ class EvaluateTest(fixtures.MappedTest):
(User(id=1, name="bar"), True),
(User(id=2, name="bar"), False),
(User(id=1, name=None), None),
+ (None, None),
],
)
@@ -251,6 +263,7 @@ class EvaluateTest(fixtures.MappedTest):
(User(id=2, name="bar"), True),
(User(id=None, name="foo"), None),
(User(id=None, name=None), None),
+ (None, None),
],
)