summaryrefslogtreecommitdiff
path: root/numpy/core/shape_base.py
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2019-05-23 06:41:18 -0700
committerEric Wieser <wieser.eric@gmail.com>2019-09-05 22:07:06 -0700
commitb6a3ee3b7a961cfc7bcf8740c2bc89153c07f6b2 (patch)
treee10b5be4cbd0c9064060a183572ad22c6a6ee044 /numpy/core/shape_base.py
parente3f4c536a0014789dbd0321926b4f62c39d73719 (diff)
downloadnumpy-b6a3ee3b7a961cfc7bcf8740c2bc89153c07f6b2.tar.gz
ENH: Always produce a consistent shape in the result of `argwhere`
Previously this would return 1d indices even though the array is zero-d. Note that using atleast1d inside numeric required an import change to avoid a circular import.
Diffstat (limited to 'numpy/core/shape_base.py')
-rw-r--r--numpy/core/shape_base.py15
1 files changed, 8 insertions, 7 deletions
diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py
index 710f64827..d7e769e62 100644
--- a/numpy/core/shape_base.py
+++ b/numpy/core/shape_base.py
@@ -9,8 +9,9 @@ import warnings
from . import numeric as _nx
from . import overrides
-from .numeric import array, asanyarray, newaxis
+from ._asarray import array, asanyarray
from .multiarray import normalize_axis_index
+from . import fromnumeric as _from_nx
array_function_dispatch = functools.partial(
@@ -123,7 +124,7 @@ def atleast_2d(*arys):
if ary.ndim == 0:
result = ary.reshape(1, 1)
elif ary.ndim == 1:
- result = ary[newaxis, :]
+ result = ary[_nx.newaxis, :]
else:
result = ary
res.append(result)
@@ -193,9 +194,9 @@ def atleast_3d(*arys):
if ary.ndim == 0:
result = ary.reshape(1, 1, 1)
elif ary.ndim == 1:
- result = ary[newaxis, :, newaxis]
+ result = ary[_nx.newaxis, :, _nx.newaxis]
elif ary.ndim == 2:
- result = ary[:, :, newaxis]
+ result = ary[:, :, _nx.newaxis]
else:
result = ary
res.append(result)
@@ -435,9 +436,9 @@ def stack(arrays, axis=0, out=None):
# Internal functions to eliminate the overhead of repeated dispatch in one of
# the two possible paths inside np.block.
# Use getattr to protect against __array_function__ being disabled.
-_size = getattr(_nx.size, '__wrapped__', _nx.size)
-_ndim = getattr(_nx.ndim, '__wrapped__', _nx.ndim)
-_concatenate = getattr(_nx.concatenate, '__wrapped__', _nx.concatenate)
+_size = getattr(_from_nx.size, '__wrapped__', _from_nx.size)
+_ndim = getattr(_from_nx.ndim, '__wrapped__', _from_nx.ndim)
+_concatenate = getattr(_from_nx.concatenate, '__wrapped__', _from_nx.concatenate)
def _block_format_index(index):