diff options
author | Aditya Panchal <apanchal@bastula.org> | 2016-01-29 18:44:59 -0600 |
---|---|---|
committer | Aditya Panchal <apanchal@bastula.org> | 2016-01-29 18:44:59 -0600 |
commit | 0e65b7166a6265a2047cb3ca47f487f3de19f0a6 (patch) | |
tree | e473af69f29f1bfb987be14ca206a3130ddc2b2b | |
parent | e2805398f9a63b825f4a2aab22e9f169ff65aae9 (diff) | |
download | numpy-0e65b7166a6265a2047cb3ca47f487f3de19f0a6.tar.gz |
BUG: Fixed regressions in np.piecewise in ref to #5737 and #5729.
Added unit tests for these conditions.
-rw-r--r-- | numpy/lib/function_base.py | 12 | ||||
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 11 |
2 files changed, 19 insertions, 4 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index a1048002c..6eff945b0 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -944,11 +944,15 @@ def piecewise(x, condlist, funclist, *args, **kw): condlist = condlist.T if n == n2 - 1: # compute the "otherwise" condition. totlist = np.logical_or.reduce(condlist, axis=0) - condlist = np.vstack([condlist, ~totlist]) + try: + condlist = np.vstack([condlist, ~totlist]) + except: + condlist = [asarray(c, dtype=bool) for c in condlist] + totlist = condlist[0] + for k in range(1, n): + totlist |= condlist[k] + condlist.append(~totlist) n += 1 - if (n != n2): - raise ValueError( - "function list and condition list must be the same") y = zeros(x.shape, x.dtype) for k in range(n): diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index d6a838f3a..878d00bdf 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -1862,6 +1862,10 @@ class TestPiecewise(TestCase): x = piecewise([1, 2], [[True, False], [False, True]], [3, 4]) assert_array_equal(x, [3, 4]) + def test_scalar_domains_three_conditions(self): + x = piecewise(3, [True, False, False], [4, 2, 0]) + assert_equal(x, 4) + def test_default(self): # No value specified for x[1], should be 0 x = piecewise([1, 2], [True, False], [2]) @@ -1886,6 +1890,13 @@ class TestPiecewise(TestCase): x = 3 piecewise(x, [x <= 3, x > 3], [4, 0]) # Should succeed. + def test_multidimensional_extrafunc(self): + x = np.array([[-2.5, -1.5, -0.5], + [0.5, 1.5, 2.5]]) + y = piecewise(x, [x < 0, x >= 2], [-1, 1, 3]) + assert_array_equal(y, np.array([[-1., -1., -1.], + [3., 3., 1.]])) + class TestBincount(TestCase): |