diff options
| author | Ryan May <rmay@ucar.edu> | 2022-09-29 19:58:54 -0600 |
|---|---|---|
| committer | Ryan May <rmay@ucar.edu> | 2023-04-27 18:09:12 -0600 |
| commit | 61571a77e1a765b36ce1a26951975a1332ed3cf6 (patch) | |
| tree | 4c049bcbdb3fba8a38f954d5eff3e3d3de478a5f /pint/facets | |
| parent | 1b54de47fcb3eeaf4c52e5acb519bd212216f413 (diff) | |
| download | pint-61571a77e1a765b36ce1a26951975a1332ed3cf6.tar.gz | |
Properly handle offset units for trapz (Fixes #1593)
Diffstat (limited to 'pint/facets')
| -rw-r--r-- | pint/facets/numpy/numpy_func.py | 34 |
1 files changed, 32 insertions, 2 deletions
diff --git a/pint/facets/numpy/numpy_func.py b/pint/facets/numpy/numpy_func.py index 2a4421c..2a004a8 100644 --- a/pint/facets/numpy/numpy_func.py +++ b/pint/facets/numpy/numpy_func.py @@ -13,7 +13,7 @@ from inspect import signature from itertools import chain from ...compat import is_upcast_type, np, zero_or_nan -from ...errors import DimensionalityError, UnitStrippedWarning +from ...errors import DimensionalityError, OffsetUnitCalculusError, UnitStrippedWarning from ...util import iterable, sized HANDLED_UFUNCS = {} @@ -729,6 +729,36 @@ for name in ["prod", "nanprod"]: implement_prod_func(name) +def _base_unit_if_needed(a): + if a._is_multiplicative: + return a + else: + if a.units._REGISTRY.autoconvert_offset_to_baseunit: + return a.to_base_units() + else: + raise OffsetUnitCalculusError(a.units) + + +@implements("trapz", "function") +def _trapz(a, x=None, dx=1.0, **kwargs): + a = _base_unit_if_needed(a) + units = a.units + if x is not None: + if hasattr(x, "units"): + x = _base_unit_if_needed(x) + units *= x.units + x = x._magnitude + ret = np.trapz(a._magnitude, x, **kwargs) + else: + if hasattr(dx, "units"): + dx = _base_unit_if_needed(dx) + units *= dx.units + dx = dx._magnitude + ret = np.trapz(a._magnitude, dx=dx, **kwargs) + + return a.units._REGISTRY.Quantity(ret, units) + + # Implement simple matching-unit or stripped-unit functions based on signature @@ -920,7 +950,7 @@ for func_str in [ # Handle functions with output unit defined by operation for func_str in ["std", "nanstd", "sum", "nansum", "cumsum", "nancumsum"]: implement_func("function", func_str, input_units=None, output_unit="sum") -for func_str in ["cross", "trapz", "dot"]: +for func_str in ["cross", "dot"]: implement_func("function", func_str, input_units=None, output_unit="mul") for func_str in ["diff", "ediff1d"]: implement_func("function", func_str, input_units=None, output_unit="delta") |
