diff options
author | Hernan Grecco <hernan.grecco@gmail.com> | 2023-04-29 10:18:34 -0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-29 10:18:34 -0300 |
commit | 87fb9d7ce3511bc4e150cc6a3b6f4a480998fe77 (patch) | |
tree | 1c15d3840ad870ba6e5f46a273a3d29db11f2665 /pint | |
parent | 1d65a3f34e5ec7dd2d62426cfb9e73c116bfe04e (diff) | |
parent | 819f92008ad3549a4a189513af86f1ad8a51c4d9 (diff) | |
download | pint-87fb9d7ce3511bc4e150cc6a3b6f4a480998fe77.tar.gz |
Merge pull request #1594 from dopplershift/fix-trapz-temp
Properly handle offset units for trapz
Diffstat (limited to 'pint')
-rw-r--r-- | pint/facets/numpy/numpy_func.py | 59 | ||||
-rw-r--r-- | pint/testsuite/test_numpy_func.py | 64 |
2 files changed, 120 insertions, 3 deletions
diff --git a/pint/facets/numpy/numpy_func.py b/pint/facets/numpy/numpy_func.py index 2a4421c..f25f4a4 100644 --- a/pint/facets/numpy/numpy_func.py +++ b/pint/facets/numpy/numpy_func.py @@ -13,7 +13,7 @@ from inspect import signature from itertools import chain from ...compat import is_upcast_type, np, zero_or_nan -from ...errors import DimensionalityError, UnitStrippedWarning +from ...errors import DimensionalityError, OffsetUnitCalculusError, UnitStrippedWarning from ...util import iterable, sized HANDLED_UFUNCS = {} @@ -729,6 +729,61 @@ for name in ["prod", "nanprod"]: implement_prod_func(name) +# Handle mutliplicative functions separately to deal with non-multiplicative units +def _base_unit_if_needed(a): + if a._is_multiplicative: + return a + else: + if a.units._REGISTRY.autoconvert_offset_to_baseunit: + return a.to_base_units() + else: + raise OffsetUnitCalculusError(a.units) + + +@implements("trapz", "function") +def _trapz(a, x=None, dx=1.0, **kwargs): + a = _base_unit_if_needed(a) + units = a.units + if x is not None: + if hasattr(x, "units"): + x = _base_unit_if_needed(x) + units *= x.units + x = x._magnitude + ret = np.trapz(a._magnitude, x, **kwargs) + else: + if hasattr(dx, "units"): + dx = _base_unit_if_needed(dx) + units *= dx.units + dx = dx._magnitude + ret = np.trapz(a._magnitude, dx=dx, **kwargs) + + return a.units._REGISTRY.Quantity(ret, units) + + +def implement_mul_func(func): + # If NumPy is not available, do not attempt implement that which does not exist + if np is None: + return + + func = getattr(np, func_str) + + @implements(func_str, "function") + def implementation(a, b, **kwargs): + a = _base_unit_if_needed(a) + units = a.units + if hasattr(b, "units"): + b = _base_unit_if_needed(b) + units *= b.units + b = b._magnitude + + mag = func(a._magnitude, b, **kwargs) + return a.units._REGISTRY.Quantity(mag, units) + + +for func_str in ["cross", "dot"]: + implement_mul_func(func_str) + + # Implement simple matching-unit or stripped-unit functions based on signature @@ -920,8 +975,6 @@ for func_str in [ # Handle functions with output unit defined by operation for func_str in ["std", "nanstd", "sum", "nansum", "cumsum", "nancumsum"]: implement_func("function", func_str, input_units=None, output_unit="sum") -for func_str in ["cross", "trapz", "dot"]: - implement_func("function", func_str, input_units=None, output_unit="mul") for func_str in ["diff", "ediff1d"]: implement_func("function", func_str, input_units=None, output_unit="delta") for func_str in ["gradient"]: diff --git a/pint/testsuite/test_numpy_func.py b/pint/testsuite/test_numpy_func.py index 49caa32..7a0cdb7 100644 --- a/pint/testsuite/test_numpy_func.py +++ b/pint/testsuite/test_numpy_func.py @@ -1,3 +1,4 @@ +from contextlib import ExitStack from unittest.mock import patch import pytest @@ -191,3 +192,66 @@ class TestNumPyFuncUtils(TestNumpyMethods): numpy_wrap("invalid", np.ones, [], {}, []) # TODO (#905 follow-up): test that NotImplemented is returned when upcast types # present + + def test_trapz(self): + with ExitStack() as stack: + stack.callback( + setattr, + self.ureg, + "autoconvert_offset_to_baseunit", + self.ureg.autoconvert_offset_to_baseunit, + ) + self.ureg.autoconvert_offset_to_baseunit = True + t = self.Q_(np.array([0.0, 4.0, 8.0]), "degC") + z = self.Q_(np.array([0.0, 2.0, 4.0]), "m") + helpers.assert_quantity_equal( + np.trapz(t, x=z), self.Q_(1108.6, "kelvin meter") + ) + + def test_trapz_no_autoconvert(self): + t = self.Q_(np.array([0.0, 4.0, 8.0]), "degC") + z = self.Q_(np.array([0.0, 2.0, 4.0]), "m") + with pytest.raises(OffsetUnitCalculusError): + np.trapz(t, x=z) + + def test_dot(self): + with ExitStack() as stack: + stack.callback( + setattr, + self.ureg, + "autoconvert_offset_to_baseunit", + self.ureg.autoconvert_offset_to_baseunit, + ) + self.ureg.autoconvert_offset_to_baseunit = True + t = self.Q_(np.array([0.0, 5.0, 10.0]), "degC") + z = self.Q_(np.array([1.0, 2.0, 3.0]), "m") + helpers.assert_quantity_almost_equal( + np.dot(t, z), self.Q_(1678.9, "kelvin meter") + ) + + def test_dot_no_autoconvert(self): + t = self.Q_(np.array([0.0, 5.0, 10.0]), "degC") + z = self.Q_(np.array([1.0, 2.0, 3.0]), "m") + with pytest.raises(OffsetUnitCalculusError): + np.dot(t, z) + + def test_cross(self): + with ExitStack() as stack: + stack.callback( + setattr, + self.ureg, + "autoconvert_offset_to_baseunit", + self.ureg.autoconvert_offset_to_baseunit, + ) + self.ureg.autoconvert_offset_to_baseunit = True + t = self.Q_(np.array([0.0, 5.0, 10.0]), "degC") + z = self.Q_(np.array([1.0, 2.0, 3.0]), "m") + helpers.assert_quantity_almost_equal( + np.cross(t, z), self.Q_([268.15, -536.3, 268.15], "kelvin meter") + ) + + def test_cross_no_autoconvert(self): + t = self.Q_(np.array([0.0, 5.0, 10.0]), "degC") + z = self.Q_(np.array([1.0, 2.0, 3.0]), "m") + with pytest.raises(OffsetUnitCalculusError): + np.cross(t, z) |