diff options
author | scoder <stefan_ml@behnel.de> | 2023-04-21 08:25:20 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-21 08:25:20 +0200 |
commit | bcb6c0e1af0086b7900db518e73e61e48ee5167b (patch) | |
tree | 2c8a9a3cdaccc0f1b77804afe3681fbf871a66d5 | |
parent | a75afc03415f1b887434991ac0db2de0e6555ed7 (diff) | |
download | cython-bcb6c0e1af0086b7900db518e73e61e48ee5167b.tar.gz |
Avoid Python int object creation when multiplying sequences with C integers (GH-5213)
* Avoid redundant subtree analysis in MulNode when multiplying sequences with unknown types.
* Avoid Python int creation when multiplying sequences with integers.
* Also allow a cint mult_factor for sequences, avoiding Python coercion if possible.
* Also optimise (int * ctuple), which will eventually end up as a Python tuple as well.
* Make sure we only apply a "mult_factor" to a Python sequence (not ctuples), and make the re-analysis of TupleNode a little safer.
-rw-r--r-- | Cython/Compiler/Builtin.py | 36 | ||||
-rw-r--r-- | Cython/Compiler/ExprNodes.py | 122 | ||||
-rw-r--r-- | Cython/Utility/ObjectHandling.c | 30 | ||||
-rw-r--r-- | tests/run/seq_mul.py | 21 | ||||
-rw-r--r-- | tests/run/unicodemethods.pyx | 155 |
5 files changed, 328 insertions, 36 deletions
diff --git a/Cython/Compiler/Builtin.py b/Cython/Compiler/Builtin.py index c84ee16fd..6b00e94f9 100644 --- a/Cython/Compiler/Builtin.py +++ b/Cython/Compiler/Builtin.py @@ -292,21 +292,33 @@ builtin_types_table = [ ("basestring", "PyBaseString_Type", [ BuiltinMethod("join", "TO", "T", "__Pyx_PyBaseString_Join", utility_code=UtilityCode.load("StringJoin", "StringTools.c")), + BuiltinMethod("__mul__", "Tz", "T", "__Pyx_PySequence_Multiply", + utility_code=UtilityCode.load("PySequenceMultiply", "ObjectHandling.c")), ]), ("bytearray", "PyByteArray_Type", [ + BuiltinMethod("__mul__", "Tz", "T", "__Pyx_PySequence_Multiply", + utility_code=UtilityCode.load("PySequenceMultiply", "ObjectHandling.c")), ]), ("bytes", "PyBytes_Type", [BuiltinMethod("join", "TO", "O", "__Pyx_PyBytes_Join", utility_code=UtilityCode.load("StringJoin", "StringTools.c")), + BuiltinMethod("__mul__", "Tz", "T", "__Pyx_PySequence_Multiply", + utility_code=UtilityCode.load("PySequenceMultiply", "ObjectHandling.c")), ]), ("str", "PyString_Type", [BuiltinMethod("join", "TO", "O", "__Pyx_PyString_Join", builtin_return_type='basestring', utility_code=UtilityCode.load("StringJoin", "StringTools.c")), + BuiltinMethod("__mul__", "Tz", "T", "__Pyx_PySequence_Multiply", + utility_code=UtilityCode.load("PySequenceMultiply", "ObjectHandling.c")), ]), ("unicode", "PyUnicode_Type", [BuiltinMethod("__contains__", "TO", "b", "PyUnicode_Contains"), BuiltinMethod("join", "TO", "T", "PyUnicode_Join"), + BuiltinMethod("__mul__", "Tz", "T", "__Pyx_PySequence_Multiply", + utility_code=UtilityCode.load("PySequenceMultiply", "ObjectHandling.c")), ]), - ("tuple", "PyTuple_Type", []), + ("tuple", "PyTuple_Type", [BuiltinMethod("__mul__", "Tz", "T", "__Pyx_PySequence_Multiply", + utility_code=UtilityCode.load("PySequenceMultiply", "ObjectHandling.c")), + ]), ("list", "PyList_Type", [BuiltinMethod("insert", "TzO", "r", "PyList_Insert"), BuiltinMethod("reverse", "T", "r", "PyList_Reverse"), @@ -314,6 +326,8 @@ builtin_types_table = [ utility_code=UtilityCode.load("ListAppend", "Optimize.c")), BuiltinMethod("extend", "TO", "r", "__Pyx_PyList_Extend", utility_code=UtilityCode.load("ListExtend", "Optimize.c")), + BuiltinMethod("__mul__", "Tz", "T", "__Pyx_PySequence_Multiply", + utility_code=UtilityCode.load("PySequenceMultiply", "ObjectHandling.c")), ]), ("dict", "PyDict_Type", [BuiltinMethod("__contains__", "TO", "b", "PyDict_Contains"), @@ -479,10 +493,11 @@ def init_builtins(): '__debug__', PyrexTypes.c_const_type(PyrexTypes.c_bint_type), pos=None, cname='__pyx_assertions_enabled()', is_cdef=True) - global type_type, list_type, tuple_type, dict_type, set_type, frozenset_type - global slice_type, bytes_type, str_type, unicode_type, basestring_type, bytearray_type + global type_type, list_type, tuple_type, dict_type, set_type, frozenset_type, slice_type + global bytes_type, str_type, unicode_type, basestring_type, bytearray_type global float_type, int_type, long_type, bool_type, complex_type global memoryview_type, py_buffer_type + global sequence_types type_type = builtin_scope.lookup('type').type list_type = builtin_scope.lookup('list').type tuple_type = builtin_scope.lookup('tuple').type @@ -490,17 +505,30 @@ def init_builtins(): set_type = builtin_scope.lookup('set').type frozenset_type = builtin_scope.lookup('frozenset').type slice_type = builtin_scope.lookup('slice').type + bytes_type = builtin_scope.lookup('bytes').type str_type = builtin_scope.lookup('str').type unicode_type = builtin_scope.lookup('unicode').type basestring_type = builtin_scope.lookup('basestring').type bytearray_type = builtin_scope.lookup('bytearray').type + memoryview_type = builtin_scope.lookup('memoryview').type + float_type = builtin_scope.lookup('float').type int_type = builtin_scope.lookup('int').type long_type = builtin_scope.lookup('long').type bool_type = builtin_scope.lookup('bool').type complex_type = builtin_scope.lookup('complex').type - memoryview_type = builtin_scope.lookup('memoryview').type + + sequence_types = ( + list_type, + tuple_type, + bytes_type, + str_type, + unicode_type, + basestring_type, + bytearray_type, + memoryview_type, + ) # Set up type inference links between equivalent Python/C types bool_type.equivalent_type = PyrexTypes.c_bint_type diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 6097297b8..6469000f7 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -13,7 +13,8 @@ cython.declare(error=object, warning=object, warn_once=object, InternalError=obj unicode_type=object, str_type=object, bytes_type=object, type_type=object, Builtin=object, Symtab=object, Utils=object, find_coercion_error=object, debug_disposal_code=object, debug_temp_alloc=object, debug_coercion=object, - bytearray_type=object, slice_type=object, _py_int_types=object, + bytearray_type=object, slice_type=object, memoryview_type=object, + builtin_sequence_types=object, _py_int_types=object, IS_PYTHON3=cython.bint) import re @@ -37,7 +38,7 @@ from . import TypeSlots from .Builtin import ( list_type, tuple_type, set_type, dict_type, type_type, unicode_type, str_type, bytes_type, bytearray_type, basestring_type, - slice_type, long_type, + slice_type, long_type, sequence_types as builtin_sequence_types, memoryview_type, ) from . import Builtin from . import Symtab @@ -8433,6 +8434,12 @@ class TupleNode(SequenceNode): return env.declare_tuple_type(self.pos, arg_types).type def analyse_types(self, env, skip_children=False): + # reset before re-analysing + if self.is_literal: + self.is_literal = False + if self.is_partly_literal: + self.is_partly_literal = False + if len(self.args) == 0: self.is_temp = False self.is_literal = True @@ -8463,7 +8470,7 @@ class TupleNode(SequenceNode): node.is_temp = False node.is_literal = True else: - if not node.mult_factor.type.is_pyobject: + if not node.mult_factor.type.is_pyobject and not node.mult_factor.type.is_int: node.mult_factor = node.mult_factor.coerce_to_pyobject(env) node.is_temp = True node.is_partly_literal = True @@ -8485,7 +8492,13 @@ class TupleNode(SequenceNode): return self.coerce_to_ctuple(dst_type, env) elif dst_type is tuple_type or dst_type is py_object_type: coerced_args = [arg.coerce_to_pyobject(env) for arg in self.args] - return TupleNode(self.pos, args=coerced_args, type=tuple_type, is_temp=1).analyse_types(env, skip_children=True) + return TupleNode( + self.pos, + args=coerced_args, + type=tuple_type, + mult_factor=self.mult_factor, + is_temp=1, + ).analyse_types(env, skip_children=True) else: return self.coerce_to_pyobject(env).coerce_to(dst_type, env) elif dst_type.is_ctuple and not self.mult_factor: @@ -8542,6 +8555,14 @@ class TupleNode(SequenceNode): const_code.put_giveref(tuple_target, py_object_type) if self.is_literal: self.result_code = tuple_target + elif self.mult_factor.type.is_int: + code.globalstate.use_utility_code( + UtilityCode.load_cached("PySequenceMultiply", "ObjectHandling.c")) + code.putln('%s = __Pyx_PySequence_Multiply(%s, %s); %s' % ( + self.result(), tuple_target, self.mult_factor.result(), + code.error_goto_if_null(self.result(), self.pos) + )) + self.generate_gotref(code) else: code.putln('%s = PyNumber_Multiply(%s, %s); %s' % ( self.result(), tuple_target, self.mult_factor.py_result(), @@ -11680,6 +11701,8 @@ class BinopNode(ExprNode): self.operand1.is_ephemeral() or self.operand2.is_ephemeral()) def generate_result_code(self, code): + type1 = self.operand1.type + type2 = self.operand2.type if self.type.is_pythran_expr: code.putln("// Pythran binop") code.putln("__Pyx_call_destructor(%s);" % self.result()) @@ -11696,18 +11719,17 @@ class BinopNode(ExprNode): self.operand1.pythran_result(), self.operator, self.operand2.pythran_result())) - elif self.operand1.type.is_pyobject: + elif type1.is_pyobject or type2.is_pyobject: function = self.py_operation_function(code) - if self.operator == '**': - extra_args = ", Py_None" - else: - extra_args = "" + extra_args = ", Py_None" if self.operator == '**' else "" + op1_result = self.operand1.py_result() if type1.is_pyobject else self.operand1.result() + op2_result = self.operand2.py_result() if type2.is_pyobject else self.operand2.result() code.putln( "%s = %s(%s, %s%s); %s" % ( self.result(), function, - self.operand1.py_result(), - self.operand2.py_result(), + op1_result, + op2_result, extra_args, code.error_goto_if_null(self.result(), self.pos))) self.generate_gotref(code) @@ -11993,40 +12015,76 @@ class SubNode(NumBinopNode): class MulNode(NumBinopNode): # '*' operator. + is_sequence_mul = False def analyse_types(self, env): + self.operand1 = self.operand1.analyse_types(env) + self.operand2 = self.operand2.analyse_types(env) + self.is_sequence_mul = self.calculate_is_sequence_mul() + # TODO: we could also optimise the case of "[...] * 2 * n", i.e. with an existing 'mult_factor' - if self.operand1.is_sequence_constructor and self.operand1.mult_factor is None: - operand2 = self.operand2.analyse_types(env) - if operand2.type.is_int or operand2.type is long_type: - return self.analyse_sequence_mul(env, self.operand1, operand2) - elif self.operand2.is_sequence_constructor and self.operand2.mult_factor is None: - operand1 = self.operand1.analyse_types(env) - if operand1.type.is_int or operand1.type is long_type: - return self.analyse_sequence_mul(env, self.operand2, operand1) + if self.is_sequence_mul: + operand1 = self.operand1 + operand2 = self.operand2 + if operand1.is_sequence_constructor and operand1.mult_factor is None: + return self.analyse_sequence_mul(env, operand1, operand2) + elif operand2.is_sequence_constructor and operand2.mult_factor is None: + return self.analyse_sequence_mul(env, operand2, operand1) - return NumBinopNode.analyse_types(self, env) + self.analyse_operation(env) + return self + + @staticmethod + def is_builtin_seqmul_type(type): + return type.is_builtin_type and type in builtin_sequence_types and type is not memoryview_type + + def calculate_is_sequence_mul(self): + type1 = self.operand1.type + type2 = self.operand2.type + if type1 is long_type or type1.is_int: + # normalise to (X * int) + type1, type2 = type2, type1 + if type2 is long_type or type2.is_int: + if type1.is_string or type1.is_ctuple: + return True + if self.is_builtin_seqmul_type(type1): + return True + return False def analyse_sequence_mul(self, env, seq, mult): assert seq.mult_factor is None + seq = seq.coerce_to_pyobject(env) seq.mult_factor = mult return seq.analyse_types(env) + def coerce_operands_to_pyobjects(self, env): + if self.is_sequence_mul: + # Keep operands as they are, but ctuples must become Python tuples to multiply them. + if self.operand1.type.is_ctuple: + self.operand1 = self.operand1.coerce_to_pyobject(env) + elif self.operand2.type.is_ctuple: + self.operand2 = self.operand2.coerce_to_pyobject(env) + return + super(MulNode, self).coerce_operands_to_pyobjects(env) + def is_py_operation_types(self, type1, type2): - if ((type1.is_string and type2.is_int) or - (type2.is_string and type1.is_int)): - return 1 - else: - return NumBinopNode.is_py_operation_types(self, type1, type2) + return self.is_sequence_mul or super(MulNode, self).is_py_operation_types(type1, type2) + + def py_operation_function(self, code): + if self.is_sequence_mul: + code.globalstate.use_utility_code( + UtilityCode.load_cached("PySequenceMultiply", "ObjectHandling.c")) + return "__Pyx_PySequence_Multiply" if self.operand1.type.is_pyobject else "__Pyx_PySequence_Multiply_Left" + return super(MulNode, self).py_operation_function(code) def infer_builtin_types_operation(self, type1, type2): - # let's assume that whatever builtin type you multiply a string with - # will either return a string of the same type or fail with an exception - string_types = (bytes_type, bytearray_type, str_type, basestring_type, unicode_type) - if type1 in string_types and type2.is_builtin_type: - return type1 - if type2 in string_types and type1.is_builtin_type: - return type2 + # let's assume that whatever builtin type you multiply a builtin sequence type with + # will either return a sequence of the same type or fail with an exception + if type1.is_builtin_type and type2.is_builtin_type: + if self.is_builtin_seqmul_type(type1): + return type1 + if self.is_builtin_seqmul_type(type2): + return type2 # multiplication of containers/numbers with an integer value # always (?) returns the same type if type1.is_int: diff --git a/Cython/Utility/ObjectHandling.c b/Cython/Utility/ObjectHandling.c index df9cc3457..e97569895 100644 --- a/Cython/Utility/ObjectHandling.c +++ b/Cython/Utility/ObjectHandling.c @@ -3086,6 +3086,36 @@ static CYTHON_INLINE PyObject *__Pyx_PyUnicode_ConcatInPlaceImpl(PyObject **p_le #define __Pyx_PyStr_ConcatInPlaceSafe(a, b) ((unlikely((a) == Py_None) || unlikely((b) == Py_None)) ? \ PyNumber_InPlaceAdd(a, b) : __Pyx_PyStr_ConcatInPlace(a, b)) + +/////////////// PySequenceMultiply.proto /////////////// + +#define __Pyx_PySequence_Multiply_Left(mul, seq) __Pyx_PySequence_Multiply(seq, mul) +static CYTHON_INLINE PyObject* __Pyx_PySequence_Multiply(PyObject *seq, Py_ssize_t mul); + +/////////////// PySequenceMultiply /////////////// + +static PyObject* __Pyx_PySequence_Multiply_Generic(PyObject *seq, Py_ssize_t mul) { + PyObject *result, *pymul = PyInt_FromSsize_t(mul); + if (unlikely(!pymul)) + return NULL; + result = PyNumber_Multiply(seq, pymul); + Py_DECREF(pymul); + return result; +} + +static CYTHON_INLINE PyObject* __Pyx_PySequence_Multiply(PyObject *seq, Py_ssize_t mul) { +#if CYTHON_USE_TYPE_SLOTS + PyTypeObject *type = Py_TYPE(seq); + if (likely(type->tp_as_sequence && type->tp_as_sequence->sq_repeat)) { + return type->tp_as_sequence->sq_repeat(seq, mul); + } else +#endif + { + return __Pyx_PySequence_Multiply_Generic(seq, mul); + } +} + + /////////////// FormatTypeName.proto /////////////// #if CYTHON_COMPILING_IN_LIMITED_API diff --git a/tests/run/seq_mul.py b/tests/run/seq_mul.py index 2c4a41e65..f600452a9 100644 --- a/tests/run/seq_mul.py +++ b/tests/run/seq_mul.py @@ -48,6 +48,27 @@ def list_times_cint(n: cython.int): @cython.test_fail_if_path_exists("//MulNode") @cython.test_assert_path_exists("//TupleNode[@mult_factor]") +def const_times_tuple(v: cython.int): + """ + >>> const_times_tuple(4) + () + (None, None) + (4, 4) + (1, 2, 3, 1, 2, 3) + """ + a = 2 * () + b = 2 * (None,) + c = 2 * (v,) + d = 2 * (1, 2, 3) + + print(a) + print(b) + print(c) + print(d) + + +@cython.test_fail_if_path_exists("//MulNode") +@cython.test_assert_path_exists("//TupleNode[@mult_factor]") def cint_times_tuple(n: cython.int): """ >>> cint_times_tuple(3) diff --git a/tests/run/unicodemethods.pyx b/tests/run/unicodemethods.pyx index a3b1a3333..f0367c267 100644 --- a/tests/run/unicodemethods.pyx +++ b/tests/run/unicodemethods.pyx @@ -777,3 +777,158 @@ def replace_maxcount(unicode s, substring, repl, maxcount): ab jd sdflk as SA sadas asdas fsdf\x20 """ return s.replace(substring, repl, maxcount) + + +# unicode * int + +@cython.test_fail_if_path_exists( + "//CoerceToPyTypeNode", +) +@cython.test_assert_path_exists( + "//MulNode[@is_sequence_mul = True]", +) +def multiply(unicode ustring, int mul): + """ + >>> astr = u"abc" + >>> ustr = u"abcüöä\\U0001F642" + + >>> print(multiply(astr, -1)) + <BLANKLINE> + >>> print(multiply(ustr, -1)) + <BLANKLINE> + + >>> print(multiply(astr, 0)) + <BLANKLINE> + >>> print(multiply(ustr, 0)) + <BLANKLINE> + + >>> print(multiply(astr, 1)) + abc + >>> print(multiply(ustr, 1)) + abcüöä\U0001F642 + + >>> print(multiply(astr, 2)) + abcabc + >>> print(multiply(ustr, 2)) + abcüöä\U0001F642abcüöä\U0001F642 + + >>> print(multiply(astr, 5)) + abcabcabcabcabc + >>> print(multiply(ustr, 5)) + abcüöä\U0001F642abcüöä\U0001F642abcüöä\U0001F642abcüöä\U0001F642abcüöä\U0001F642 + """ + return ustring * mul + + +#@cython.test_fail_if_path_exists( +# "//CoerceToPyTypeNode", +# "//CastNode", "//TypecastNode") +#@cython.test_assert_path_exists( +# "//PythonCapiCallNode") +def multiply_inplace(unicode ustring, int mul): + """ + >>> astr = u"abc" + >>> ustr = u"abcüöä\\U0001F642" + + >>> print(multiply_inplace(astr, -1)) + <BLANKLINE> + >>> print(multiply_inplace(ustr, -1)) + <BLANKLINE> + + >>> print(multiply_inplace(astr, 0)) + <BLANKLINE> + >>> print(multiply_inplace(ustr, 0)) + <BLANKLINE> + + >>> print(multiply_inplace(astr, 1)) + abc + >>> print(multiply_inplace(ustr, 1)) + abcüöä\U0001F642 + + >>> print(multiply_inplace(astr, 2)) + abcabc + >>> print(multiply_inplace(ustr, 2)) + abcüöä\U0001F642abcüöä\U0001F642 + + >>> print(multiply_inplace(astr, 5)) + abcabcabcabcabc + >>> print(multiply_inplace(ustr, 5)) + abcüöä\U0001F642abcüöä\U0001F642abcüöä\U0001F642abcüöä\U0001F642abcüöä\U0001F642 + """ + ustring *= mul + return ustring + + +@cython.test_fail_if_path_exists( + "//CoerceToPyTypeNode", +) +@cython.test_assert_path_exists( + "//MulNode[@is_sequence_mul = True]", +) +def multiply_reversed(unicode ustring, int mul): + """ + >>> astr = u"abc" + >>> ustr = u"abcüöä\\U0001F642" + + >>> print(multiply_reversed(astr, -1)) + <BLANKLINE> + >>> print(multiply_reversed(ustr, -1)) + <BLANKLINE> + + >>> print(multiply_reversed(astr, 0)) + <BLANKLINE> + >>> print(multiply_reversed(ustr, 0)) + <BLANKLINE> + + >>> print(multiply_reversed(astr, 1)) + abc + >>> print(multiply_reversed(ustr, 1)) + abcüöä\U0001F642 + + >>> print(multiply_reversed(astr, 2)) + abcabc + >>> print(multiply_reversed(ustr, 2)) + abcüöä\U0001F642abcüöä\U0001F642 + + >>> print(multiply_reversed(astr, 5)) + abcabcabcabcabc + >>> print(multiply_reversed(ustr, 5)) + abcüöä\U0001F642abcüöä\U0001F642abcüöä\U0001F642abcüöä\U0001F642abcüöä\U0001F642 + """ + return mul * ustring + + +@cython.test_fail_if_path_exists( + "//CoerceToPyTypeNode", +) +def unicode__mul__(unicode ustring, int mul): + """ + >>> astr = u"abc" + >>> ustr = u"abcüöä\\U0001F642" + + >>> print(unicode__mul__(astr, -1)) + <BLANKLINE> + >>> print(unicode__mul__(ustr, -1)) + <BLANKLINE> + + >>> print(unicode__mul__(astr, 0)) + <BLANKLINE> + >>> print(unicode__mul__(ustr, 0)) + <BLANKLINE> + + >>> print(unicode__mul__(astr, 1)) + abc + >>> print(unicode__mul__(ustr, 1)) + abcüöä\U0001F642 + + >>> print(unicode__mul__(astr, 2)) + abcabc + >>> print(unicode__mul__(ustr, 2)) + abcüöä\U0001F642abcüöä\U0001F642 + + >>> print(unicode__mul__(astr, 5)) + abcabcabcabcabc + >>> print(unicode__mul__(ustr, 5)) + abcüöä\U0001F642abcüöä\U0001F642abcüöä\U0001F642abcüöä\U0001F642abcüöä\U0001F642 + """ + return ustring.__mul__(mul) |