summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorkeewis <keewis@users.noreply.github.com>2022-02-02 22:49:31 +0100
committerGitHub <noreply@github.com>2022-02-02 22:49:31 +0100
commitf1dc122abc48eae80d88bc3adeb99081e46e6f29 (patch)
treeff2b2ff98126cea2822f628c5a3645d33ccc78ac
parent8f5f7e56fc5a8e990c85cb01d6636c165cf41a15 (diff)
downloadpint-f1dc122abc48eae80d88bc3adeb99081e46e6f29.tar.gz
implement `numpy.nanprod` (#1369)
* add tests for nanprod * fix the exponent for nanprod * use `helpers.assert_quantity_equal` instead of `self.assertQuantityEqual` * add a entry to `CHANGES`
-rw-r--r--CHANGES1
-rw-r--r--pint/numpy_func.py65
-rw-r--r--pint/testsuite/test_numpy.py10
3 files changed, 51 insertions, 25 deletions
diff --git a/CHANGES b/CHANGES
index f34a555..aa5443d 100644
--- a/CHANGES
+++ b/CHANGES
@@ -12,6 +12,7 @@ Pint Changelog
- Fix casting error when using to_reduced_units with array of int.
(Issue #1184)
- Use default numpy `np.printoptions` available since numpy 1.15.
+- Implement `numpy.nanprod` (Issue #1369)
- Fix default_format ignored for measurement (Issue #1456)
diff --git a/pint/numpy_func.py b/pint/numpy_func.py
index 38aab1a..5c48e5a 100644
--- a/pint/numpy_func.py
+++ b/pint/numpy_func.py
@@ -679,34 +679,49 @@ def _all(a, *args, **kwargs):
raise ValueError("Boolean value of Quantity with offset unit is ambiguous.")
-@implements("prod", "function")
-def _prod(a, *args, **kwargs):
- arg_names = ("axis", "dtype", "out", "keepdims", "initial", "where")
- all_kwargs = dict(**dict(zip(arg_names, args)), **kwargs)
- axis = all_kwargs.get("axis", None)
- where = all_kwargs.get("where", None)
-
- registry = a.units._REGISTRY
-
- if axis is not None and where is not None:
- _, where_ = np.broadcast_arrays(a._magnitude, where)
- exponents = np.unique(np.sum(where_, axis=axis))
- if len(exponents) == 1 or (len(exponents) == 2 and 0 in exponents):
- units = a.units ** np.max(exponents)
+def implement_prod_func(name):
+ if np is None:
+ return
+
+ func = getattr(np, name, None)
+ if func is None:
+ return
+
+ @implements(name, "function")
+ def _prod(a, *args, **kwargs):
+ arg_names = ("axis", "dtype", "out", "keepdims", "initial", "where")
+ all_kwargs = dict(**dict(zip(arg_names, args)), **kwargs)
+ axis = all_kwargs.get("axis", None)
+ where = all_kwargs.get("where", None)
+
+ registry = a.units._REGISTRY
+
+ if axis is not None and where is not None:
+ _, where_ = np.broadcast_arrays(a._magnitude, where)
+ exponents = np.unique(np.sum(where_, axis=axis))
+ if len(exponents) == 1 or (len(exponents) == 2 and 0 in exponents):
+ units = a.units ** np.max(exponents)
+ else:
+ units = registry.dimensionless
+ a = a.to(units)
+ elif axis is not None:
+ units = a.units ** a.shape[axis]
+ elif where is not None:
+ exponent = np.sum(where)
+ units = a.units ** exponent
else:
- units = registry.dimensionless
- a = a.to(units)
- elif axis is not None:
- units = a.units ** a.shape[axis]
- elif where is not None:
- exponent = np.sum(where)
- units = a.units ** exponent
- else:
- units = a.units ** a.size
+ exponent = (
+ np.sum(np.logical_not(np.isnan(a))) if name == "nanprod" else a.size
+ )
+ units = a.units ** exponent
+
+ result = func(a._magnitude, *args, **kwargs)
+
+ return registry.Quantity(result, units)
- result = np.prod(a._magnitude, *args, **kwargs)
- return registry.Quantity(result, units)
+for name in ["prod", "nanprod"]:
+ implement_prod_func(name)
# Implement simple matching-unit or stripped-unit functions based on signature
diff --git a/pint/testsuite/test_numpy.py b/pint/testsuite/test_numpy.py
index ce337b2..5e9915b 100644
--- a/pint/testsuite/test_numpy.py
+++ b/pint/testsuite/test_numpy.py
@@ -329,6 +329,16 @@ class TestNumpyMathematicalFunctions(TestNumpyMethods):
np.prod(self.q, axis=axis, where=[True, False]), [3, 1] * self.ureg.m ** 2
)
+ @helpers.requires_array_function_protocol()
+ def test_nanprod_numpy_func(self):
+ helpers.assert_quantity_equal(np.nanprod(self.q_nan), 6 * self.ureg.m ** 3)
+ helpers.assert_quantity_equal(
+ np.nanprod(self.q_nan, axis=0), [3, 2] * self.ureg.m ** 2
+ )
+ helpers.assert_quantity_equal(
+ np.nanprod(self.q_nan, axis=1), [2, 3] * self.ureg.m ** 2
+ )
+
def test_sum(self):
assert self.q.sum() == 10 * self.ureg.m
helpers.assert_quantity_equal(self.q.sum(0), [4, 6] * self.ureg.m)