summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAditya Panchal <apanchal@bastula.org>2016-01-29 18:44:59 -0600
committerAditya Panchal <apanchal@bastula.org>2016-01-29 18:44:59 -0600
commit0e65b7166a6265a2047cb3ca47f487f3de19f0a6 (patch)
treee473af69f29f1bfb987be14ca206a3130ddc2b2b
parente2805398f9a63b825f4a2aab22e9f169ff65aae9 (diff)
downloadnumpy-0e65b7166a6265a2047cb3ca47f487f3de19f0a6.tar.gz
BUG: Fixed regressions in np.piecewise in ref to #5737 and #5729.
Added unit tests for these conditions.
-rw-r--r--numpy/lib/function_base.py12
-rw-r--r--numpy/lib/tests/test_function_base.py11
2 files changed, 19 insertions, 4 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index a1048002c..6eff945b0 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -944,11 +944,15 @@ def piecewise(x, condlist, funclist, *args, **kw):
condlist = condlist.T
if n == n2 - 1: # compute the "otherwise" condition.
totlist = np.logical_or.reduce(condlist, axis=0)
- condlist = np.vstack([condlist, ~totlist])
+ try:
+ condlist = np.vstack([condlist, ~totlist])
+ except:
+ condlist = [asarray(c, dtype=bool) for c in condlist]
+ totlist = condlist[0]
+ for k in range(1, n):
+ totlist |= condlist[k]
+ condlist.append(~totlist)
n += 1
- if (n != n2):
- raise ValueError(
- "function list and condition list must be the same")
y = zeros(x.shape, x.dtype)
for k in range(n):
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index d6a838f3a..878d00bdf 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -1862,6 +1862,10 @@ class TestPiecewise(TestCase):
x = piecewise([1, 2], [[True, False], [False, True]], [3, 4])
assert_array_equal(x, [3, 4])
+ def test_scalar_domains_three_conditions(self):
+ x = piecewise(3, [True, False, False], [4, 2, 0])
+ assert_equal(x, 4)
+
def test_default(self):
# No value specified for x[1], should be 0
x = piecewise([1, 2], [True, False], [2])
@@ -1886,6 +1890,13 @@ class TestPiecewise(TestCase):
x = 3
piecewise(x, [x <= 3, x > 3], [4, 0]) # Should succeed.
+ def test_multidimensional_extrafunc(self):
+ x = np.array([[-2.5, -1.5, -0.5],
+ [0.5, 1.5, 2.5]])
+ y = piecewise(x, [x < 0, x >= 2], [-1, 1, 3])
+ assert_array_equal(y, np.array([[-1., -1., -1.],
+ [3., 3., 1.]]))
+
class TestBincount(TestCase):