diff options
author | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2018-11-14 14:19:39 -0500 |
---|---|---|
committer | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2018-12-05 09:47:22 -0500 |
commit | 58ebb6a7d77cf89afeb888a70aff23e03d213788 (patch) | |
tree | 981d0fa4d4f80044a59bb0574241fbd25e89fa48 /numpy/core/function_base.py | |
parent | bd1d6a5d51cda6fdac6986669962e6e79f425656 (diff) | |
download | numpy-58ebb6a7d77cf89afeb888a70aff23e03d213788.tar.gz |
ENH: Allow {lin,log,geom}space start and stop to be arrays.
Diffstat (limited to 'numpy/core/function_base.py')
-rw-r--r-- | numpy/core/function_base.py | 73 |
1 files changed, 42 insertions, 31 deletions
diff --git a/numpy/core/function_base.py b/numpy/core/function_base.py index 0fc56e70e..a7682620a 100644 --- a/numpy/core/function_base.py +++ b/numpy/core/function_base.py @@ -46,9 +46,9 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None): Parameters ---------- - start : scalar + start : scalar or array_like The starting value of the sequence. - stop : scalar + stop : scalar or array_like The end value of the sequence, unless `endpoint` is set to False. In that case, the sequence consists of all but the last of ``num + 1`` evenly spaced samples, so that `stop` is excluded. Note that the step @@ -72,7 +72,9 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None): samples : ndarray There are `num` equally spaced samples in the closed interval ``[start, stop]`` or the half-open interval ``[start, stop)`` - (depending on whether `endpoint` is True or False). + (depending on whether `endpoint` is True or False). If start + or stop are array-like, then the samples will be along a new + axis inserted at the beginning. step : float, optional Only returned if `retstep` is True @@ -128,16 +130,15 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None): if dtype is None: dtype = dt - y = _nx.arange(0, num, dtype=dt) - delta = stop - start + y = _nx.arange(0, num, dtype=dt).reshape((-1,) + (1,) * delta.ndim) # In-place multiplication y *= delta/div is faster, but prevents the multiplicant # from overriding what class is produced, and thus prevents, e.g. use of Quantities, # see gh-7142. Hence, we multiply in place only for standard scalar types. - _mult_inplace = _nx.isscalar(delta) + _mult_inplace = _nx.isscalar(delta) if num > 1: step = delta / div - if step == 0: + if _nx.any(step == 0): # Special handling for denormal numbers, gh-5437 y /= div if _mult_inplace: @@ -182,9 +183,9 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None): Parameters ---------- - start : float + start : float or array_like ``base ** start`` is the starting value of the sequence. - stop : float + stop : float or array_like ``base ** stop`` is the final value of the sequence, unless `endpoint` is False. In that case, ``num + 1`` values are spaced over the interval in log-space, of which all but the last (a sequence of @@ -205,7 +206,9 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None): Returns ------- samples : ndarray - `num` samples, equally spaced on a log scale. + `num` samples, equally spaced on a log scale. If start or stop are + array-like, then the samples will be along a new axis inserted at + the beginning. See Also -------- @@ -270,9 +273,9 @@ def geomspace(start, stop, num=50, endpoint=True, dtype=None): Parameters ---------- - start : scalar + start : scalar or array_like The starting value of the sequence. - stop : scalar + stop : scalar or array_like The final value of the sequence, unless `endpoint` is False. In that case, ``num + 1`` values are spaced over the interval in log-space, of which all but the last (a sequence of @@ -289,7 +292,9 @@ def geomspace(start, stop, num=50, endpoint=True, dtype=None): Returns ------- samples : ndarray - `num` samples, equally spaced on a log scale. + `num` samples, equally spaced on a log scale. If start or stop are + array-like, then the samples will be along a new axis inserted at + the beginning. See Also -------- @@ -349,40 +354,46 @@ def geomspace(start, stop, num=50, endpoint=True, dtype=None): >>> plt.show() """ - if start == 0 or stop == 0: + start = asanyarray(start) + stop = asanyarray(stop) + if _nx.any(start == 0) or _nx.any(stop == 0): raise ValueError('Geometric sequence cannot include zero') - dt = result_type(start, stop, float(num)) + dt = result_type(start, stop, float(num), _nx.zeros((), dtype)) if dtype is None: dtype = dt else: # complex to dtype('complex128'), for instance dtype = _nx.dtype(dtype) + # Promote both arguments to the same dtype in case, for instance, one is + # complex and another is negative and log would produce NaN otherwise. + # Copy since we may change things in-place further down. + start = start.astype(dt, copy=True) + stop = stop.astype(dt, copy=True) + + out_sign = _nx.ones(_nx.broadcast(start, stop).shape, dt) # Avoid negligible real or imaginary parts in output by rotating to # positive real, calculating, then undoing rotation - out_sign = 1 - if start.real == stop.real == 0: - start, stop = start.imag, stop.imag - out_sign = 1j * out_sign - if _nx.sign(start) == _nx.sign(stop) == -1: - start, stop = -start, -stop - out_sign = -out_sign - - # Promote both arguments to the same dtype in case, for instance, one is - # complex and another is negative and log would produce NaN otherwise - start = start + (stop - stop) - stop = stop + (start - start) - if _nx.issubdtype(dtype, _nx.complexfloating): - start = start + 0j - stop = stop + 0j + if _nx.issubdtype(dt, _nx.complexfloating): + all_imag = (start.real == 0.) & (stop.real == 0.) + if _nx.any(all_imag): + start[all_imag] = start[all_imag].imag + stop[all_imag] = stop[all_imag].imag + out_sign[all_imag] = 1j + + both_negative = (_nx.sign(start) == -1) & (_nx.sign(stop) == -1) + if _nx.any(both_negative): + _nx.negative(start, out=start, where=both_negative) + _nx.negative(stop, out=stop, where=both_negative) + _nx.negative(out_sign, out=out_sign, where=both_negative) log_start = _nx.log10(start) log_stop = _nx.log10(stop) result = out_sign * logspace(log_start, log_stop, num=num, endpoint=endpoint, base=10.0, dtype=dtype) - return result.astype(dtype) + return result.astype(dtype, copy=False) #always succeed |