diff options
| author | Keewis <keewis@posteo.de> | 2020-04-23 16:30:53 +0200 |
|---|---|---|
| committer | Keewis <keewis@posteo.de> | 2020-04-23 16:35:21 +0200 |
| commit | c12aac68f28c1d229b1f90336c17ec2fbf086ebc (patch) | |
| tree | 4ca0a90d25701c7beb34028e58af9ccf2d7dc50b /pint/numpy_func.py | |
| parent | e06902c9f6850717b1c4c5ad0ffbf57785bb4959 (diff) | |
| download | pint-c12aac68f28c1d229b1f90336c17ec2fbf086ebc.tar.gz | |
implement numpy.prod
Diffstat (limited to 'pint/numpy_func.py')
| -rw-r--r-- | pint/numpy_func.py | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/pint/numpy_func.py b/pint/numpy_func.py index 8501a7e..7188071 100644 --- a/pint/numpy_func.py +++ b/pint/numpy_func.py @@ -670,6 +670,28 @@ def _all(a, *args, **kwargs): raise ValueError("Boolean value of Quantity with offset unit is ambiguous.") +@implements("prod", "function") +def _prod(a, *args, **kwargs): + arg_names = ("axis", "dtype", "out", "keepdims", "initial", "where") + all_kwargs = dict(**dict(zip(arg_names, args)), **kwargs) + axis = all_kwargs.get("axis", None) + where = all_kwargs.get("where", None) + + if axis is not None and where is not None: + raise ValueError("passing axis and where is not supported") + + result = np.prod(a._magnitude, *args, **kwargs) + + if axis is not None: + exponent = a.size // result.size + units = a.units ** exponent + elif where is not None: + exponent = np.asarray(where, dtype=np.bool_).sum() + units = a.units ** exponent + + return units._REGISTRY.Quantity(result, units) + + # Implement simple matching-unit or stripped-unit functions based on signature |
