summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRobert Bradshaw <robertwb@math.washington.edu>2009-11-11 12:45:58 -0800
committerRobert Bradshaw <robertwb@math.washington.edu>2009-11-11 12:45:58 -0800
commit350f769244d64627f9202e254c75aef9b35e8c35 (patch)
tree3dbe97eac22592a65808a5f661c9eb0f8f04aa1e
parent1175bb684d31f5b2e8c1a02cdd2f584c1cc02ace (diff)
downloadcython-350f769244d64627f9202e254c75aef9b35e8c35.tar.gz
Fix for in/not in cascading.
-rw-r--r--Cython/Compiler/ExprNodes.py49
-rw-r--r--tests/run/contains_T455.pyx15
2 files changed, 47 insertions, 17 deletions
diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py
index 353a78e09..1b6c2e828 100644
--- a/Cython/Compiler/ExprNodes.py
+++ b/Cython/Compiler/ExprNodes.py
@@ -5231,9 +5231,11 @@ class CmpNode(object):
else:
negation = ""
if op == 'in' or op == 'not_in':
- assert not coerce_result
+ code.globalstate.use_utility_code(contians_utility_code)
+ if self.type is PyrexTypes.py_object_type:
+ coerce_result = "__Pyx_PyBoolOrNull_FromLong"
if op == 'not_in':
- negation = "if (likely(%s != -1)) %s = !%s; " % ((result_code,)*3)
+ negation = "__Pyx_NegateNonNeg"
if operand2.type is dict_type:
code.globalstate.use_utility_code(
raise_none_iter_error_utility_code)
@@ -5241,22 +5243,27 @@ class CmpNode(object):
code.putln("__Pyx_RaiseNoneNotIterableError(); %s" %
code.error_goto(self.pos))
code.putln("} else {")
- code.putln(
- "%s = PyDict_Contains(%s, %s); %s%s" % (
- result_code,
- operand2.py_result(),
- operand1.py_result(),
- negation,
- code.error_goto_if_neg(result_code, self.pos)))
- code.putln("}")
+ method = "PyDict_Contains"
else:
- code.putln(
- "%s = PySequence_Contains(%s, %s); %s%s" % (
- result_code,
- operand2.py_result(),
- operand1.py_result(),
- negation,
- code.error_goto_if_neg(result_code, self.pos)))
+ method = "PySequence_Contains"
+ if self.type is PyrexTypes.py_object_type:
+ error_clause = code.error_goto_if_null
+ got_ref = "__Pyx_XGOTREF(%s); " % result_code
+ else:
+ error_clause = code.error_goto_if_neg
+ got_ref = ""
+ code.putln(
+ "%s = %s(%s(%s(%s, %s))); %s%s" % (
+ result_code,
+ coerce_result,
+ negation,
+ method,
+ operand2.py_result(),
+ operand1.py_result(),
+ got_ref,
+ error_clause(result_code, self.pos)))
+ if operand2.type is dict_type:
+ code.putln("}")
elif (operand1.type.is_pyobject
and op not in ('is', 'is_not')):
@@ -5306,6 +5313,14 @@ class CmpNode(object):
else:
return op
+contians_utility_code = UtilityCode(
+proto="""
+static INLINE long __Pyx_NegateNonNeg(long b) { return unlikely(b < 0) ? b : !b; }
+static INLINE PyObject* __Pyx_PyBoolOrNull_FromLong(long b) {
+ return unlikely(b < 0) ? NULL : __Pyx_PyBool_FromLong(b);
+}
+""")
+
class PrimaryCmpNode(ExprNode, CmpNode):
# Non-cascaded comparison or first comparison of
diff --git a/tests/run/contains_T455.pyx b/tests/run/contains_T455.pyx
index 669f89c16..1fa018212 100644
--- a/tests/run/contains_T455.pyx
+++ b/tests/run/contains_T455.pyx
@@ -80,3 +80,18 @@ def not_in_dict(k, dict dct):
TypeError: 'NoneType' object is not iterable
"""
return k not in dct
+
+def cascaded(a, b, c):
+ """
+ >>> cascaded(1, 2, 3)
+ Traceback (most recent call last):
+ ...
+ TypeError: argument of type 'int' is not iterable
+ >>> cascaded(-1, (1,2), (1,3))
+ True
+ >>> cascaded(1, (1,2), (1,3))
+ False
+ >>> cascaded(-1, (1,2), (1,0))
+ False
+ """
+ return a not in b < c