summaryrefslogtreecommitdiff
path: root/pint/numpy_func.py
diff options
context:
space:
mode:
authorKeewis <keewis@posteo.de>2020-04-23 16:30:53 +0200
committerKeewis <keewis@posteo.de>2020-04-23 16:35:21 +0200
commitc12aac68f28c1d229b1f90336c17ec2fbf086ebc (patch)
tree4ca0a90d25701c7beb34028e58af9ccf2d7dc50b /pint/numpy_func.py
parente06902c9f6850717b1c4c5ad0ffbf57785bb4959 (diff)
downloadpint-c12aac68f28c1d229b1f90336c17ec2fbf086ebc.tar.gz
implement numpy.prod
Diffstat (limited to 'pint/numpy_func.py')
-rw-r--r--pint/numpy_func.py22
1 files changed, 22 insertions, 0 deletions
diff --git a/pint/numpy_func.py b/pint/numpy_func.py
index 8501a7e..7188071 100644
--- a/pint/numpy_func.py
+++ b/pint/numpy_func.py
@@ -670,6 +670,28 @@ def _all(a, *args, **kwargs):
raise ValueError("Boolean value of Quantity with offset unit is ambiguous.")
+@implements("prod", "function")
+def _prod(a, *args, **kwargs):
+ arg_names = ("axis", "dtype", "out", "keepdims", "initial", "where")
+ all_kwargs = dict(**dict(zip(arg_names, args)), **kwargs)
+ axis = all_kwargs.get("axis", None)
+ where = all_kwargs.get("where", None)
+
+ if axis is not None and where is not None:
+ raise ValueError("passing axis and where is not supported")
+
+ result = np.prod(a._magnitude, *args, **kwargs)
+
+ if axis is not None:
+ exponent = a.size // result.size
+ units = a.units ** exponent
+ elif where is not None:
+ exponent = np.asarray(where, dtype=np.bool_).sum()
+ units = a.units ** exponent
+
+ return units._REGISTRY.Quantity(result, units)
+
+
# Implement simple matching-unit or stripped-unit functions based on signature