From b6a3ee3b7a961cfc7bcf8740c2bc89153c07f6b2 Mon Sep 17 00:00:00 2001 From: Eric Wieser Date: Thu, 23 May 2019 06:41:18 -0700 Subject: ENH: Always produce a consistent shape in the result of `argwhere` Previously this would return 1d indices even though the array is zero-d. Note that using atleast1d inside numeric required an import change to avoid a circular import. --- numpy/core/shape_base.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) (limited to 'numpy/core/shape_base.py') diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py index 710f64827..d7e769e62 100644 --- a/numpy/core/shape_base.py +++ b/numpy/core/shape_base.py @@ -9,8 +9,9 @@ import warnings from . import numeric as _nx from . import overrides -from .numeric import array, asanyarray, newaxis +from ._asarray import array, asanyarray from .multiarray import normalize_axis_index +from . import fromnumeric as _from_nx array_function_dispatch = functools.partial( @@ -123,7 +124,7 @@ def atleast_2d(*arys): if ary.ndim == 0: result = ary.reshape(1, 1) elif ary.ndim == 1: - result = ary[newaxis, :] + result = ary[_nx.newaxis, :] else: result = ary res.append(result) @@ -193,9 +194,9 @@ def atleast_3d(*arys): if ary.ndim == 0: result = ary.reshape(1, 1, 1) elif ary.ndim == 1: - result = ary[newaxis, :, newaxis] + result = ary[_nx.newaxis, :, _nx.newaxis] elif ary.ndim == 2: - result = ary[:, :, newaxis] + result = ary[:, :, _nx.newaxis] else: result = ary res.append(result) @@ -435,9 +436,9 @@ def stack(arrays, axis=0, out=None): # Internal functions to eliminate the overhead of repeated dispatch in one of # the two possible paths inside np.block. # Use getattr to protect against __array_function__ being disabled. -_size = getattr(_nx.size, '__wrapped__', _nx.size) -_ndim = getattr(_nx.ndim, '__wrapped__', _nx.ndim) -_concatenate = getattr(_nx.concatenate, '__wrapped__', _nx.concatenate) +_size = getattr(_from_nx.size, '__wrapped__', _from_nx.size) +_ndim = getattr(_from_nx.ndim, '__wrapped__', _from_nx.ndim) +_concatenate = getattr(_from_nx.concatenate, '__wrapped__', _from_nx.concatenate) def _block_format_index(index): -- cgit v1.2.1