summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDCtheTall <dylancutler@google.com>2021-03-31 17:11:41 -0400
committerDCtheTall <dylancutler@google.com>2021-03-31 17:11:41 -0400
commit3cdb33f02298c3544f7a3bc312c42422a2a7b971 (patch)
treeee797ab4abee784fbbb6668d3a1e84783fe8b535
parentc1ce397565398dcf22eda23f37d3c77ffe8f9b13 (diff)
downloadnumpy-3cdb33f02298c3544f7a3bc312c42422a2a7b971.tar.gz
Add tests np.meshgrid for higher dimensional grids.
-rw-r--r--numpy/lib/tests/test_function_base.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index afcb81eff..4201afac3 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -2219,6 +2219,7 @@ class TestMsort:
[0.64864341, 0.79115165, 0.96098397]]))
+# Run test using: python3 runtests.py -t numpy.lib.tests.test_function_base
class TestMeshgrid:
def test_simple(self):
@@ -2307,6 +2308,20 @@ class TestMeshgrid:
assert_equal(x[0, :], 0)
assert_equal(x[1, :], X)
+ def test_higher_dimensions(self):
+ a, b, c = np.meshgrid([0], [1, 1], [2, 2])
+ assert_equal(a, [[[0, 0]], [[0, 0]]])
+ assert_equal(b, [[[1, 1]], [[1, 1]]])
+ assert_equal(c, [[[2, 2]], [[2, 2]]])
+
+ a, b, c, d, e = np.meshgrid(*([0] * i for i in range(1, 6)))
+ expected_shape = (2, 1, 3, 4, 5)
+ assert_equal(a.shape, expected_shape)
+ assert_equal(b.shape, expected_shape)
+ assert_equal(c.shape, expected_shape)
+ assert_equal(d.shape, expected_shape)
+ assert_equal(e.shape, expected_shape)
+
class TestPiecewise: