summaryrefslogtreecommitdiff
path: root/pint
diff options
context:
space:
mode:
authorAndrew <andrewgsavage@gmail.com>2022-10-15 02:05:25 +0100
committerAndrew <andrewgsavage@gmail.com>2022-10-15 02:05:25 +0100
commit93958b932bb85be440582e66150964c05b2dd460 (patch)
tree02653d968f12fe6dfdb06fa6b6cd8334582340fb /pint
parent208d36a5bb051465296dad846296bb19628e2307 (diff)
downloadpint-93958b932bb85be440582e66150964c05b2dd460.tar.gz
broadcast_arrays
Diffstat (limited to 'pint')
-rw-r--r--pint/facets/numpy/numpy_func.py9
-rw-r--r--pint/testsuite/test_numpy.py15
2 files changed, 22 insertions, 2 deletions
diff --git a/pint/facets/numpy/numpy_func.py b/pint/facets/numpy/numpy_func.py
index 7143143..cb80914 100644
--- a/pint/facets/numpy/numpy_func.py
+++ b/pint/facets/numpy/numpy_func.py
@@ -877,7 +877,14 @@ for func_str in ["cumprod", "cumproduct", "nancumprod"]:
implement_single_dimensionless_argument_func(func_str)
# Handle single-argument consistent unit functions
-for func_str in ["block", "hstack", "vstack", "dstack", "column_stack"]:
+for func_str in [
+ "block",
+ "hstack",
+ "vstack",
+ "dstack",
+ "column_stack",
+ "broadcast_arrays",
+]:
implement_func(
"function", func_str, input_units="all_consistent", output_unit="match_input"
)
diff --git a/pint/testsuite/test_numpy.py b/pint/testsuite/test_numpy.py
index f2ddaf0..a6b47f4 100644
--- a/pint/testsuite/test_numpy.py
+++ b/pint/testsuite/test_numpy.py
@@ -82,7 +82,7 @@ class TestNumpyArrayManipulation(TestNumpyMethods):
# TODO
# https://www.numpy.org/devdocs/reference/routines.array-manipulation.html
# copyto
- # broadcast , broadcast_arrays
+ # broadcast
# asarray asanyarray asmatrix asfarray asfortranarray ascontiguousarray asarray_chkfinite asscalar require
# Changing array shape
@@ -271,6 +271,19 @@ class TestNumpyArrayManipulation(TestNumpyMethods):
def test_item(self):
helpers.assert_quantity_equal(self.Q_([[0]], "m").item(), 0 * self.ureg.m)
+ def test_broadcast_arrays(self):
+ x = self.Q_(np.array([[1, 2, 3]]), "m")
+ y = self.Q_(np.array([[4], [5]]), "nm")
+ result = np.broadcast_arrays(x, y)
+ expected = self.Q_(
+ [
+ [[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]],
+ [[4e-09, 4e-09, 4e-09], [5e-09, 5e-09, 5e-09]],
+ ],
+ "m",
+ )
+ helpers.assert_quantity_equal(result, expected)
+
class TestNumpyMathematicalFunctions(TestNumpyMethods):
# https://www.numpy.org/devdocs/reference/routines.math.html