summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Zwicker <david.zwicker@ds.mpg.de>2019-11-21 18:18:19 +0100
committerWarren Weckesser <warren.weckesser@gmail.com>2019-11-21 12:18:19 -0500
commit8efd1e70a7789befc7a85a0b52254c68f7ae408a (patch)
treeb9f62e89a0aa3d462f2fa2d0a41a2000b59488ef
parentd428127183d46b2fbd99afefa4670642addf8d6e (diff)
downloadnumpy-8efd1e70a7789befc7a85a0b52254c68f7ae408a.tar.gz
BUG: Fix step returned by linspace when num=1 and endpoint=False (#14929)
Changed the the behavior of linspace to return a proper step size for arguments num=1 and endpoint=False, where previously NaN was returned. closes gh-14927
-rw-r--r--numpy/core/function_base.py5
-rw-r--r--numpy/core/tests/test_function_base.py20
2 files changed, 16 insertions, 9 deletions
diff --git a/numpy/core/function_base.py b/numpy/core/function_base.py
index 42604ec3f..538ac8b84 100644
--- a/numpy/core/function_base.py
+++ b/numpy/core/function_base.py
@@ -139,7 +139,7 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None,
# from overriding what class is produced, and thus prevents, e.g. use of Quantities,
# see gh-7142. Hence, we multiply in place only for standard scalar types.
_mult_inplace = _nx.isscalar(delta)
- if num > 1:
+ if div > 0:
step = delta / div
if _nx.any(step == 0):
# Special handling for denormal numbers, gh-5437
@@ -154,7 +154,8 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None,
else:
y = y * step
else:
- # 0 and 1 item long sequences have an undefined step
+ # sequences with 0 items or 1 item with endpoint=True (i.e. div <= 0)
+ # have an undefined step
step = NaN
# Multiply with delta to allow possible override of output class.
y = y * delta
diff --git a/numpy/core/tests/test_function_base.py b/numpy/core/tests/test_function_base.py
index 84b60b19c..c8a7cb6ce 100644
--- a/numpy/core/tests/test_function_base.py
+++ b/numpy/core/tests/test_function_base.py
@@ -351,14 +351,20 @@ class TestLinspace(object):
arange(j+1, dtype=int))
def test_retstep(self):
- y = linspace(0, 1, 2, retstep=True)
- assert_(isinstance(y, tuple) and len(y) == 2)
- for num in (0, 1):
- for ept in (False, True):
+ for num in [0, 1, 2]:
+ for ept in [False, True]:
y = linspace(0, 1, num, endpoint=ept, retstep=True)
- assert_(isinstance(y, tuple) and len(y) == 2 and
- len(y[0]) == num and isnan(y[1]),
- 'num={0}, endpoint={1}'.format(num, ept))
+ assert isinstance(y, tuple) and len(y) == 2
+ if num == 2:
+ y0_expect = [0.0, 1.0] if ept else [0.0, 0.5]
+ assert_array_equal(y[0], y0_expect)
+ assert_equal(y[1], y0_expect[1])
+ elif num == 1 and not ept:
+ assert_array_equal(y[0], [0.0])
+ assert_equal(y[1], 1.0)
+ else:
+ assert_array_equal(y[0], [0.0][:num])
+ assert isnan(y[1])
def test_object(self):
start = array(1, dtype='O')