summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorwill-ca <37680486+will-ca@users.noreply.github.com>2020-04-01 03:08:37 -0700
committerGitHub <noreply@github.com>2020-04-01 12:08:37 +0200
commit27b5adbb461675ef775aee46d17b0a6d3b2c047e (patch)
tree170506544e7ac4c7257ea47a3ca1eabdf3440b9b
parent4fd901849ff30399289ebf6613c81d8075a79cc4 (diff)
downloadcython-27b5adbb461675ef775aee46d17b0a6d3b2c047e.tar.gz
Make fused function dispatch O(n) for `cpdef` functions. (GH-3366)
* Rewrote signature matching for fused cpdef function dispatch to use a pre-built tree index in a mutable default argument and be O(n). * Added test to ensure proper differentiation between ambiguously compatible and definitely compatible arguments. * Added test to ensure fused cpdef's can be called by the module itself during import. * Added test to ensure consistent handling of ambiguous fused cpdef signatures. * Test for explicitly defined fused cpdef method. * Add .komodoproject to .gitignore. * Add /cython_debug/ to .gitignore. Closes #1385.
-rw-r--r--.gitignore3
-rw-r--r--Cython/Compiler/FusedNode.py67
-rw-r--r--tests/run/fused_cpdef.pyx104
3 files changed, 148 insertions, 26 deletions
diff --git a/.gitignore b/.gitignore
index 665ba4743..fd5461a8c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -23,6 +23,7 @@ Demos/*/*.html
/TEST_TMP/
/build/
+/cython_build/
/wheelhouse*/
!tests/build/
/dist/
@@ -52,3 +53,5 @@ MANIFEST
/.idea
/*.iml
+# Komodo EDIT/IDE project files
+/*.komodoproject
diff --git a/Cython/Compiler/FusedNode.py b/Cython/Compiler/FusedNode.py
index dbf3bbeac..42f15eea3 100644
--- a/Cython/Compiler/FusedNode.py
+++ b/Cython/Compiler/FusedNode.py
@@ -584,6 +584,26 @@ class FusedCFuncDefNode(StatListNode):
{{endif}}
""")
+ def _fused_signature_index(self, pyx_code):
+ """
+ Generate Cython code for constructing a persistent nested dictionary index of
+ fused type specialization signatures.
+ """
+ pyx_code.put_chunk(
+ u"""
+ if not _fused_sigindex:
+ for sig in <dict>signatures:
+ sigindex_node = _fused_sigindex
+ sig_series = sig.strip('()').split('|')
+ for sig_type in sig_series[:-1]:
+ if sig_type not in sigindex_node:
+ sigindex_node[sig_type] = sigindex_node = {}
+ else:
+ sigindex_node = sigindex_node[sig_type]
+ sigindex_node[sig_series[-1]] = sig
+ """
+ )
+
def make_fused_cpdef(self, orig_py_func, env, is_def):
"""
This creates the function that is indexable from Python and does
@@ -620,10 +640,14 @@ class FusedCFuncDefNode(StatListNode):
pyx_code.put_chunk(
u"""
- def __pyx_fused_cpdef(signatures, args, kwargs, defaults):
+ def __pyx_fused_cpdef(signatures, args, kwargs, defaults, *, _fused_sigindex={}):
# FIXME: use a typed signature - currently fails badly because
# default arguments inherit the types we specify here!
+ cdef list search_list
+
+ cdef dict sn, sigindex_node
+
dest_sig = [None] * {{n_fused}}
if kwargs is not None and not kwargs:
@@ -691,23 +715,36 @@ class FusedCFuncDefNode(StatListNode):
env.use_utility_code(Code.UtilityCode.load_cached("Import", "ImportExport.c"))
env.use_utility_code(Code.UtilityCode.load_cached("ImportNumPyArray", "ImportExport.c"))
+ self._fused_signature_index(pyx_code)
+
pyx_code.put_chunk(
u"""
- candidates = []
- for sig in <dict>signatures:
- match_found = False
- src_sig = sig.strip('()').split('|')
- for i in range(len(dest_sig)):
- dst_type = dest_sig[i]
- if dst_type is not None:
- if src_sig[i] == dst_type:
- match_found = True
- else:
- match_found = False
- break
+ sigindex_matches = []
+ sigindex_candidates = [_fused_sigindex]
+
+ for dst_type in dest_sig:
+ found_matches = []
+ found_candidates = []
+ # Make two seperate lists: One for signature sub-trees
+ # with at least one definite match, and another for
+ # signature sub-trees with only ambiguous matches
+ # (where `dest_sig[i] is None`).
+ if dst_type is None:
+ for sn in sigindex_matches:
+ found_matches.extend(sn.values())
+ for sn in sigindex_candidates:
+ found_candidates.extend(sn.values())
+ else:
+ for search_list in (sigindex_matches, sigindex_candidates):
+ for sn in search_list:
+ if dst_type in sn:
+ found_matches.append(sn[dst_type])
+ sigindex_matches = found_matches
+ sigindex_candidates = found_candidates
+ if not (found_matches or found_candidates):
+ break
- if match_found:
- candidates.append(sig)
+ candidates = sigindex_matches
if not candidates:
raise TypeError("No matching signature found")
diff --git a/tests/run/fused_cpdef.pyx b/tests/run/fused_cpdef.pyx
index 0b63c8b98..4a614e0f4 100644
--- a/tests/run/fused_cpdef.pyx
+++ b/tests/run/fused_cpdef.pyx
@@ -1,13 +1,17 @@
+# cython: language_level=3
+# mode: run
+
cimport cython
+import sys, io
cy = __import__("cython")
cpdef func1(self, cython.integral x):
- print "%s," % (self,),
+ print(f"{self},", end=' ')
if cython.integral is int:
- print 'x is int', x, cython.typeof(x)
+ print('x is int', x, cython.typeof(x))
else:
- print 'x is long', x, cython.typeof(x)
+ print('x is long', x, cython.typeof(x))
class A(object):
@@ -16,6 +20,18 @@ class A(object):
def __str__(self):
return "A"
+cdef class B:
+ cpdef int meth(self, cython.integral x):
+ print(f"{self},", end=' ')
+ if cython.integral is int:
+ print('x is int', x, cython.typeof(x))
+ else:
+ print('x is long', x, cython.typeof(x))
+ return 0
+
+ def __str__(self):
+ return "B"
+
pyfunc = func1
def test_fused_cpdef():
@@ -32,23 +48,71 @@ def test_fused_cpdef():
A, x is long 2 long
A, x is long 2 long
A, x is long 2 long
+ <BLANKLINE>
+ B, x is long 2 long
"""
func1[int](None, 2)
func1[long](None, 2)
func1(None, 2)
- print
+ print()
pyfunc[cy.int](None, 2)
pyfunc(None, 2)
- print
+ print()
A.meth[cy.int](A(), 2)
A.meth(A(), 2)
A().meth[cy.long](2)
A().meth(2)
+ print()
+
+ B().meth(2)
+
+
+midimport_run = io.StringIO()
+if sys.version_info.major < 3:
+ # Monkey-patch midimport_run.write to accept non-unicode strings under Python 2.
+ midimport_run.write = lambda c: io.StringIO.write(midimport_run, unicode(c))
+
+realstdout = sys.stdout
+sys.stdout = midimport_run
+
+try:
+ # Run `test_fused_cpdef()` during import and save the result for
+ # `test_midimport_run()`.
+ test_fused_cpdef()
+except Exception as e:
+ midimport_run.write(f"{e!r}\n")
+finally:
+ sys.stdout = realstdout
+
+def test_midimport_run():
+ # At one point, dynamically calling fused cpdef functions during import
+ # would fail because the type signature-matching indices weren't
+ # yet initialized.
+ # (See Compiler.FusedNode.FusedCFuncDefNode._fused_signature_index,
+ # GH-3366.)
+ """
+ >>> test_midimport_run()
+ None, x is int 2 int
+ None, x is long 2 long
+ None, x is long 2 long
+ <BLANKLINE>
+ None, x is int 2 int
+ None, x is long 2 long
+ <BLANKLINE>
+ A, x is int 2 int
+ A, x is long 2 long
+ A, x is long 2 long
+ A, x is long 2 long
+ <BLANKLINE>
+ B, x is long 2 long
+ """
+ print(midimport_run.getvalue(), end='')
+
def assert_raise(func, *args):
try:
@@ -70,23 +134,31 @@ def test_badcall():
assert_raise(A.meth)
assert_raise(A().meth[cy.int])
assert_raise(A.meth[cy.int])
+ assert_raise(B().meth, 1, 2, 3)
+
+def test_nomatch():
+ """
+ >>> func1(None, ())
+ Traceback (most recent call last):
+ TypeError: No matching signature found
+ """
ctypedef long double long_double
cpdef multiarg(cython.integral x, cython.floating y):
if cython.integral is int:
- print "x is an int,",
+ print("x is an int,", end=' ')
else:
- print "x is a long,",
+ print("x is a long,", end=' ')
if cython.floating is long_double:
- print "y is a long double:",
+ print("y is a long double:", end=' ')
elif float is cython.floating:
- print "y is a float:",
+ print("y is a float:", end=' ')
else:
- print "y is a double:",
+ print("y is a double:", end=' ')
- print x, y
+ print(x, y)
def test_multiarg():
"""
@@ -104,3 +176,13 @@ def test_multiarg():
multiarg[int, float](1, 2.0)
multiarg[cy.int, cy.float](1, 2.0)
multiarg(4, 5.0)
+
+def test_ambiguousmatch():
+ """
+ >>> multiarg(5, ())
+ Traceback (most recent call last):
+ TypeError: Function call with ambiguous argument types
+ >>> multiarg((), 2.0)
+ Traceback (most recent call last):
+ TypeError: Function call with ambiguous argument types
+ """