summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorscoder <stefan_ml@behnel.de>2023-04-21 08:25:20 +0200
committerGitHub <noreply@github.com>2023-04-21 08:25:20 +0200
commitbcb6c0e1af0086b7900db518e73e61e48ee5167b (patch)
tree2c8a9a3cdaccc0f1b77804afe3681fbf871a66d5
parenta75afc03415f1b887434991ac0db2de0e6555ed7 (diff)
downloadcython-bcb6c0e1af0086b7900db518e73e61e48ee5167b.tar.gz
Avoid Python int object creation when multiplying sequences with C integers (GH-5213)
* Avoid redundant subtree analysis in MulNode when multiplying sequences with unknown types. * Avoid Python int creation when multiplying sequences with integers. * Also allow a cint mult_factor for sequences, avoiding Python coercion if possible. * Also optimise (int * ctuple), which will eventually end up as a Python tuple as well. * Make sure we only apply a "mult_factor" to a Python sequence (not ctuples), and make the re-analysis of TupleNode a little safer.
-rw-r--r--Cython/Compiler/Builtin.py36
-rw-r--r--Cython/Compiler/ExprNodes.py122
-rw-r--r--Cython/Utility/ObjectHandling.c30
-rw-r--r--tests/run/seq_mul.py21
-rw-r--r--tests/run/unicodemethods.pyx155
5 files changed, 328 insertions, 36 deletions
diff --git a/Cython/Compiler/Builtin.py b/Cython/Compiler/Builtin.py
index c84ee16fd..6b00e94f9 100644
--- a/Cython/Compiler/Builtin.py
+++ b/Cython/Compiler/Builtin.py
@@ -292,21 +292,33 @@ builtin_types_table = [
("basestring", "PyBaseString_Type", [
BuiltinMethod("join", "TO", "T", "__Pyx_PyBaseString_Join",
utility_code=UtilityCode.load("StringJoin", "StringTools.c")),
+ BuiltinMethod("__mul__", "Tz", "T", "__Pyx_PySequence_Multiply",
+ utility_code=UtilityCode.load("PySequenceMultiply", "ObjectHandling.c")),
]),
("bytearray", "PyByteArray_Type", [
+ BuiltinMethod("__mul__", "Tz", "T", "__Pyx_PySequence_Multiply",
+ utility_code=UtilityCode.load("PySequenceMultiply", "ObjectHandling.c")),
]),
("bytes", "PyBytes_Type", [BuiltinMethod("join", "TO", "O", "__Pyx_PyBytes_Join",
utility_code=UtilityCode.load("StringJoin", "StringTools.c")),
+ BuiltinMethod("__mul__", "Tz", "T", "__Pyx_PySequence_Multiply",
+ utility_code=UtilityCode.load("PySequenceMultiply", "ObjectHandling.c")),
]),
("str", "PyString_Type", [BuiltinMethod("join", "TO", "O", "__Pyx_PyString_Join",
builtin_return_type='basestring',
utility_code=UtilityCode.load("StringJoin", "StringTools.c")),
+ BuiltinMethod("__mul__", "Tz", "T", "__Pyx_PySequence_Multiply",
+ utility_code=UtilityCode.load("PySequenceMultiply", "ObjectHandling.c")),
]),
("unicode", "PyUnicode_Type", [BuiltinMethod("__contains__", "TO", "b", "PyUnicode_Contains"),
BuiltinMethod("join", "TO", "T", "PyUnicode_Join"),
+ BuiltinMethod("__mul__", "Tz", "T", "__Pyx_PySequence_Multiply",
+ utility_code=UtilityCode.load("PySequenceMultiply", "ObjectHandling.c")),
]),
- ("tuple", "PyTuple_Type", []),
+ ("tuple", "PyTuple_Type", [BuiltinMethod("__mul__", "Tz", "T", "__Pyx_PySequence_Multiply",
+ utility_code=UtilityCode.load("PySequenceMultiply", "ObjectHandling.c")),
+ ]),
("list", "PyList_Type", [BuiltinMethod("insert", "TzO", "r", "PyList_Insert"),
BuiltinMethod("reverse", "T", "r", "PyList_Reverse"),
@@ -314,6 +326,8 @@ builtin_types_table = [
utility_code=UtilityCode.load("ListAppend", "Optimize.c")),
BuiltinMethod("extend", "TO", "r", "__Pyx_PyList_Extend",
utility_code=UtilityCode.load("ListExtend", "Optimize.c")),
+ BuiltinMethod("__mul__", "Tz", "T", "__Pyx_PySequence_Multiply",
+ utility_code=UtilityCode.load("PySequenceMultiply", "ObjectHandling.c")),
]),
("dict", "PyDict_Type", [BuiltinMethod("__contains__", "TO", "b", "PyDict_Contains"),
@@ -479,10 +493,11 @@ def init_builtins():
'__debug__', PyrexTypes.c_const_type(PyrexTypes.c_bint_type),
pos=None, cname='__pyx_assertions_enabled()', is_cdef=True)
- global type_type, list_type, tuple_type, dict_type, set_type, frozenset_type
- global slice_type, bytes_type, str_type, unicode_type, basestring_type, bytearray_type
+ global type_type, list_type, tuple_type, dict_type, set_type, frozenset_type, slice_type
+ global bytes_type, str_type, unicode_type, basestring_type, bytearray_type
global float_type, int_type, long_type, bool_type, complex_type
global memoryview_type, py_buffer_type
+ global sequence_types
type_type = builtin_scope.lookup('type').type
list_type = builtin_scope.lookup('list').type
tuple_type = builtin_scope.lookup('tuple').type
@@ -490,17 +505,30 @@ def init_builtins():
set_type = builtin_scope.lookup('set').type
frozenset_type = builtin_scope.lookup('frozenset').type
slice_type = builtin_scope.lookup('slice').type
+
bytes_type = builtin_scope.lookup('bytes').type
str_type = builtin_scope.lookup('str').type
unicode_type = builtin_scope.lookup('unicode').type
basestring_type = builtin_scope.lookup('basestring').type
bytearray_type = builtin_scope.lookup('bytearray').type
+ memoryview_type = builtin_scope.lookup('memoryview').type
+
float_type = builtin_scope.lookup('float').type
int_type = builtin_scope.lookup('int').type
long_type = builtin_scope.lookup('long').type
bool_type = builtin_scope.lookup('bool').type
complex_type = builtin_scope.lookup('complex').type
- memoryview_type = builtin_scope.lookup('memoryview').type
+
+ sequence_types = (
+ list_type,
+ tuple_type,
+ bytes_type,
+ str_type,
+ unicode_type,
+ basestring_type,
+ bytearray_type,
+ memoryview_type,
+ )
# Set up type inference links between equivalent Python/C types
bool_type.equivalent_type = PyrexTypes.c_bint_type
diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py
index 6097297b8..6469000f7 100644
--- a/Cython/Compiler/ExprNodes.py
+++ b/Cython/Compiler/ExprNodes.py
@@ -13,7 +13,8 @@ cython.declare(error=object, warning=object, warn_once=object, InternalError=obj
unicode_type=object, str_type=object, bytes_type=object, type_type=object,
Builtin=object, Symtab=object, Utils=object, find_coercion_error=object,
debug_disposal_code=object, debug_temp_alloc=object, debug_coercion=object,
- bytearray_type=object, slice_type=object, _py_int_types=object,
+ bytearray_type=object, slice_type=object, memoryview_type=object,
+ builtin_sequence_types=object, _py_int_types=object,
IS_PYTHON3=cython.bint)
import re
@@ -37,7 +38,7 @@ from . import TypeSlots
from .Builtin import (
list_type, tuple_type, set_type, dict_type, type_type,
unicode_type, str_type, bytes_type, bytearray_type, basestring_type,
- slice_type, long_type,
+ slice_type, long_type, sequence_types as builtin_sequence_types, memoryview_type,
)
from . import Builtin
from . import Symtab
@@ -8433,6 +8434,12 @@ class TupleNode(SequenceNode):
return env.declare_tuple_type(self.pos, arg_types).type
def analyse_types(self, env, skip_children=False):
+ # reset before re-analysing
+ if self.is_literal:
+ self.is_literal = False
+ if self.is_partly_literal:
+ self.is_partly_literal = False
+
if len(self.args) == 0:
self.is_temp = False
self.is_literal = True
@@ -8463,7 +8470,7 @@ class TupleNode(SequenceNode):
node.is_temp = False
node.is_literal = True
else:
- if not node.mult_factor.type.is_pyobject:
+ if not node.mult_factor.type.is_pyobject and not node.mult_factor.type.is_int:
node.mult_factor = node.mult_factor.coerce_to_pyobject(env)
node.is_temp = True
node.is_partly_literal = True
@@ -8485,7 +8492,13 @@ class TupleNode(SequenceNode):
return self.coerce_to_ctuple(dst_type, env)
elif dst_type is tuple_type or dst_type is py_object_type:
coerced_args = [arg.coerce_to_pyobject(env) for arg in self.args]
- return TupleNode(self.pos, args=coerced_args, type=tuple_type, is_temp=1).analyse_types(env, skip_children=True)
+ return TupleNode(
+ self.pos,
+ args=coerced_args,
+ type=tuple_type,
+ mult_factor=self.mult_factor,
+ is_temp=1,
+ ).analyse_types(env, skip_children=True)
else:
return self.coerce_to_pyobject(env).coerce_to(dst_type, env)
elif dst_type.is_ctuple and not self.mult_factor:
@@ -8542,6 +8555,14 @@ class TupleNode(SequenceNode):
const_code.put_giveref(tuple_target, py_object_type)
if self.is_literal:
self.result_code = tuple_target
+ elif self.mult_factor.type.is_int:
+ code.globalstate.use_utility_code(
+ UtilityCode.load_cached("PySequenceMultiply", "ObjectHandling.c"))
+ code.putln('%s = __Pyx_PySequence_Multiply(%s, %s); %s' % (
+ self.result(), tuple_target, self.mult_factor.result(),
+ code.error_goto_if_null(self.result(), self.pos)
+ ))
+ self.generate_gotref(code)
else:
code.putln('%s = PyNumber_Multiply(%s, %s); %s' % (
self.result(), tuple_target, self.mult_factor.py_result(),
@@ -11680,6 +11701,8 @@ class BinopNode(ExprNode):
self.operand1.is_ephemeral() or self.operand2.is_ephemeral())
def generate_result_code(self, code):
+ type1 = self.operand1.type
+ type2 = self.operand2.type
if self.type.is_pythran_expr:
code.putln("// Pythran binop")
code.putln("__Pyx_call_destructor(%s);" % self.result())
@@ -11696,18 +11719,17 @@ class BinopNode(ExprNode):
self.operand1.pythran_result(),
self.operator,
self.operand2.pythran_result()))
- elif self.operand1.type.is_pyobject:
+ elif type1.is_pyobject or type2.is_pyobject:
function = self.py_operation_function(code)
- if self.operator == '**':
- extra_args = ", Py_None"
- else:
- extra_args = ""
+ extra_args = ", Py_None" if self.operator == '**' else ""
+ op1_result = self.operand1.py_result() if type1.is_pyobject else self.operand1.result()
+ op2_result = self.operand2.py_result() if type2.is_pyobject else self.operand2.result()
code.putln(
"%s = %s(%s, %s%s); %s" % (
self.result(),
function,
- self.operand1.py_result(),
- self.operand2.py_result(),
+ op1_result,
+ op2_result,
extra_args,
code.error_goto_if_null(self.result(), self.pos)))
self.generate_gotref(code)
@@ -11993,40 +12015,76 @@ class SubNode(NumBinopNode):
class MulNode(NumBinopNode):
# '*' operator.
+ is_sequence_mul = False
def analyse_types(self, env):
+ self.operand1 = self.operand1.analyse_types(env)
+ self.operand2 = self.operand2.analyse_types(env)
+ self.is_sequence_mul = self.calculate_is_sequence_mul()
+
# TODO: we could also optimise the case of "[...] * 2 * n", i.e. with an existing 'mult_factor'
- if self.operand1.is_sequence_constructor and self.operand1.mult_factor is None:
- operand2 = self.operand2.analyse_types(env)
- if operand2.type.is_int or operand2.type is long_type:
- return self.analyse_sequence_mul(env, self.operand1, operand2)
- elif self.operand2.is_sequence_constructor and self.operand2.mult_factor is None:
- operand1 = self.operand1.analyse_types(env)
- if operand1.type.is_int or operand1.type is long_type:
- return self.analyse_sequence_mul(env, self.operand2, operand1)
+ if self.is_sequence_mul:
+ operand1 = self.operand1
+ operand2 = self.operand2
+ if operand1.is_sequence_constructor and operand1.mult_factor is None:
+ return self.analyse_sequence_mul(env, operand1, operand2)
+ elif operand2.is_sequence_constructor and operand2.mult_factor is None:
+ return self.analyse_sequence_mul(env, operand2, operand1)
- return NumBinopNode.analyse_types(self, env)
+ self.analyse_operation(env)
+ return self
+
+ @staticmethod
+ def is_builtin_seqmul_type(type):
+ return type.is_builtin_type and type in builtin_sequence_types and type is not memoryview_type
+
+ def calculate_is_sequence_mul(self):
+ type1 = self.operand1.type
+ type2 = self.operand2.type
+ if type1 is long_type or type1.is_int:
+ # normalise to (X * int)
+ type1, type2 = type2, type1
+ if type2 is long_type or type2.is_int:
+ if type1.is_string or type1.is_ctuple:
+ return True
+ if self.is_builtin_seqmul_type(type1):
+ return True
+ return False
def analyse_sequence_mul(self, env, seq, mult):
assert seq.mult_factor is None
+ seq = seq.coerce_to_pyobject(env)
seq.mult_factor = mult
return seq.analyse_types(env)
+ def coerce_operands_to_pyobjects(self, env):
+ if self.is_sequence_mul:
+ # Keep operands as they are, but ctuples must become Python tuples to multiply them.
+ if self.operand1.type.is_ctuple:
+ self.operand1 = self.operand1.coerce_to_pyobject(env)
+ elif self.operand2.type.is_ctuple:
+ self.operand2 = self.operand2.coerce_to_pyobject(env)
+ return
+ super(MulNode, self).coerce_operands_to_pyobjects(env)
+
def is_py_operation_types(self, type1, type2):
- if ((type1.is_string and type2.is_int) or
- (type2.is_string and type1.is_int)):
- return 1
- else:
- return NumBinopNode.is_py_operation_types(self, type1, type2)
+ return self.is_sequence_mul or super(MulNode, self).is_py_operation_types(type1, type2)
+
+ def py_operation_function(self, code):
+ if self.is_sequence_mul:
+ code.globalstate.use_utility_code(
+ UtilityCode.load_cached("PySequenceMultiply", "ObjectHandling.c"))
+ return "__Pyx_PySequence_Multiply" if self.operand1.type.is_pyobject else "__Pyx_PySequence_Multiply_Left"
+ return super(MulNode, self).py_operation_function(code)
def infer_builtin_types_operation(self, type1, type2):
- # let's assume that whatever builtin type you multiply a string with
- # will either return a string of the same type or fail with an exception
- string_types = (bytes_type, bytearray_type, str_type, basestring_type, unicode_type)
- if type1 in string_types and type2.is_builtin_type:
- return type1
- if type2 in string_types and type1.is_builtin_type:
- return type2
+ # let's assume that whatever builtin type you multiply a builtin sequence type with
+ # will either return a sequence of the same type or fail with an exception
+ if type1.is_builtin_type and type2.is_builtin_type:
+ if self.is_builtin_seqmul_type(type1):
+ return type1
+ if self.is_builtin_seqmul_type(type2):
+ return type2
# multiplication of containers/numbers with an integer value
# always (?) returns the same type
if type1.is_int:
diff --git a/Cython/Utility/ObjectHandling.c b/Cython/Utility/ObjectHandling.c
index df9cc3457..e97569895 100644
--- a/Cython/Utility/ObjectHandling.c
+++ b/Cython/Utility/ObjectHandling.c
@@ -3086,6 +3086,36 @@ static CYTHON_INLINE PyObject *__Pyx_PyUnicode_ConcatInPlaceImpl(PyObject **p_le
#define __Pyx_PyStr_ConcatInPlaceSafe(a, b) ((unlikely((a) == Py_None) || unlikely((b) == Py_None)) ? \
PyNumber_InPlaceAdd(a, b) : __Pyx_PyStr_ConcatInPlace(a, b))
+
+/////////////// PySequenceMultiply.proto ///////////////
+
+#define __Pyx_PySequence_Multiply_Left(mul, seq) __Pyx_PySequence_Multiply(seq, mul)
+static CYTHON_INLINE PyObject* __Pyx_PySequence_Multiply(PyObject *seq, Py_ssize_t mul);
+
+/////////////// PySequenceMultiply ///////////////
+
+static PyObject* __Pyx_PySequence_Multiply_Generic(PyObject *seq, Py_ssize_t mul) {
+ PyObject *result, *pymul = PyInt_FromSsize_t(mul);
+ if (unlikely(!pymul))
+ return NULL;
+ result = PyNumber_Multiply(seq, pymul);
+ Py_DECREF(pymul);
+ return result;
+}
+
+static CYTHON_INLINE PyObject* __Pyx_PySequence_Multiply(PyObject *seq, Py_ssize_t mul) {
+#if CYTHON_USE_TYPE_SLOTS
+ PyTypeObject *type = Py_TYPE(seq);
+ if (likely(type->tp_as_sequence && type->tp_as_sequence->sq_repeat)) {
+ return type->tp_as_sequence->sq_repeat(seq, mul);
+ } else
+#endif
+ {
+ return __Pyx_PySequence_Multiply_Generic(seq, mul);
+ }
+}
+
+
/////////////// FormatTypeName.proto ///////////////
#if CYTHON_COMPILING_IN_LIMITED_API
diff --git a/tests/run/seq_mul.py b/tests/run/seq_mul.py
index 2c4a41e65..f600452a9 100644
--- a/tests/run/seq_mul.py
+++ b/tests/run/seq_mul.py
@@ -48,6 +48,27 @@ def list_times_cint(n: cython.int):
@cython.test_fail_if_path_exists("//MulNode")
@cython.test_assert_path_exists("//TupleNode[@mult_factor]")
+def const_times_tuple(v: cython.int):
+ """
+ >>> const_times_tuple(4)
+ ()
+ (None, None)
+ (4, 4)
+ (1, 2, 3, 1, 2, 3)
+ """
+ a = 2 * ()
+ b = 2 * (None,)
+ c = 2 * (v,)
+ d = 2 * (1, 2, 3)
+
+ print(a)
+ print(b)
+ print(c)
+ print(d)
+
+
+@cython.test_fail_if_path_exists("//MulNode")
+@cython.test_assert_path_exists("//TupleNode[@mult_factor]")
def cint_times_tuple(n: cython.int):
"""
>>> cint_times_tuple(3)
diff --git a/tests/run/unicodemethods.pyx b/tests/run/unicodemethods.pyx
index a3b1a3333..f0367c267 100644
--- a/tests/run/unicodemethods.pyx
+++ b/tests/run/unicodemethods.pyx
@@ -777,3 +777,158 @@ def replace_maxcount(unicode s, substring, repl, maxcount):
ab jd sdflk as SA sadas asdas fsdf\x20
"""
return s.replace(substring, repl, maxcount)
+
+
+# unicode * int
+
+@cython.test_fail_if_path_exists(
+ "//CoerceToPyTypeNode",
+)
+@cython.test_assert_path_exists(
+ "//MulNode[@is_sequence_mul = True]",
+)
+def multiply(unicode ustring, int mul):
+ """
+ >>> astr = u"abc"
+ >>> ustr = u"abcüöä\\U0001F642"
+
+ >>> print(multiply(astr, -1))
+ <BLANKLINE>
+ >>> print(multiply(ustr, -1))
+ <BLANKLINE>
+
+ >>> print(multiply(astr, 0))
+ <BLANKLINE>
+ >>> print(multiply(ustr, 0))
+ <BLANKLINE>
+
+ >>> print(multiply(astr, 1))
+ abc
+ >>> print(multiply(ustr, 1))
+ abcüöä\U0001F642
+
+ >>> print(multiply(astr, 2))
+ abcabc
+ >>> print(multiply(ustr, 2))
+ abcüöä\U0001F642abcüöä\U0001F642
+
+ >>> print(multiply(astr, 5))
+ abcabcabcabcabc
+ >>> print(multiply(ustr, 5))
+ abcüöä\U0001F642abcüöä\U0001F642abcüöä\U0001F642abcüöä\U0001F642abcüöä\U0001F642
+ """
+ return ustring * mul
+
+
+#@cython.test_fail_if_path_exists(
+# "//CoerceToPyTypeNode",
+# "//CastNode", "//TypecastNode")
+#@cython.test_assert_path_exists(
+# "//PythonCapiCallNode")
+def multiply_inplace(unicode ustring, int mul):
+ """
+ >>> astr = u"abc"
+ >>> ustr = u"abcüöä\\U0001F642"
+
+ >>> print(multiply_inplace(astr, -1))
+ <BLANKLINE>
+ >>> print(multiply_inplace(ustr, -1))
+ <BLANKLINE>
+
+ >>> print(multiply_inplace(astr, 0))
+ <BLANKLINE>
+ >>> print(multiply_inplace(ustr, 0))
+ <BLANKLINE>
+
+ >>> print(multiply_inplace(astr, 1))
+ abc
+ >>> print(multiply_inplace(ustr, 1))
+ abcüöä\U0001F642
+
+ >>> print(multiply_inplace(astr, 2))
+ abcabc
+ >>> print(multiply_inplace(ustr, 2))
+ abcüöä\U0001F642abcüöä\U0001F642
+
+ >>> print(multiply_inplace(astr, 5))
+ abcabcabcabcabc
+ >>> print(multiply_inplace(ustr, 5))
+ abcüöä\U0001F642abcüöä\U0001F642abcüöä\U0001F642abcüöä\U0001F642abcüöä\U0001F642
+ """
+ ustring *= mul
+ return ustring
+
+
+@cython.test_fail_if_path_exists(
+ "//CoerceToPyTypeNode",
+)
+@cython.test_assert_path_exists(
+ "//MulNode[@is_sequence_mul = True]",
+)
+def multiply_reversed(unicode ustring, int mul):
+ """
+ >>> astr = u"abc"
+ >>> ustr = u"abcüöä\\U0001F642"
+
+ >>> print(multiply_reversed(astr, -1))
+ <BLANKLINE>
+ >>> print(multiply_reversed(ustr, -1))
+ <BLANKLINE>
+
+ >>> print(multiply_reversed(astr, 0))
+ <BLANKLINE>
+ >>> print(multiply_reversed(ustr, 0))
+ <BLANKLINE>
+
+ >>> print(multiply_reversed(astr, 1))
+ abc
+ >>> print(multiply_reversed(ustr, 1))
+ abcüöä\U0001F642
+
+ >>> print(multiply_reversed(astr, 2))
+ abcabc
+ >>> print(multiply_reversed(ustr, 2))
+ abcüöä\U0001F642abcüöä\U0001F642
+
+ >>> print(multiply_reversed(astr, 5))
+ abcabcabcabcabc
+ >>> print(multiply_reversed(ustr, 5))
+ abcüöä\U0001F642abcüöä\U0001F642abcüöä\U0001F642abcüöä\U0001F642abcüöä\U0001F642
+ """
+ return mul * ustring
+
+
+@cython.test_fail_if_path_exists(
+ "//CoerceToPyTypeNode",
+)
+def unicode__mul__(unicode ustring, int mul):
+ """
+ >>> astr = u"abc"
+ >>> ustr = u"abcüöä\\U0001F642"
+
+ >>> print(unicode__mul__(astr, -1))
+ <BLANKLINE>
+ >>> print(unicode__mul__(ustr, -1))
+ <BLANKLINE>
+
+ >>> print(unicode__mul__(astr, 0))
+ <BLANKLINE>
+ >>> print(unicode__mul__(ustr, 0))
+ <BLANKLINE>
+
+ >>> print(unicode__mul__(astr, 1))
+ abc
+ >>> print(unicode__mul__(ustr, 1))
+ abcüöä\U0001F642
+
+ >>> print(unicode__mul__(astr, 2))
+ abcabc
+ >>> print(unicode__mul__(ustr, 2))
+ abcüöä\U0001F642abcüöä\U0001F642
+
+ >>> print(unicode__mul__(astr, 5))
+ abcabcabcabcabc
+ >>> print(unicode__mul__(ustr, 5))
+ abcüöä\U0001F642abcüöä\U0001F642abcüöä\U0001F642abcüöä\U0001F642abcüöä\U0001F642
+ """
+ return ustring.__mul__(mul)