summaryrefslogtreecommitdiff
path: root/pint/facets
diff options
context:
space:
mode:
authorRyan May <rmay@ucar.edu>2022-09-29 19:58:54 -0600
committerRyan May <rmay@ucar.edu>2023-04-27 18:09:12 -0600
commit61571a77e1a765b36ce1a26951975a1332ed3cf6 (patch)
tree4c049bcbdb3fba8a38f954d5eff3e3d3de478a5f /pint/facets
parent1b54de47fcb3eeaf4c52e5acb519bd212216f413 (diff)
downloadpint-61571a77e1a765b36ce1a26951975a1332ed3cf6.tar.gz
Properly handle offset units for trapz (Fixes #1593)
Diffstat (limited to 'pint/facets')
-rw-r--r--pint/facets/numpy/numpy_func.py34
1 files changed, 32 insertions, 2 deletions
diff --git a/pint/facets/numpy/numpy_func.py b/pint/facets/numpy/numpy_func.py
index 2a4421c..2a004a8 100644
--- a/pint/facets/numpy/numpy_func.py
+++ b/pint/facets/numpy/numpy_func.py
@@ -13,7 +13,7 @@ from inspect import signature
from itertools import chain
from ...compat import is_upcast_type, np, zero_or_nan
-from ...errors import DimensionalityError, UnitStrippedWarning
+from ...errors import DimensionalityError, OffsetUnitCalculusError, UnitStrippedWarning
from ...util import iterable, sized
HANDLED_UFUNCS = {}
@@ -729,6 +729,36 @@ for name in ["prod", "nanprod"]:
implement_prod_func(name)
+def _base_unit_if_needed(a):
+ if a._is_multiplicative:
+ return a
+ else:
+ if a.units._REGISTRY.autoconvert_offset_to_baseunit:
+ return a.to_base_units()
+ else:
+ raise OffsetUnitCalculusError(a.units)
+
+
+@implements("trapz", "function")
+def _trapz(a, x=None, dx=1.0, **kwargs):
+ a = _base_unit_if_needed(a)
+ units = a.units
+ if x is not None:
+ if hasattr(x, "units"):
+ x = _base_unit_if_needed(x)
+ units *= x.units
+ x = x._magnitude
+ ret = np.trapz(a._magnitude, x, **kwargs)
+ else:
+ if hasattr(dx, "units"):
+ dx = _base_unit_if_needed(dx)
+ units *= dx.units
+ dx = dx._magnitude
+ ret = np.trapz(a._magnitude, dx=dx, **kwargs)
+
+ return a.units._REGISTRY.Quantity(ret, units)
+
+
# Implement simple matching-unit or stripped-unit functions based on signature
@@ -920,7 +950,7 @@ for func_str in [
# Handle functions with output unit defined by operation
for func_str in ["std", "nanstd", "sum", "nansum", "cumsum", "nancumsum"]:
implement_func("function", func_str, input_units=None, output_unit="sum")
-for func_str in ["cross", "trapz", "dot"]:
+for func_str in ["cross", "dot"]:
implement_func("function", func_str, input_units=None, output_unit="mul")
for func_str in ["diff", "ediff1d"]:
implement_func("function", func_str, input_units=None, output_unit="delta")