diff options
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r-- | numpy/lib/index_tricks.py | 103 |
1 files changed, 64 insertions, 39 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index 58dd394e1..6913d2b95 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -3,12 +3,11 @@ import sys import math import warnings +import numpy as np from .._utils import set_module import numpy.core.numeric as _nx -from numpy.core.numeric import ( - asarray, ScalarType, array, alltrue, cumprod, arange, ndim -) -from numpy.core.numerictypes import find_common_type, issubdtype +from numpy.core.numeric import ScalarType, array +from numpy.core.numerictypes import issubdtype import numpy.matrixlib as matrixlib from .function_base import diff @@ -94,7 +93,7 @@ def ix_(*args): nd = len(args) for k, new in enumerate(args): if not isinstance(new, _nx.ndarray): - new = asarray(new) + new = np.asarray(new) if new.size == 0: # Explicitly type empty arrays to avoid float default new = new.astype(_nx.intp) @@ -211,13 +210,13 @@ class nd_grid: class MGridClass(nd_grid): """ - `nd_grid` instance which returns a dense multi-dimensional "meshgrid". + An instance which returns a dense multi-dimensional "meshgrid". - An instance of `numpy.lib.index_tricks.nd_grid` which returns an dense - (or fleshed out) mesh-grid when indexed, so that each returned argument - has the same shape. The dimensions and number of the output arrays are - equal to the number of indexing dimensions. If the step length is not a - complex number, then the stop is not inclusive. + An instance which returns a dense (or fleshed out) mesh-grid + when indexed, so that each returned argument has the same shape. + The dimensions and number of the output arrays are equal to the + number of indexing dimensions. If the step length is not a complex + number, then the stop is not inclusive. However, if the step length is a **complex number** (e.g. 5j), then the integer part of its magnitude is interpreted as specifying the @@ -230,8 +229,7 @@ class MGridClass(nd_grid): See Also -------- - lib.index_tricks.nd_grid : class of `ogrid` and `mgrid` objects - ogrid : like mgrid but returns open (not fleshed out) mesh grids + ogrid : like `mgrid` but returns open (not fleshed out) mesh grids meshgrid: return coordinate matrices from coordinate vectors r_ : array concatenator :ref:`how-to-partition` @@ -263,13 +261,13 @@ mgrid = MGridClass() class OGridClass(nd_grid): """ - `nd_grid` instance which returns an open multi-dimensional "meshgrid". + An instance which returns an open multi-dimensional "meshgrid". - An instance of `numpy.lib.index_tricks.nd_grid` which returns an open - (i.e. not fleshed out) mesh-grid when indexed, so that only one dimension - of each returned array is greater than 1. The dimension and number of the - output arrays are equal to the number of indexing dimensions. If the step - length is not a complex number, then the stop is not inclusive. + An instance which returns an open (i.e. not fleshed out) mesh-grid + when indexed, so that only one dimension of each returned array is + greater than 1. The dimension and number of the output arrays are + equal to the number of indexing dimensions. If the step length is + not a complex number, then the stop is not inclusive. However, if the step length is a **complex number** (e.g. 5j), then the integer part of its magnitude is interpreted as specifying the @@ -283,7 +281,6 @@ class OGridClass(nd_grid): See Also -------- - np.lib.index_tricks.nd_grid : class of `ogrid` and `mgrid` objects mgrid : like `ogrid` but returns dense (or fleshed out) mesh grids meshgrid: return coordinate matrices from coordinate vectors r_ : array concatenator @@ -343,9 +340,8 @@ class AxisConcatenator: axis = self.axis objs = [] - scalars = [] - arraytypes = [] - scalartypes = [] + # dtypes or scalars for weak scalar handling in result_type + result_type_objs = [] for k, item in enumerate(key): scalar = False @@ -391,12 +387,10 @@ class AxisConcatenator: except (ValueError, TypeError) as e: raise ValueError("unknown special directive") from e elif type(item) in ScalarType: - newobj = array(item, ndmin=ndmin) - scalars.append(len(objs)) scalar = True - scalartypes.append(newobj.dtype) + newobj = item else: - item_ndim = ndim(item) + item_ndim = np.ndim(item) newobj = array(item, copy=False, subok=True, ndmin=ndmin) if trans1d != -1 and item_ndim < ndmin: k2 = ndmin - item_ndim @@ -406,15 +400,20 @@ class AxisConcatenator: defaxes = list(range(ndmin)) axes = defaxes[:k1] + defaxes[k2:] + defaxes[k1:k2] newobj = newobj.transpose(axes) + objs.append(newobj) - if not scalar and isinstance(newobj, _nx.ndarray): - arraytypes.append(newobj.dtype) + if scalar: + result_type_objs.append(item) + else: + result_type_objs.append(newobj.dtype) - # Ensure that scalars won't up-cast unless warranted - final_dtype = find_common_type(arraytypes, scalartypes) - if final_dtype is not None: - for k in scalars: - objs[k] = objs[k].astype(final_dtype) + # Ensure that scalars won't up-cast unless warranted, for 0, drops + # through to error in concatenate. + if len(result_type_objs) != 0: + final_dtype = _nx.result_type(*result_type_objs) + # concatenate could do cast, but that can be overriden: + objs = [array(obj, copy=False, subok=True, + ndmin=ndmin, dtype=final_dtype) for obj in objs] res = self.concatenate(tuple(objs), axis=axis) @@ -596,7 +595,7 @@ class ndenumerate: """ def __init__(self, arr): - self.iter = asarray(arr).flat + self.iter = np.asarray(arr).flat def __next__(self): """ @@ -909,9 +908,9 @@ def fill_diagonal(a, val, wrap=False): else: # For more than d=2, the strided formula is only valid for arrays with # all dimensions equal, so we check first. - if not alltrue(diff(a.shape) == 0): + if not np.all(diff(a.shape) == 0): raise ValueError("All dimensions of input must be of equal length") - step = 1 + (cumprod(a.shape[:-1])).sum() + step = 1 + (np.cumprod(a.shape[:-1])).sum() # Write the value out into the diagonal. a.flat[:end:step] = val @@ -982,7 +981,7 @@ def diag_indices(n, ndim=2): [0, 1]]]) """ - idx = arange(n) + idx = np.arange(n) return (idx,) * ndim @@ -1009,13 +1008,39 @@ def diag_indices_from(arr): ----- .. versionadded:: 1.4.0 + Examples + -------- + + Create a 4 by 4 array. + + >>> a = np.arange(16).reshape(4, 4) + >>> a + array([[ 0, 1, 2, 3], + [ 4, 5, 6, 7], + [ 8, 9, 10, 11], + [12, 13, 14, 15]]) + + Get the indices of the diagonal elements. + + >>> di = np.diag_indices_from(a) + >>> di + (array([0, 1, 2, 3]), array([0, 1, 2, 3])) + + >>> a[di] + array([ 0, 5, 10, 15]) + + This is simply syntactic sugar for diag_indices. + + >>> np.diag_indices(a.shape[0]) + (array([0, 1, 2, 3]), array([0, 1, 2, 3])) + """ if not arr.ndim >= 2: raise ValueError("input array must be at least 2-d") # For more than d=2, the strided formula is only valid for arrays with # all dimensions equal, so we check first. - if not alltrue(diff(arr.shape) == 0): + if not np.all(diff(arr.shape) == 0): raise ValueError("All dimensions of input must be of equal length") return diag_indices(arr.shape[0], arr.ndim) |