diff options
-rw-r--r-- | pint/registry_helpers.py | 15 | ||||
-rw-r--r-- | pint/testsuite/test_unit.py | 4 |
2 files changed, 13 insertions, 6 deletions
diff --git a/pint/registry_helpers.py b/pint/registry_helpers.py index ea196d0..c50d22a 100644 --- a/pint/registry_helpers.py +++ b/pint/registry_helpers.py @@ -13,7 +13,7 @@ import functools from .compat import string_types, zip_longest from .errors import DimensionalityError -from .util import to_units_container +from .util import to_units_container, UnitsContainer def _replace_units(original_units, values_by_name): @@ -26,7 +26,7 @@ def _replace_units(original_units, values_by_name): for arg_name, exponent in original_units.items(): q = q * values_by_name[arg_name] ** exponent - return to_units_container(q) + return getattr(q, "_units", UnitsContainer({})) def _to_units_container(a): @@ -95,13 +95,16 @@ def _parse_wrap_args(args): # first pass: Grab named values for ndx in defs_args_ndx: - values_by_name[args_as_uc[ndx][0]] = values[ndx] - new_values[ndx] = values[ndx]._magnitude + value = values[ndx] + values_by_name[args_as_uc[ndx][0]] = value + new_values[ndx] = getattr(value, "_magnitude", value) # second pass: calculate derived values based on named values for ndx in dependent_args_ndx: - new_values[ndx] = ureg._convert(values[ndx]._magnitude, - values[ndx]._units, + value = values[ndx] + assert _replace_units(args_as_uc[ndx][0], values_by_name) is not None + new_values[ndx] = ureg._convert(getattr(value, "_magnitude", value), + getattr(value, "_units", UnitsContainer({})), _replace_units(args_as_uc[ndx][0], values_by_name)) # third pass: convert other arguments diff --git a/pint/testsuite/test_unit.py b/pint/testsuite/test_unit.py index 883ce81..90d68cd 100644 --- a/pint/testsuite/test_unit.py +++ b/pint/testsuite/test_unit.py @@ -364,6 +364,7 @@ class TestRegistry(QuantityTestCase): g0 = ureg.wraps('=A', ['=A', '=A'])(gfunc) self.assertEqual(g0(3. * ureg.meter, 1. * ureg.centimeter), rst.to('meter')) + self.assertEqual(g0(3, 1), 4) g1 = ureg.wraps('=A', ['=A', '=A'])(gfunc) self.assertEqual(g1(3. * ureg.meter, 1. * ureg.centimeter), rst.to('centimeter')) @@ -375,9 +376,12 @@ class TestRegistry(QuantityTestCase): a = 3. * ureg.meter b = (2. * ureg.centimeter) ** 2 self.assertEqual(g3(a, b), gfunc2(a, b)) + self.assertEqual(g3(3, 2), gfunc2(3, 2)) g4 = ureg.wraps('=A**2 * B', ['=A', '=B'])(gfunc3) self.assertEqual(g4(3. * ureg.meter, 2. * ureg.second), ureg('(3*meter)**2 * 2 *second')) + self.assertEqual(g4(3. * ureg.meter, 2.), ureg('(3*meter)**2 * 2')) + self.assertEqual(g4(3., 2. * ureg.second), ureg('3**2 * 2 * second')) def test_check(self): |