summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--pint/registry_helpers.py15
-rw-r--r--pint/testsuite/test_unit.py4
2 files changed, 13 insertions, 6 deletions
diff --git a/pint/registry_helpers.py b/pint/registry_helpers.py
index ea196d0..c50d22a 100644
--- a/pint/registry_helpers.py
+++ b/pint/registry_helpers.py
@@ -13,7 +13,7 @@ import functools
from .compat import string_types, zip_longest
from .errors import DimensionalityError
-from .util import to_units_container
+from .util import to_units_container, UnitsContainer
def _replace_units(original_units, values_by_name):
@@ -26,7 +26,7 @@ def _replace_units(original_units, values_by_name):
for arg_name, exponent in original_units.items():
q = q * values_by_name[arg_name] ** exponent
- return to_units_container(q)
+ return getattr(q, "_units", UnitsContainer({}))
def _to_units_container(a):
@@ -95,13 +95,16 @@ def _parse_wrap_args(args):
# first pass: Grab named values
for ndx in defs_args_ndx:
- values_by_name[args_as_uc[ndx][0]] = values[ndx]
- new_values[ndx] = values[ndx]._magnitude
+ value = values[ndx]
+ values_by_name[args_as_uc[ndx][0]] = value
+ new_values[ndx] = getattr(value, "_magnitude", value)
# second pass: calculate derived values based on named values
for ndx in dependent_args_ndx:
- new_values[ndx] = ureg._convert(values[ndx]._magnitude,
- values[ndx]._units,
+ value = values[ndx]
+ assert _replace_units(args_as_uc[ndx][0], values_by_name) is not None
+ new_values[ndx] = ureg._convert(getattr(value, "_magnitude", value),
+ getattr(value, "_units", UnitsContainer({})),
_replace_units(args_as_uc[ndx][0], values_by_name))
# third pass: convert other arguments
diff --git a/pint/testsuite/test_unit.py b/pint/testsuite/test_unit.py
index 883ce81..90d68cd 100644
--- a/pint/testsuite/test_unit.py
+++ b/pint/testsuite/test_unit.py
@@ -364,6 +364,7 @@ class TestRegistry(QuantityTestCase):
g0 = ureg.wraps('=A', ['=A', '=A'])(gfunc)
self.assertEqual(g0(3. * ureg.meter, 1. * ureg.centimeter), rst.to('meter'))
+ self.assertEqual(g0(3, 1), 4)
g1 = ureg.wraps('=A', ['=A', '=A'])(gfunc)
self.assertEqual(g1(3. * ureg.meter, 1. * ureg.centimeter), rst.to('centimeter'))
@@ -375,9 +376,12 @@ class TestRegistry(QuantityTestCase):
a = 3. * ureg.meter
b = (2. * ureg.centimeter) ** 2
self.assertEqual(g3(a, b), gfunc2(a, b))
+ self.assertEqual(g3(3, 2), gfunc2(3, 2))
g4 = ureg.wraps('=A**2 * B', ['=A', '=B'])(gfunc3)
self.assertEqual(g4(3. * ureg.meter, 2. * ureg.second), ureg('(3*meter)**2 * 2 *second'))
+ self.assertEqual(g4(3. * ureg.meter, 2.), ureg('(3*meter)**2 * 2'))
+ self.assertEqual(g4(3., 2. * ureg.second), ureg('3**2 * 2 * second'))
def test_check(self):