summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex <aprockhill206@gmail.com>2020-05-14 18:47:59 -0700
committerAlex <aprockhill206@gmail.com>2020-07-22 11:59:30 -0700
commit8066b45451eff24228bb5af96aad2fe0bd548383 (patch)
tree83864ac983af6b18ef43ca8c728d875a2b1c345b
parenta39e3021b9304fb5a76542d444b7fec2dcff1374 (diff)
downloadnumpy-8066b45451eff24228bb5af96aad2fe0bd548383.tar.gz
edge first try
ENH: added edge keyword argument to digitize added test
-rw-r--r--numpy/lib/function_base.py20
-rw-r--r--numpy/lib/tests/test_function_base.py4
2 files changed, 22 insertions, 2 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 6ea9cc4de..c7f3dc033 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -4733,12 +4733,12 @@ def append(arr, values, axis=None):
return concatenate((arr, values), axis=axis)
-def _digitize_dispatcher(x, bins, right=None):
+def _digitize_dispatcher(x, bins, right=None, edge=None):
return (x, bins)
@array_function_dispatch(_digitize_dispatcher)
-def digitize(x, bins, right=False):
+def digitize(x, bins, right=False, edge=False):
"""
Return the indices of the bins to which each value in input array belongs.
@@ -4767,6 +4767,10 @@ def digitize(x, bins, right=False):
does not include the right edge. The left bin end is open in this
case, i.e., bins[i-1] <= x < bins[i] is the default behavior for
monotonically increasing bins.
+ edge : bool, optional
+ Whether to include the last right edge if right == False or the first
+ left edge if right == True. If egde==True, the entire interval is
+ included that would otherwise not be for the first or last edge case.
Returns
-------
@@ -4839,6 +4843,18 @@ def digitize(x, bins, right=False):
if mono == 0:
raise ValueError("bins must be monotonically increasing or decreasing")
+ if edge:
+ # if cannot make round trip, cannot use eps
+ if np.issubdtype(bins.dtype, _nx.int64):
+ if (bins != bins.astype(_nx.float64).astype(_nx.int64)).any():
+ raise ValueError("bins have too large values to use"
+ "'edges=True'")
+ bins = bins.astype(_nx.float64)
+ if right:
+ bins[0] -= np.finfo(_nx.float64).eps * 2 * mono
+ else:
+ bins[-1] += np.finfo(_nx.float64).eps * 2 * mono
+
# this is backwards because the arguments below are swapped
side = 'left' if right else 'right'
if mono == -1:
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index eb2fc3311..35225ff21 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -1712,6 +1712,9 @@ class TestDigitize:
bins = [1, 1, 0]
assert_array_equal(digitize(x, bins, False), [3, 2, 0, 0])
assert_array_equal(digitize(x, bins, True), [3, 3, 2, 0])
+ bins = [-1, 0, 1, 2]
+ assert_array_equal(digitize(x, bins, False, True), [1, 2, 3, 3])
+ assert_array_equal(digitize(x, bins, True, True), [1, 1, 2, 3])
bins = [1, 1, 1, 1]
assert_array_equal(digitize(x, bins, False), [0, 0, 4, 4])
assert_array_equal(digitize(x, bins, True), [0, 0, 0, 4])
@@ -1740,6 +1743,7 @@ class TestDigitize:
# gh-11022
x = 2**54 # loses precision in a float
assert_equal(np.digitize(x, [x - 1, x + 1]), 1)
+ assert_raises(ValueError, digitize, x, [x - 1, x + 1], False, True)
@pytest.mark.xfail(
reason="gh-11022: np.core.multiarray._monoticity loses precision")