summaryrefslogtreecommitdiff
path: root/numpy/core/function_base.py
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2018-11-14 14:19:39 -0500
committerMarten van Kerkwijk <mhvk@astro.utoronto.ca>2018-12-05 09:47:22 -0500
commit58ebb6a7d77cf89afeb888a70aff23e03d213788 (patch)
tree981d0fa4d4f80044a59bb0574241fbd25e89fa48 /numpy/core/function_base.py
parentbd1d6a5d51cda6fdac6986669962e6e79f425656 (diff)
downloadnumpy-58ebb6a7d77cf89afeb888a70aff23e03d213788.tar.gz
ENH: Allow {lin,log,geom}space start and stop to be arrays.
Diffstat (limited to 'numpy/core/function_base.py')
-rw-r--r--numpy/core/function_base.py73
1 files changed, 42 insertions, 31 deletions
diff --git a/numpy/core/function_base.py b/numpy/core/function_base.py
index 0fc56e70e..a7682620a 100644
--- a/numpy/core/function_base.py
+++ b/numpy/core/function_base.py
@@ -46,9 +46,9 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None):
Parameters
----------
- start : scalar
+ start : scalar or array_like
The starting value of the sequence.
- stop : scalar
+ stop : scalar or array_like
The end value of the sequence, unless `endpoint` is set to False.
In that case, the sequence consists of all but the last of ``num + 1``
evenly spaced samples, so that `stop` is excluded. Note that the step
@@ -72,7 +72,9 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None):
samples : ndarray
There are `num` equally spaced samples in the closed interval
``[start, stop]`` or the half-open interval ``[start, stop)``
- (depending on whether `endpoint` is True or False).
+ (depending on whether `endpoint` is True or False). If start
+ or stop are array-like, then the samples will be along a new
+ axis inserted at the beginning.
step : float, optional
Only returned if `retstep` is True
@@ -128,16 +130,15 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None):
if dtype is None:
dtype = dt
- y = _nx.arange(0, num, dtype=dt)
-
delta = stop - start
+ y = _nx.arange(0, num, dtype=dt).reshape((-1,) + (1,) * delta.ndim)
# In-place multiplication y *= delta/div is faster, but prevents the multiplicant
# from overriding what class is produced, and thus prevents, e.g. use of Quantities,
# see gh-7142. Hence, we multiply in place only for standard scalar types.
- _mult_inplace = _nx.isscalar(delta)
+ _mult_inplace = _nx.isscalar(delta)
if num > 1:
step = delta / div
- if step == 0:
+ if _nx.any(step == 0):
# Special handling for denormal numbers, gh-5437
y /= div
if _mult_inplace:
@@ -182,9 +183,9 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None):
Parameters
----------
- start : float
+ start : float or array_like
``base ** start`` is the starting value of the sequence.
- stop : float
+ stop : float or array_like
``base ** stop`` is the final value of the sequence, unless `endpoint`
is False. In that case, ``num + 1`` values are spaced over the
interval in log-space, of which all but the last (a sequence of
@@ -205,7 +206,9 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None):
Returns
-------
samples : ndarray
- `num` samples, equally spaced on a log scale.
+ `num` samples, equally spaced on a log scale. If start or stop are
+ array-like, then the samples will be along a new axis inserted at
+ the beginning.
See Also
--------
@@ -270,9 +273,9 @@ def geomspace(start, stop, num=50, endpoint=True, dtype=None):
Parameters
----------
- start : scalar
+ start : scalar or array_like
The starting value of the sequence.
- stop : scalar
+ stop : scalar or array_like
The final value of the sequence, unless `endpoint` is False.
In that case, ``num + 1`` values are spaced over the
interval in log-space, of which all but the last (a sequence of
@@ -289,7 +292,9 @@ def geomspace(start, stop, num=50, endpoint=True, dtype=None):
Returns
-------
samples : ndarray
- `num` samples, equally spaced on a log scale.
+ `num` samples, equally spaced on a log scale. If start or stop are
+ array-like, then the samples will be along a new axis inserted at
+ the beginning.
See Also
--------
@@ -349,40 +354,46 @@ def geomspace(start, stop, num=50, endpoint=True, dtype=None):
>>> plt.show()
"""
- if start == 0 or stop == 0:
+ start = asanyarray(start)
+ stop = asanyarray(stop)
+ if _nx.any(start == 0) or _nx.any(stop == 0):
raise ValueError('Geometric sequence cannot include zero')
- dt = result_type(start, stop, float(num))
+ dt = result_type(start, stop, float(num), _nx.zeros((), dtype))
if dtype is None:
dtype = dt
else:
# complex to dtype('complex128'), for instance
dtype = _nx.dtype(dtype)
+ # Promote both arguments to the same dtype in case, for instance, one is
+ # complex and another is negative and log would produce NaN otherwise.
+ # Copy since we may change things in-place further down.
+ start = start.astype(dt, copy=True)
+ stop = stop.astype(dt, copy=True)
+
+ out_sign = _nx.ones(_nx.broadcast(start, stop).shape, dt)
# Avoid negligible real or imaginary parts in output by rotating to
# positive real, calculating, then undoing rotation
- out_sign = 1
- if start.real == stop.real == 0:
- start, stop = start.imag, stop.imag
- out_sign = 1j * out_sign
- if _nx.sign(start) == _nx.sign(stop) == -1:
- start, stop = -start, -stop
- out_sign = -out_sign
-
- # Promote both arguments to the same dtype in case, for instance, one is
- # complex and another is negative and log would produce NaN otherwise
- start = start + (stop - stop)
- stop = stop + (start - start)
- if _nx.issubdtype(dtype, _nx.complexfloating):
- start = start + 0j
- stop = stop + 0j
+ if _nx.issubdtype(dt, _nx.complexfloating):
+ all_imag = (start.real == 0.) & (stop.real == 0.)
+ if _nx.any(all_imag):
+ start[all_imag] = start[all_imag].imag
+ stop[all_imag] = stop[all_imag].imag
+ out_sign[all_imag] = 1j
+
+ both_negative = (_nx.sign(start) == -1) & (_nx.sign(stop) == -1)
+ if _nx.any(both_negative):
+ _nx.negative(start, out=start, where=both_negative)
+ _nx.negative(stop, out=stop, where=both_negative)
+ _nx.negative(out_sign, out=out_sign, where=both_negative)
log_start = _nx.log10(start)
log_stop = _nx.log10(stop)
result = out_sign * logspace(log_start, log_stop, num=num,
endpoint=endpoint, base=10.0, dtype=dtype)
- return result.astype(dtype)
+ return result.astype(dtype, copy=False)
#always succeed