summaryrefslogtreecommitdiff
path: root/pint
diff options
context:
space:
mode:
authorHernan Grecco <hernan.grecco@gmail.com>2023-04-29 10:18:34 -0300
committerGitHub <noreply@github.com>2023-04-29 10:18:34 -0300
commit87fb9d7ce3511bc4e150cc6a3b6f4a480998fe77 (patch)
tree1c15d3840ad870ba6e5f46a273a3d29db11f2665 /pint
parent1d65a3f34e5ec7dd2d62426cfb9e73c116bfe04e (diff)
parent819f92008ad3549a4a189513af86f1ad8a51c4d9 (diff)
downloadpint-87fb9d7ce3511bc4e150cc6a3b6f4a480998fe77.tar.gz
Merge pull request #1594 from dopplershift/fix-trapz-temp
Properly handle offset units for trapz
Diffstat (limited to 'pint')
-rw-r--r--pint/facets/numpy/numpy_func.py59
-rw-r--r--pint/testsuite/test_numpy_func.py64
2 files changed, 120 insertions, 3 deletions
diff --git a/pint/facets/numpy/numpy_func.py b/pint/facets/numpy/numpy_func.py
index 2a4421c..f25f4a4 100644
--- a/pint/facets/numpy/numpy_func.py
+++ b/pint/facets/numpy/numpy_func.py
@@ -13,7 +13,7 @@ from inspect import signature
from itertools import chain
from ...compat import is_upcast_type, np, zero_or_nan
-from ...errors import DimensionalityError, UnitStrippedWarning
+from ...errors import DimensionalityError, OffsetUnitCalculusError, UnitStrippedWarning
from ...util import iterable, sized
HANDLED_UFUNCS = {}
@@ -729,6 +729,61 @@ for name in ["prod", "nanprod"]:
implement_prod_func(name)
+# Handle mutliplicative functions separately to deal with non-multiplicative units
+def _base_unit_if_needed(a):
+ if a._is_multiplicative:
+ return a
+ else:
+ if a.units._REGISTRY.autoconvert_offset_to_baseunit:
+ return a.to_base_units()
+ else:
+ raise OffsetUnitCalculusError(a.units)
+
+
+@implements("trapz", "function")
+def _trapz(a, x=None, dx=1.0, **kwargs):
+ a = _base_unit_if_needed(a)
+ units = a.units
+ if x is not None:
+ if hasattr(x, "units"):
+ x = _base_unit_if_needed(x)
+ units *= x.units
+ x = x._magnitude
+ ret = np.trapz(a._magnitude, x, **kwargs)
+ else:
+ if hasattr(dx, "units"):
+ dx = _base_unit_if_needed(dx)
+ units *= dx.units
+ dx = dx._magnitude
+ ret = np.trapz(a._magnitude, dx=dx, **kwargs)
+
+ return a.units._REGISTRY.Quantity(ret, units)
+
+
+def implement_mul_func(func):
+ # If NumPy is not available, do not attempt implement that which does not exist
+ if np is None:
+ return
+
+ func = getattr(np, func_str)
+
+ @implements(func_str, "function")
+ def implementation(a, b, **kwargs):
+ a = _base_unit_if_needed(a)
+ units = a.units
+ if hasattr(b, "units"):
+ b = _base_unit_if_needed(b)
+ units *= b.units
+ b = b._magnitude
+
+ mag = func(a._magnitude, b, **kwargs)
+ return a.units._REGISTRY.Quantity(mag, units)
+
+
+for func_str in ["cross", "dot"]:
+ implement_mul_func(func_str)
+
+
# Implement simple matching-unit or stripped-unit functions based on signature
@@ -920,8 +975,6 @@ for func_str in [
# Handle functions with output unit defined by operation
for func_str in ["std", "nanstd", "sum", "nansum", "cumsum", "nancumsum"]:
implement_func("function", func_str, input_units=None, output_unit="sum")
-for func_str in ["cross", "trapz", "dot"]:
- implement_func("function", func_str, input_units=None, output_unit="mul")
for func_str in ["diff", "ediff1d"]:
implement_func("function", func_str, input_units=None, output_unit="delta")
for func_str in ["gradient"]:
diff --git a/pint/testsuite/test_numpy_func.py b/pint/testsuite/test_numpy_func.py
index 49caa32..7a0cdb7 100644
--- a/pint/testsuite/test_numpy_func.py
+++ b/pint/testsuite/test_numpy_func.py
@@ -1,3 +1,4 @@
+from contextlib import ExitStack
from unittest.mock import patch
import pytest
@@ -191,3 +192,66 @@ class TestNumPyFuncUtils(TestNumpyMethods):
numpy_wrap("invalid", np.ones, [], {}, [])
# TODO (#905 follow-up): test that NotImplemented is returned when upcast types
# present
+
+ def test_trapz(self):
+ with ExitStack() as stack:
+ stack.callback(
+ setattr,
+ self.ureg,
+ "autoconvert_offset_to_baseunit",
+ self.ureg.autoconvert_offset_to_baseunit,
+ )
+ self.ureg.autoconvert_offset_to_baseunit = True
+ t = self.Q_(np.array([0.0, 4.0, 8.0]), "degC")
+ z = self.Q_(np.array([0.0, 2.0, 4.0]), "m")
+ helpers.assert_quantity_equal(
+ np.trapz(t, x=z), self.Q_(1108.6, "kelvin meter")
+ )
+
+ def test_trapz_no_autoconvert(self):
+ t = self.Q_(np.array([0.0, 4.0, 8.0]), "degC")
+ z = self.Q_(np.array([0.0, 2.0, 4.0]), "m")
+ with pytest.raises(OffsetUnitCalculusError):
+ np.trapz(t, x=z)
+
+ def test_dot(self):
+ with ExitStack() as stack:
+ stack.callback(
+ setattr,
+ self.ureg,
+ "autoconvert_offset_to_baseunit",
+ self.ureg.autoconvert_offset_to_baseunit,
+ )
+ self.ureg.autoconvert_offset_to_baseunit = True
+ t = self.Q_(np.array([0.0, 5.0, 10.0]), "degC")
+ z = self.Q_(np.array([1.0, 2.0, 3.0]), "m")
+ helpers.assert_quantity_almost_equal(
+ np.dot(t, z), self.Q_(1678.9, "kelvin meter")
+ )
+
+ def test_dot_no_autoconvert(self):
+ t = self.Q_(np.array([0.0, 5.0, 10.0]), "degC")
+ z = self.Q_(np.array([1.0, 2.0, 3.0]), "m")
+ with pytest.raises(OffsetUnitCalculusError):
+ np.dot(t, z)
+
+ def test_cross(self):
+ with ExitStack() as stack:
+ stack.callback(
+ setattr,
+ self.ureg,
+ "autoconvert_offset_to_baseunit",
+ self.ureg.autoconvert_offset_to_baseunit,
+ )
+ self.ureg.autoconvert_offset_to_baseunit = True
+ t = self.Q_(np.array([0.0, 5.0, 10.0]), "degC")
+ z = self.Q_(np.array([1.0, 2.0, 3.0]), "m")
+ helpers.assert_quantity_almost_equal(
+ np.cross(t, z), self.Q_([268.15, -536.3, 268.15], "kelvin meter")
+ )
+
+ def test_cross_no_autoconvert(self):
+ t = self.Q_(np.array([0.0, 5.0, 10.0]), "degC")
+ z = self.Q_(np.array([1.0, 2.0, 3.0]), "m")
+ with pytest.raises(OffsetUnitCalculusError):
+ np.cross(t, z)