summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex <aprockhill206@gmail.com>2020-05-15 09:56:38 -0700
committerAlex <aprockhill206@gmail.com>2020-07-22 11:59:31 -0700
commit5840165bd7db8628d0d5b318544943a28d799068 (patch)
treeedb285c83873da369951dc33a5f38f1f02a82054
parent8c4ce9936ec6989fdf2b2374489f023681e329a3 (diff)
downloadnumpy-5840165bd7db8628d0d5b318544943a28d799068.tar.gz
simplified
-rw-r--r--numpy/lib/function_base.py25
-rw-r--r--numpy/lib/tests/test_function_base.py3
2 files changed, 14 insertions, 14 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index f629a8fdb..3199c6169 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -4768,9 +4768,9 @@ def digitize(x, bins, right=False, edge=False):
case, i.e., bins[i-1] <= x < bins[i] is the default behavior for
monotonically increasing bins.
edge : bool, optional
- Whether to include the last right edge if right == False or the first
- left edge if right == True. If egde==True, the entire interval is
- included that would otherwise not be for the first or last edge case.
+ Whether to include the last right edge if right==False or the first
+ left edge if right==True so that the whole interval from the least
+ to the greatest value of bins is covered.
Returns
-------
@@ -4786,7 +4786,7 @@ def digitize(x, bins, right=False, edge=False):
See Also
--------
- bincount, histogram, unique, searchsorted
+ bincount, histogram, unique, nextafter, searchsorted
Notes
-----
@@ -4844,18 +4844,15 @@ def digitize(x, bins, right=False, edge=False):
raise ValueError("bins must be monotonically increasing or decreasing")
if edge:
- # if cannot make round trip, cannot use eps
+ # move first bin eps if right edge not included else move last bin
+ idx = 0 if right else -1
+ # move bin down if going up and using right or going down and using
+ # left else move bin up
+ delta = -mono if right else mono
if np.issubdtype(bins.dtype, _nx.integer):
- if right:
- bins[0] -= mono
- else:
- bins[-1] += mono
+ bins[idx] += delta
else:
- bins = bins.astype(_nx.float64)
- if right:
- bins[0] -= np.finfo(_nx.float64).eps * 2 * mono
- else:
- bins[-1] += np.finfo(_nx.float64).eps * 2 * mono
+ bins[idx] = np.nextafter(bins[idx], bins[idx] + delta)
# this is backwards because the arguments below are swapped
side = 'left' if right else 'right'
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index d6e768a2b..32f660772 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -1715,6 +1715,9 @@ class TestDigitize:
bins = [-1, 0, 1, 2]
assert_array_equal(digitize(x, bins, False, True), [1, 2, 3, 3])
assert_array_equal(digitize(x, bins, True, True), [1, 1, 2, 3])
+ bins = [2, 1, 0, -1]
+ assert_array_equal(digitize(x, bins, False, True), [3, 2, 1, 1])
+ assert_array_equal(digitize(x, bins, True, True), [3, 3, 2, 1])
bins = [1, 1, 1, 1]
assert_array_equal(digitize(x, bins, False), [0, 0, 4, 4])
assert_array_equal(digitize(x, bins, True), [0, 0, 0, 4])