diff options
| -rw-r--r-- | lib/sqlalchemy/orm/evaluator.py | 26 | ||||
| -rw-r--r-- | test/orm/test_evaluator.py | 15 |
2 files changed, 35 insertions, 6 deletions
diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py index f7f12ce12..23c48329d 100644 --- a/lib/sqlalchemy/orm/evaluator.py +++ b/lib/sqlalchemy/orm/evaluator.py @@ -17,6 +17,16 @@ class UnevaluatableError(Exception): pass +class _NoObject(operators.ColumnOperators): + def operate(self, *arg, **kw): + return None + + def reverse_operate(self, *arg, **kw): + return None + + +_NO_OBJECT = _NoObject() + _straight_ops = set( getattr(operators, op) for op in ( @@ -36,8 +46,10 @@ _straight_ops = set( ) _extended_ops = { - operators.in_op: (lambda a, b: a in b), - operators.not_in_op: (lambda a, b: a not in b), + operators.in_op: (lambda a, b: a in b if a is not _NO_OBJECT else None), + operators.not_in_op: ( + lambda a, b: a not in b if a is not _NO_OBJECT else None + ), } _notimplemented_ops = set( @@ -111,7 +123,11 @@ class EvaluatorCompiler(object): raise UnevaluatableError("Cannot evaluate column: %s" % clause) get_corresponding_attr = operator.attrgetter(key) - return lambda obj: get_corresponding_attr(obj) + return ( + lambda obj: get_corresponding_attr(obj) + if obj is not None + else _NO_OBJECT + ) def visit_tuple(self, clause): return self.visit_clauselist(clause) @@ -137,7 +153,7 @@ class EvaluatorCompiler(object): for sub_evaluate in evaluators: value = sub_evaluate(obj) if not value: - if value is None: + if value is None or value is _NO_OBJECT: return None return False return True @@ -148,7 +164,7 @@ class EvaluatorCompiler(object): values = [] for sub_evaluate in evaluators: value = sub_evaluate(obj) - if value is None: + if value is None or value is _NO_OBJECT: return None values.append(value) return tuple(values) diff --git a/test/orm/test_evaluator.py b/test/orm/test_evaluator.py index a6c889aa7..955e5134f 100644 --- a/test/orm/test_evaluator.py +++ b/test/orm/test_evaluator.py @@ -100,7 +100,11 @@ class EvaluateTest(fixtures.MappedTest): eval_eq( User.name == None, # noqa - testcases=[(User(name="foo"), False), (User(name=None), True)], + testcases=[ + (User(name="foo"), False), + (User(name=None), True), + (None, None), + ], ) def test_warn_on_unannotated_matched_column(self): @@ -144,6 +148,7 @@ class EvaluateTest(fixtures.MappedTest): (User(name="foo"), False), (User(name=True), False), (User(name=False), True), + (None, None), ], ) @@ -153,6 +158,7 @@ class EvaluateTest(fixtures.MappedTest): (User(name="foo"), False), (User(name=True), True), (User(name=False), False), + (None, None), ], ) @@ -167,6 +173,7 @@ class EvaluateTest(fixtures.MappedTest): (User(id=1, name="bar"), False), (User(id=2, name="bar"), False), (User(id=1, name=None), None), + (None, None), ], ) @@ -179,6 +186,7 @@ class EvaluateTest(fixtures.MappedTest): (User(id=2, name="bar"), False), (User(id=1, name=None), True), (User(id=2, name=None), None), + (None, None), ], ) @@ -201,6 +209,7 @@ class EvaluateTest(fixtures.MappedTest): (User(id=2, name="bat"), False), (User(id=1, name="bar"), True), (User(id=1, name=None), None), + (None, None), ], ) @@ -211,6 +220,7 @@ class EvaluateTest(fixtures.MappedTest): (User(id=2, name="bat"), True), (User(id=1, name="bar"), False), (User(id=1, name=None), None), + (None, None), ], ) @@ -225,6 +235,7 @@ class EvaluateTest(fixtures.MappedTest): (User(id=1, name="bar"), False), (User(id=2, name="bar"), True), (User(id=1, name=None), None), + (None, None), ], ) @@ -236,6 +247,7 @@ class EvaluateTest(fixtures.MappedTest): (User(id=1, name="bar"), True), (User(id=2, name="bar"), False), (User(id=1, name=None), None), + (None, None), ], ) @@ -251,6 +263,7 @@ class EvaluateTest(fixtures.MappedTest): (User(id=2, name="bar"), True), (User(id=None, name="foo"), None), (User(id=None, name=None), None), + (None, None), ], ) |
