diff options
author | keewis <keewis@users.noreply.github.com> | 2022-02-02 22:49:31 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-02-02 22:49:31 +0100 |
commit | f1dc122abc48eae80d88bc3adeb99081e46e6f29 (patch) | |
tree | ff2b2ff98126cea2822f628c5a3645d33ccc78ac | |
parent | 8f5f7e56fc5a8e990c85cb01d6636c165cf41a15 (diff) | |
download | pint-f1dc122abc48eae80d88bc3adeb99081e46e6f29.tar.gz |
implement `numpy.nanprod` (#1369)
* add tests for nanprod
* fix the exponent for nanprod
* use `helpers.assert_quantity_equal` instead of `self.assertQuantityEqual`
* add a entry to `CHANGES`
-rw-r--r-- | CHANGES | 1 | ||||
-rw-r--r-- | pint/numpy_func.py | 65 | ||||
-rw-r--r-- | pint/testsuite/test_numpy.py | 10 |
3 files changed, 51 insertions, 25 deletions
@@ -12,6 +12,7 @@ Pint Changelog - Fix casting error when using to_reduced_units with array of int. (Issue #1184) - Use default numpy `np.printoptions` available since numpy 1.15. +- Implement `numpy.nanprod` (Issue #1369) - Fix default_format ignored for measurement (Issue #1456) diff --git a/pint/numpy_func.py b/pint/numpy_func.py index 38aab1a..5c48e5a 100644 --- a/pint/numpy_func.py +++ b/pint/numpy_func.py @@ -679,34 +679,49 @@ def _all(a, *args, **kwargs): raise ValueError("Boolean value of Quantity with offset unit is ambiguous.") -@implements("prod", "function") -def _prod(a, *args, **kwargs): - arg_names = ("axis", "dtype", "out", "keepdims", "initial", "where") - all_kwargs = dict(**dict(zip(arg_names, args)), **kwargs) - axis = all_kwargs.get("axis", None) - where = all_kwargs.get("where", None) - - registry = a.units._REGISTRY - - if axis is not None and where is not None: - _, where_ = np.broadcast_arrays(a._magnitude, where) - exponents = np.unique(np.sum(where_, axis=axis)) - if len(exponents) == 1 or (len(exponents) == 2 and 0 in exponents): - units = a.units ** np.max(exponents) +def implement_prod_func(name): + if np is None: + return + + func = getattr(np, name, None) + if func is None: + return + + @implements(name, "function") + def _prod(a, *args, **kwargs): + arg_names = ("axis", "dtype", "out", "keepdims", "initial", "where") + all_kwargs = dict(**dict(zip(arg_names, args)), **kwargs) + axis = all_kwargs.get("axis", None) + where = all_kwargs.get("where", None) + + registry = a.units._REGISTRY + + if axis is not None and where is not None: + _, where_ = np.broadcast_arrays(a._magnitude, where) + exponents = np.unique(np.sum(where_, axis=axis)) + if len(exponents) == 1 or (len(exponents) == 2 and 0 in exponents): + units = a.units ** np.max(exponents) + else: + units = registry.dimensionless + a = a.to(units) + elif axis is not None: + units = a.units ** a.shape[axis] + elif where is not None: + exponent = np.sum(where) + units = a.units ** exponent else: - units = registry.dimensionless - a = a.to(units) - elif axis is not None: - units = a.units ** a.shape[axis] - elif where is not None: - exponent = np.sum(where) - units = a.units ** exponent - else: - units = a.units ** a.size + exponent = ( + np.sum(np.logical_not(np.isnan(a))) if name == "nanprod" else a.size + ) + units = a.units ** exponent + + result = func(a._magnitude, *args, **kwargs) + + return registry.Quantity(result, units) - result = np.prod(a._magnitude, *args, **kwargs) - return registry.Quantity(result, units) +for name in ["prod", "nanprod"]: + implement_prod_func(name) # Implement simple matching-unit or stripped-unit functions based on signature diff --git a/pint/testsuite/test_numpy.py b/pint/testsuite/test_numpy.py index ce337b2..5e9915b 100644 --- a/pint/testsuite/test_numpy.py +++ b/pint/testsuite/test_numpy.py @@ -329,6 +329,16 @@ class TestNumpyMathematicalFunctions(TestNumpyMethods): np.prod(self.q, axis=axis, where=[True, False]), [3, 1] * self.ureg.m ** 2 ) + @helpers.requires_array_function_protocol() + def test_nanprod_numpy_func(self): + helpers.assert_quantity_equal(np.nanprod(self.q_nan), 6 * self.ureg.m ** 3) + helpers.assert_quantity_equal( + np.nanprod(self.q_nan, axis=0), [3, 2] * self.ureg.m ** 2 + ) + helpers.assert_quantity_equal( + np.nanprod(self.q_nan, axis=1), [2, 3] * self.ureg.m ** 2 + ) + def test_sum(self): assert self.q.sum() == 10 * self.ureg.m helpers.assert_quantity_equal(self.q.sum(0), [4, 6] * self.ureg.m) |