diff options
author | Alex <aprockhill206@gmail.com> | 2020-05-14 18:47:59 -0700 |
---|---|---|
committer | Alex <aprockhill206@gmail.com> | 2020-07-22 11:59:30 -0700 |
commit | 8066b45451eff24228bb5af96aad2fe0bd548383 (patch) | |
tree | 83864ac983af6b18ef43ca8c728d875a2b1c345b | |
parent | a39e3021b9304fb5a76542d444b7fec2dcff1374 (diff) | |
download | numpy-8066b45451eff24228bb5af96aad2fe0bd548383.tar.gz |
edge first try
ENH: added edge keyword argument to digitize
added test
-rw-r--r-- | numpy/lib/function_base.py | 20 | ||||
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 4 |
2 files changed, 22 insertions, 2 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index 6ea9cc4de..c7f3dc033 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -4733,12 +4733,12 @@ def append(arr, values, axis=None): return concatenate((arr, values), axis=axis) -def _digitize_dispatcher(x, bins, right=None): +def _digitize_dispatcher(x, bins, right=None, edge=None): return (x, bins) @array_function_dispatch(_digitize_dispatcher) -def digitize(x, bins, right=False): +def digitize(x, bins, right=False, edge=False): """ Return the indices of the bins to which each value in input array belongs. @@ -4767,6 +4767,10 @@ def digitize(x, bins, right=False): does not include the right edge. The left bin end is open in this case, i.e., bins[i-1] <= x < bins[i] is the default behavior for monotonically increasing bins. + edge : bool, optional + Whether to include the last right edge if right == False or the first + left edge if right == True. If egde==True, the entire interval is + included that would otherwise not be for the first or last edge case. Returns ------- @@ -4839,6 +4843,18 @@ def digitize(x, bins, right=False): if mono == 0: raise ValueError("bins must be monotonically increasing or decreasing") + if edge: + # if cannot make round trip, cannot use eps + if np.issubdtype(bins.dtype, _nx.int64): + if (bins != bins.astype(_nx.float64).astype(_nx.int64)).any(): + raise ValueError("bins have too large values to use" + "'edges=True'") + bins = bins.astype(_nx.float64) + if right: + bins[0] -= np.finfo(_nx.float64).eps * 2 * mono + else: + bins[-1] += np.finfo(_nx.float64).eps * 2 * mono + # this is backwards because the arguments below are swapped side = 'left' if right else 'right' if mono == -1: diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index eb2fc3311..35225ff21 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -1712,6 +1712,9 @@ class TestDigitize: bins = [1, 1, 0] assert_array_equal(digitize(x, bins, False), [3, 2, 0, 0]) assert_array_equal(digitize(x, bins, True), [3, 3, 2, 0]) + bins = [-1, 0, 1, 2] + assert_array_equal(digitize(x, bins, False, True), [1, 2, 3, 3]) + assert_array_equal(digitize(x, bins, True, True), [1, 1, 2, 3]) bins = [1, 1, 1, 1] assert_array_equal(digitize(x, bins, False), [0, 0, 4, 4]) assert_array_equal(digitize(x, bins, True), [0, 0, 0, 4]) @@ -1740,6 +1743,7 @@ class TestDigitize: # gh-11022 x = 2**54 # loses precision in a float assert_equal(np.digitize(x, [x - 1, x + 1]), 1) + assert_raises(ValueError, digitize, x, [x - 1, x + 1], False, True) @pytest.mark.xfail( reason="gh-11022: np.core.multiarray._monoticity loses precision") |