diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2019-05-23 06:41:18 -0700 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2019-09-05 22:07:06 -0700 |
commit | b6a3ee3b7a961cfc7bcf8740c2bc89153c07f6b2 (patch) | |
tree | e10b5be4cbd0c9064060a183572ad22c6a6ee044 /numpy/core/shape_base.py | |
parent | e3f4c536a0014789dbd0321926b4f62c39d73719 (diff) | |
download | numpy-b6a3ee3b7a961cfc7bcf8740c2bc89153c07f6b2.tar.gz |
ENH: Always produce a consistent shape in the result of `argwhere`
Previously this would return 1d indices even though the array is zero-d.
Note that using atleast1d inside numeric required an import change to avoid a circular import.
Diffstat (limited to 'numpy/core/shape_base.py')
-rw-r--r-- | numpy/core/shape_base.py | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py index 710f64827..d7e769e62 100644 --- a/numpy/core/shape_base.py +++ b/numpy/core/shape_base.py @@ -9,8 +9,9 @@ import warnings from . import numeric as _nx from . import overrides -from .numeric import array, asanyarray, newaxis +from ._asarray import array, asanyarray from .multiarray import normalize_axis_index +from . import fromnumeric as _from_nx array_function_dispatch = functools.partial( @@ -123,7 +124,7 @@ def atleast_2d(*arys): if ary.ndim == 0: result = ary.reshape(1, 1) elif ary.ndim == 1: - result = ary[newaxis, :] + result = ary[_nx.newaxis, :] else: result = ary res.append(result) @@ -193,9 +194,9 @@ def atleast_3d(*arys): if ary.ndim == 0: result = ary.reshape(1, 1, 1) elif ary.ndim == 1: - result = ary[newaxis, :, newaxis] + result = ary[_nx.newaxis, :, _nx.newaxis] elif ary.ndim == 2: - result = ary[:, :, newaxis] + result = ary[:, :, _nx.newaxis] else: result = ary res.append(result) @@ -435,9 +436,9 @@ def stack(arrays, axis=0, out=None): # Internal functions to eliminate the overhead of repeated dispatch in one of # the two possible paths inside np.block. # Use getattr to protect against __array_function__ being disabled. -_size = getattr(_nx.size, '__wrapped__', _nx.size) -_ndim = getattr(_nx.ndim, '__wrapped__', _nx.ndim) -_concatenate = getattr(_nx.concatenate, '__wrapped__', _nx.concatenate) +_size = getattr(_from_nx.size, '__wrapped__', _from_nx.size) +_ndim = getattr(_from_nx.ndim, '__wrapped__', _from_nx.ndim) +_concatenate = getattr(_from_nx.concatenate, '__wrapped__', _from_nx.concatenate) def _block_format_index(index): |