summaryrefslogtreecommitdiff
path: root/numpy/core/shape_base.py
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-11-12 22:53:55 -0800
committerEric Wieser <wieser.eric@gmail.com>2017-11-12 23:03:12 -0800
commitae338e4c7deb4268ed9122b8dd4922424eacd58f (patch)
treeef1782e6eadf88cb1c387710addea603c5618e7b /numpy/core/shape_base.py
parent6789c257e2f5c1d90b9ca961dd655896c6fa9ea5 (diff)
downloadnumpy-ae338e4c7deb4268ed9122b8dd4922424eacd58f.tar.gz
REV: Undo bad rebase in gh-8981 (7fdfdd6a52fc0761c0d45931247c5ed2480224eb)
This restores the changes in gh-9667 that were overwritten.
Diffstat (limited to 'numpy/core/shape_base.py')
-rw-r--r--numpy/core/shape_base.py236
1 files changed, 88 insertions, 148 deletions
diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py
index 026ad603a..8a047fdda 100644
--- a/numpy/core/shape_base.py
+++ b/numpy/core/shape_base.py
@@ -365,78 +365,93 @@ def stack(arrays, axis=0, out=None):
return _nx.concatenate(expanded_arrays, axis=axis, out=out)
-class _Recurser(object):
+def _block_check_depths_match(arrays, parent_index=[]):
"""
- Utility class for recursing over nested iterables
+ Recursive function checking that the depths of nested lists in `arrays`
+ all match. Mismatch raises a ValueError as described in the block
+ docstring below.
+
+ The entire index (rather than just the depth) needs to be calculated
+ for each innermost list, in case an error needs to be raised, so that
+ the index of the offending list can be printed as part of the error.
+
+ The parameter `parent_index` is the full index of `arrays` within the
+ nested lists passed to _block_check_depths_match at the top of the
+ recursion.
+ The return value is a pair. The first item returned is the full index
+ of an element (specifically the first element) from the bottom of the
+ nesting in `arrays`. An empty list at the bottom of the nesting is
+ represented by a `None` index.
+ The second item is the maximum of the ndims of the arrays nested in
+ `arrays`.
"""
- def __init__(self, recurse_if):
- self.recurse_if = recurse_if
-
- def map_reduce(self, x, f_map=lambda x, **kwargs: x,
- f_reduce=lambda x, **kwargs: x,
- f_kwargs=lambda **kwargs: kwargs,
- **kwargs):
- """
- Iterate over the nested list, applying:
- * ``f_map`` (T -> U) to items
- * ``f_reduce`` (Iterable[U] -> U) to mapped items
-
- For instance, ``map_reduce([[1, 2], 3, 4])`` is::
-
- f_reduce([
- f_reduce([
- f_map(1),
- f_map(2)
- ]),
- f_map(3),
- f_map(4)
- ]])
-
-
- State can be passed down through the calls with `f_kwargs`,
- to iterables of mapped items. When kwargs are passed, as in
- ``map_reduce([[1, 2], 3, 4], **kw)``, this becomes::
-
- kw1 = f_kwargs(**kw)
- kw2 = f_kwargs(**kw1)
- f_reduce([
- f_reduce([
- f_map(1), **kw2)
- f_map(2, **kw2)
- ], **kw1),
- f_map(3, **kw1),
- f_map(4, **kw1)
- ]], **kw)
- """
- def f(x, **kwargs):
- if not self.recurse_if(x):
- return f_map(x, **kwargs)
- else:
- next_kwargs = f_kwargs(**kwargs)
- return f_reduce((
- f(xi, **next_kwargs)
- for xi in x
- ), **kwargs)
- return f(x, **kwargs)
-
- def walk(self, x, index=()):
- """
- Iterate over x, yielding (index, value, entering), where
-
- * ``index``: a tuple of indices up to this point
- * ``value``: equal to ``x[index[0]][...][index[-1]]``. On the first iteration, is
- ``x`` itself
- * ``entering``: bool. The result of ``recurse_if(value)``
- """
- do_recurse = self.recurse_if(x)
- yield index, x, do_recurse
-
- if not do_recurse:
- return
- for i, xi in enumerate(x):
- # yield from ...
- for v in self.walk(xi, index + (i,)):
- yield v
+ def format_index(index):
+ idx_str = ''.join('[{}]'.format(i) for i in index if i is not None)
+ return 'arrays' + idx_str
+ if type(arrays) is tuple:
+ # not strictly necessary, but saves us from:
+ # - more than one way to do things - no point treating tuples like
+ # lists
+ # - horribly confusing behaviour that results when tuples are
+ # treated like ndarray
+ raise TypeError(
+ '{} is a tuple. '
+ 'Only lists can be used to arrange blocks, and np.block does '
+ 'not allow implicit conversion from tuple to ndarray.'.format(
+ format_index(parent_index)
+ )
+ )
+ elif type(arrays) is list and len(arrays) > 0:
+ idxs_ndims = (_block_check_depths_match(arr, parent_index + [i])
+ for i, arr in enumerate(arrays))
+
+ first_index, max_arr_ndim = next(idxs_ndims)
+ for index, ndim in idxs_ndims:
+ if ndim > max_arr_ndim:
+ max_arr_ndim = ndim
+ if len(index) != len(first_index):
+ raise ValueError(
+ "List depths are mismatched. First element was at depth "
+ "{}, but there is an element at depth {} ({})".format(
+ len(first_index),
+ len(index),
+ format_index(index)
+ )
+ )
+ return first_index, max_arr_ndim
+ elif type(arrays) is list and len(arrays) == 0:
+ # We've 'bottomed out' on an empty list
+ return parent_index + [None], 0
+ else:
+ # We've 'bottomed out' - arrays is either a scalar or an array
+ return parent_index, _nx.ndim(arrays)
+
+
+def _block(arrays, max_depth, result_ndim):
+ """
+ Internal implementation of block. `arrays` is the argument passed to
+ block. `max_depth` is the depth of nested lists within `arrays` and
+ `result_ndim` is the greatest of the dimensions of the arrays in
+ `arrays` and the depth of the lists in `arrays` (see block docstring
+ for details).
+ """
+ def atleast_nd(a, ndim):
+ # Ensures `a` has at least `ndim` dimensions by prepending
+ # ones to `a.shape` as necessary
+ return array(a, ndmin=ndim, copy=False, subok=True)
+
+ def block_recursion(arrays, depth=0):
+ if depth < max_depth:
+ if len(arrays) == 0:
+ raise ValueError('Lists cannot be empty')
+ arrs = [block_recursion(arr, depth+1) for arr in arrays]
+ return _nx.concatenate(arrs, axis=-(max_depth-depth))
+ else:
+ # We've 'bottomed out' - arrays is either a scalar or an array
+ # type(arrays) is not list
+ return atleast_nd(arrays, result_ndim)
+
+ return block_recursion(arrays)
def block(arrays):
@@ -587,81 +602,6 @@ def block(arrays):
"""
- def atleast_nd(x, ndim):
- x = asanyarray(x)
- diff = max(ndim - x.ndim, 0)
- return x[(None,)*diff + (Ellipsis,)]
-
- def format_index(index):
- return 'arrays' + ''.join('[{}]'.format(i) for i in index)
-
- rec = _Recurser(recurse_if=lambda x: type(x) is list)
-
- # ensure that the lists are all matched in depth
- list_ndim = None
- any_empty = False
- for index, value, entering in rec.walk(arrays):
- if type(value) is tuple:
- # not strictly necessary, but saves us from:
- # - more than one way to do things - no point treating tuples like
- # lists
- # - horribly confusing behaviour that results when tuples are
- # treated like ndarray
- raise TypeError(
- '{} is a tuple. '
- 'Only lists can be used to arrange blocks, and np.block does '
- 'not allow implicit conversion from tuple to ndarray.'.format(
- format_index(index)
- )
- )
- if not entering:
- curr_depth = len(index)
- elif len(value) == 0:
- curr_depth = len(index) + 1
- any_empty = True
- else:
- continue
-
- if list_ndim is not None and list_ndim != curr_depth:
- raise ValueError(
- "List depths are mismatched. First element was at depth {}, "
- "but there is an element at depth {} ({})".format(
- list_ndim,
- curr_depth,
- format_index(index)
- )
- )
- list_ndim = curr_depth
-
- # do this here so we catch depth mismatches first
- if any_empty:
- raise ValueError('Lists cannot be empty')
-
- # convert all the arrays to ndarrays
- arrays = rec.map_reduce(arrays,
- f_map=asanyarray,
- f_reduce=list
- )
-
- # determine the maximum dimension of the elements
- elem_ndim = rec.map_reduce(arrays,
- f_map=lambda xi: xi.ndim,
- f_reduce=max
- )
- ndim = max(list_ndim, elem_ndim)
-
- # first axis to concatenate along
- first_axis = ndim - list_ndim
-
- # Make all the elements the same dimension
- arrays = rec.map_reduce(arrays,
- f_map=lambda xi: atleast_nd(xi, ndim),
- f_reduce=list
- )
-
- # concatenate innermost lists on the right, outermost on the left
- return rec.map_reduce(arrays,
- f_reduce=lambda xs, axis: _nx.concatenate(list(xs), axis=axis),
- f_kwargs=lambda axis: dict(axis=axis+1),
- axis=first_axis
- )
+ bottom_index, arr_ndim = _block_check_depths_match(arrays)
+ list_ndim = len(bottom_index)
+ return _block(arrays, list_ndim, max(arr_ndim, list_ndim))