From 8f9ec01afa2b2dee209010f6155da27af02de96d Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 12 Jun 2019 00:08:06 -0400 Subject: MAINT: avoid nested dispatch in numpy.core.shape_base (#13634) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * MAINT: avoid nested dispatch in numpy.core.shape_base This is a partial reprise of the optimizations from GH-13585. The trade-offs here are about readability, performance and whether these functions automatically work on ndarray subclasses. You'll have to judge the readability costs for yourself, but I think this is pretty reasonable. Here are the performance numbers for three relevant functions with the following IPython script: import numpy as np x = np.array([1]) xs = [x, x, x] for func in [np.stack, np.vstack, np.block]: %timeit func(xs) | Function | Master | This PR | | --- | --- | --- | | `stack` | 6.36 µs ± 175 ns | 6 µs ± 174 ns | | `vstack` | 7.18 µs ± 186 ns | 5.43 µs ± 125 ns | | `block` | 15.1 µs ± 141 ns | 11.3 µs ± 104 ns | The performance benefit for `stack` is somewhat marginal (perhaps it should be dropped), but it's much more meaningful inside `vstack`/`hstack` and `block`, because these functions call other dispatched functions within a loop. For automatically working on ndarray subclasses, the main concern would be that by skipping dispatch with `concatenate`, subclasses that define `concatenate` won't automatically get implementations for `*stack` functions. (But I don't think we should consider ourselves obligated to guarantee these implementation details, as I write in GH-13633.) `block` also will not get an automatic implementation, but given that `block` uses two different code paths depending on argument values, this is probably a good thing, because there's no way the code path not involving `concatenate` could automatically work (it uses `empty()`). * MAINT: only remove internal use in np.block * MAINT: fixup comment on np.block optimization --- numpy/core/shape_base.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) (limited to 'numpy/core/shape_base.py') diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py index c23ffa935..710f64827 100644 --- a/numpy/core/shape_base.py +++ b/numpy/core/shape_base.py @@ -123,7 +123,7 @@ def atleast_2d(*arys): if ary.ndim == 0: result = ary.reshape(1, 1) elif ary.ndim == 1: - result = ary[newaxis,:] + result = ary[newaxis, :] else: result = ary res.append(result) @@ -193,9 +193,9 @@ def atleast_3d(*arys): if ary.ndim == 0: result = ary.reshape(1, 1, 1) elif ary.ndim == 1: - result = ary[newaxis,:, newaxis] + result = ary[newaxis, :, newaxis] elif ary.ndim == 2: - result = ary[:,:, newaxis] + result = ary[:, :, newaxis] else: result = ary res.append(result) @@ -432,6 +432,14 @@ def stack(arrays, axis=0, out=None): return _nx.concatenate(expanded_arrays, axis=axis, out=out) +# Internal functions to eliminate the overhead of repeated dispatch in one of +# the two possible paths inside np.block. +# Use getattr to protect against __array_function__ being disabled. +_size = getattr(_nx.size, '__wrapped__', _nx.size) +_ndim = getattr(_nx.ndim, '__wrapped__', _nx.ndim) +_concatenate = getattr(_nx.concatenate, '__wrapped__', _nx.concatenate) + + def _block_format_index(index): """ Convert a list of indices ``[0, 1, 2]`` into ``"arrays[0][1][2]"``. @@ -512,8 +520,8 @@ def _block_check_depths_match(arrays, parent_index=[]): return parent_index + [None], 0, 0 else: # We've 'bottomed out' - arrays is either a scalar or an array - size = _nx.size(arrays) - return parent_index, _nx.ndim(arrays), size + size = _size(arrays) + return parent_index, _ndim(arrays), size def _atleast_nd(a, ndim): @@ -656,7 +664,7 @@ def _block(arrays, max_depth, result_ndim, depth=0): if depth < max_depth: arrs = [_block(arr, max_depth, result_ndim, depth+1) for arr in arrays] - return _nx.concatenate(arrs, axis=-(max_depth-depth)) + return _concatenate(arrs, axis=-(max_depth-depth)) else: # We've 'bottomed out' - arrays is either a scalar or an array # type(arrays) is not list @@ -874,7 +882,7 @@ def _block_slicing(arrays, list_ndim, result_ndim): # Test preferring F only in the case that all input arrays are F F_order = all(arr.flags['F_CONTIGUOUS'] for arr in arrays) - C_order = all(arr.flags['C_CONTIGUOUS'] for arr in arrays) + C_order = all(arr.flags['C_CONTIGUOUS'] for arr in arrays) order = 'F' if F_order and not C_order else 'C' result = _nx.empty(shape=shape, dtype=dtype, order=order) # Note: In a c implementation, the function -- cgit v1.2.1