summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHernan Grecco <hernan.grecco@gmail.com>2016-02-12 14:07:44 -0300
committerHernan Grecco <hernan.grecco@gmail.com>2016-02-12 15:10:18 -0300
commitf70d749936a8ba6009a8e1f0fee6e4b1ae62ab64 (patch)
tree02e07f2b22364002c22ace140493ea61b405c636
parent457df519c72da34ec45d8c147c56d81f07115a81 (diff)
downloadpint-f70d749936a8ba6009a8e1f0fee6e4b1ae62ab64.tar.gz
Implemented reference in wraps decorator.
We use an API based on strings prefixed with the equal sign. Each parameter can be labeled with a unique name. Parameters can reference other using labels to build up relationships. Close #195
-rw-r--r--docs/wrapping.rst28
-rw-r--r--pint/registry_helpers.py217
-rw-r--r--pint/testsuite/test_unit.py37
-rw-r--r--pint/unit.py98
-rw-r--r--pint/util.py1
5 files changed, 283 insertions, 98 deletions
diff --git a/docs/wrapping.rst b/docs/wrapping.rst
index afbf0d6..bdcb81a 100644
--- a/docs/wrapping.rst
+++ b/docs/wrapping.rst
@@ -1,7 +1,7 @@
.. _wrapping:
-Wrapping functions
-==================
+Wrapping and checking functions
+===============================
In some cases you might want to use pint with a pre-existing web service or library
which is not units aware. Or you might want to write a fast implementation of a
@@ -137,6 +137,30 @@ Or if the function has multiple outputs:
... (ureg.meter, ureg.radians))(pendulum_period_maxspeed)
...
+
+Specifying relations between arguments
+--------------------------------------
+
+In certain cases the actual units but just their relation. This is done using string
+starting with the equal sign `=`:
+
+.. doctest::
+
+ >>> @ureg.wraps('=A**2', ('=A', '=A'))
+ ... def sqsum(x, y):
+ ... return x * x + 2 * x * y + y * y
+
+which can be read as the first argument (`x`) has certain units (we labeled them `A`),
+the second argument (`y`) has the same units as the first (`A` again). The return value
+has the unit of `x` squared (`A**2`)
+
+You can use more than one labels.
+
+ >>> @ureg.wraps('=A**2*B', ('=A', '=A*B', '=B'))
+ ... def some_function(x, y, z):
+
+
+
Ignoring an argument or return value
------------------------------------
diff --git a/pint/registry_helpers.py b/pint/registry_helpers.py
new file mode 100644
index 0000000..66d57da
--- /dev/null
+++ b/pint/registry_helpers.py
@@ -0,0 +1,217 @@
+# -*- coding: utf-8 -*-
+"""
+ pint.registry_helpers
+ ~~~~~~~~~~~~~~~~~~~~~
+
+ Miscellaneous methods of the registry writen as separate functions.
+
+ :copyright: 2013 by Pint Authors, see AUTHORS for more details.
+ :license: BSD, see LICENSE for more details.
+"""
+
+import functools
+
+from .compat import string_types, zip_longest
+from .errors import DimensionalityError
+from .util import to_units_container
+
+
+def _replace_units(original_units, values_by_name):
+ """Convert a unit compatible type to a UnitsContainer.
+
+ :param original_units: a UnitsContainer instance.
+ :param values_by_name: a map between original names and the new values.
+ """
+ q = 1
+ for arg_name, exponent in original_units.items():
+ q = q * values_by_name[arg_name] ** exponent
+
+ return to_units_container(q)
+
+
+def _to_units_container(a):
+ """Convert a unit compatible type to a UnitsContainer,
+ checking if it is string field prefixed with an equal
+ (which is considered a reference)
+
+ Return a tuple with the unit container and a boolean indicating if it was a reference.
+ """
+ if isinstance(a, string_types) and '=' in a:
+ return to_units_container(a.split('=', 1)[1]), True
+ return to_units_container(a), False
+
+
+def _parse_wrap_args(args):
+
+ # Arguments which contain definitions
+ # (i.e. names that appear alone and for the first time)
+ defs_args = set()
+ defs_args_ndx = set()
+
+ # Arguments which depend on others
+ dependent_args_ndx = set()
+
+ # Arguments which have units.
+ unit_args_ndx = set()
+
+ # _to_units_container
+ args_as_uc = [_to_units_container(arg) for arg in args]
+
+ # Check for references in args, remove None values
+ for ndx, (arg, is_ref) in enumerate(args_as_uc):
+ if arg is None:
+ continue
+ elif is_ref:
+ if len(arg) == 1:
+ [(key, value)] = arg.items()
+ if value == 1 and key not in defs_args:
+ # This is the first time that
+ # a variable is used => it is a definition.
+ defs_args.add(key)
+ defs_args_ndx.add(ndx)
+ args_as_uc[ndx] = (key, True)
+ else:
+ # The variable was already found elsewhere,
+ # we consider it a dependent variable.
+ dependent_args_ndx.add(ndx)
+ else:
+ dependent_args_ndx.add(ndx)
+ else:
+ unit_args_ndx.add(ndx)
+
+ # Check that all valid dependent variables
+ for ndx in dependent_args_ndx:
+ arg, is_ref = args_as_uc[ndx]
+ if not isinstance(arg, dict):
+ continue
+ if not set(arg.keys()) <= defs_args:
+ raise ValueError('Found a missing token while wrapping a function: '
+ 'Not all variable referenced in %s are defined using !' % args[ndx])
+
+ def _converter(ureg, values, strict):
+ new_values = list(value for value in values)
+
+ values_by_name = {}
+
+ # first pass: Grab named values
+ for ndx in defs_args_ndx:
+ values_by_name[args_as_uc[ndx][0]] = values[ndx]
+ new_values[ndx] = values[ndx]._magnitude
+
+ # second pass: calculate derived values based on named values
+ for ndx in dependent_args_ndx:
+ new_values[ndx] = ureg._convert(values[ndx]._magnitude,
+ values[ndx]._units,
+ _replace_units(args_as_uc[ndx][0], values_by_name))
+
+ # third pass: convert other arguments
+ for ndx in unit_args_ndx:
+
+ if isinstance(values[ndx], ureg.Quantity):
+ new_values[ndx] = ureg._convert(values[ndx]._magnitude,
+ values[ndx]._units,
+ args_as_uc[ndx][0])
+ else:
+ if strict:
+ raise ValueError('A wrapped function using strict=True requires '
+ 'quantity for all arguments with not None units. '
+ '(error found for {0}, {1})'.format(args_as_uc[ndx][0], new_values[ndx]))
+
+ return new_values, values_by_name
+
+ return _converter
+
+
+def wraps(ureg, ret, args, strict=True):
+ """Wraps a function to become pint-aware.
+
+ Use it when a function requires a numerical value but in some specific
+ units. The wrapper function will take a pint quantity, convert to the units
+ specified in `args` and then call the wrapped function with the resulting
+ magnitude.
+
+ The value returned by the wrapped function will be converted to the units
+ specified in `ret`.
+
+ Use None to skip argument conversion.
+ Set strict to False, to accept also numerical values.
+
+ :param ureg: a UnitRegistry instance.
+ :param ret: output units.
+ :param args: iterable of input units.
+ :param strict: boolean to indicate that only quantities are accepted.
+ :return: the wrapped function.
+ :raises:
+ :class:`ValueError` if strict and one of the arguments is not a Quantity.
+ """
+
+ if not isinstance(args, (list, tuple)):
+ args = (args, )
+
+ converter = _parse_wrap_args(args)
+
+ if isinstance(ret, (list, tuple)):
+ container, ret = True, ret.__class__([_to_units_container(arg) for arg in ret])
+ elif isinstance(ret, string_types):
+ container, ret = False, _to_units_container(ret)
+ else:
+ container = False
+
+ def decorator(func):
+ assigned = tuple(attr for attr in functools.WRAPPER_ASSIGNMENTS if hasattr(func, attr))
+ updated = tuple(attr for attr in functools.WRAPPER_UPDATES if hasattr(func, attr))
+
+ @functools.wraps(func, assigned=assigned, updated=updated)
+ def wrapper(*values, **kw):
+
+ # In principle, the values are used as is
+ # When then extract the magnitudes when needed.
+ new_values, values_by_name = converter(ureg, values, strict)
+
+ result = func(*new_values, **kw)
+
+ if container:
+ out_units = (_replace_units(r, values_by_name) if is_ref else r
+ for (r, is_ref) in ret)
+ return ret.__class__(res if unit is None else ureg.Quantity(res, unit)
+ for unit, res in zip(out_units, result))
+
+ if ret is None:
+ return result
+
+ return ureg.Quantity(result,
+ _replace_units(ret[0], values_by_name) if ret[1] else ret[0])
+
+ return wrapper
+ return decorator
+
+
+def check(ureg, *args):
+ """Decorator to for quantity type checking for function inputs.
+
+ Use it to ensure that the decorated function input parameters match
+ the expected type of pint quantity.
+
+ Use None to skip argument checking.
+
+ :param ureg: a UnitRegistry instance.
+ :param args: iterable of input units.
+ :return: the wrapped function.
+ :raises:
+ :class:`DimensionalityError` if the parameters don't match dimensions
+ """
+ dimensions = [ureg.get_dimensionality(dim) for dim in args]
+
+ def decorator(func):
+ assigned = tuple(attr for attr in functools.WRAPPER_ASSIGNMENTS if hasattr(func, attr))
+ updated = tuple(attr for attr in functools.WRAPPER_UPDATES if hasattr(func, attr))
+
+ @functools.wraps(func, assigned=assigned, updated=updated)
+ def wrapper(*values, **kwargs):
+ for dim, value in zip_longest(dimensions, values):
+ if dim and value.dimensionality != dim:
+ raise DimensionalityError(value, 'a quantity of',
+ value.dimensionality, dim)
+ return func(*values, **kwargs)
+ return wrapper
+ return decorator
diff --git a/pint/testsuite/test_unit.py b/pint/testsuite/test_unit.py
index 97f9fae..f966906 100644
--- a/pint/testsuite/test_unit.py
+++ b/pint/testsuite/test_unit.py
@@ -330,12 +330,45 @@ class TestRegistry(QuantityTestCase):
h0 = ureg.wraps(None, [None, None])(hfunc)
self.assertEqual(h0(3, 1), (3, 1))
- h1 = ureg.wraps(['meter', 'cm'], [None, None])(hfunc)
+ h1 = ureg.wraps(['meter', 'centimeter'], [None, None])(hfunc)
self.assertEqual(h1(3, 1), [3 * ureg.meter, 1 * ureg.cm])
- h2 = ureg.wraps(('meter', 'cm'), [None, None])(hfunc)
+ h2 = ureg.wraps(('meter', 'centimeter'), [None, None])(hfunc)
self.assertEqual(h2(3, 1), (3 * ureg.meter, 1 * ureg.cm))
+ def test_wrap_referencing(self):
+
+ ureg = self.ureg
+
+ def gfunc(x, y):
+ return x + y
+
+ def gfunc2(x, y):
+ return x ** 2 + y
+
+ def gfunc3(x, y):
+ return x ** 2 * y
+
+ rst = 3. * ureg.meter + 1. * ureg.centimeter
+
+ g0 = ureg.wraps('=A', ['=A', '=A'])(gfunc)
+ self.assertEqual(g0(3. * ureg.meter, 1. * ureg.centimeter), rst.to('meter'))
+
+ g1 = ureg.wraps('=A', ['=A', '=A'])(gfunc)
+ self.assertEqual(g1(3. * ureg.meter, 1. * ureg.centimeter), rst.to('centimeter'))
+
+ g2 = ureg.wraps('=A', ['=A', '=A'])(gfunc)
+ self.assertEqual(g2(3. * ureg.meter, 1. * ureg.centimeter), rst.to('meter'))
+
+ g3 = ureg.wraps('=A**2', ['=A', '=A**2'])(gfunc2)
+ a = 3. * ureg.meter
+ b = (2. * ureg.centimeter) ** 2
+ self.assertEqual(g3(a, b), gfunc2(a, b))
+
+ g4 = ureg.wraps('=A**2 * B', ['=A', '=B'])(gfunc3)
+ self.assertEqual(g4(3. * ureg.meter, 2. * ureg.second), ureg('(3*meter)**2 * 2 *second'))
+
+
def test_check(self):
def func(x):
return x
diff --git a/pint/unit.py b/pint/unit.py
index 335cb54..bd62ca9 100644
--- a/pint/unit.py
+++ b/pint/unit.py
@@ -14,7 +14,6 @@ from __future__ import division, unicode_literals, print_function, absolute_impo
import os
import math
import itertools
-import functools
import operator
import pkg_resources
from decimal import Decimal
@@ -25,6 +24,7 @@ from collections import defaultdict
from tokenize import untokenize, NUMBER, STRING, NAME, OP
from numbers import Number
+from . import registry_helpers
from .context import Context, ContextChain
from .util import (logger, pi_theorem, solve_dependencies, ParserHelper,
string_preprocessor, find_connected_nodes,
@@ -32,7 +32,7 @@ from .util import (logger, pi_theorem, solve_dependencies, ParserHelper,
SharedRegistryObject, to_units_container,
fix_str_conversions, SourceIterator)
-from .compat import tokenizer, string_types, NUMERIC_TYPES, long_type, zip_longest
+from .compat import tokenizer, string_types, NUMERIC_TYPES, long_type
from .formatting import siunitx_format_unit
from .definitions import (Definition, UnitDefinition, PrefixDefinition,
DimensionDefinition)
@@ -1267,99 +1267,9 @@ class UnitRegistry(object):
__call__ = parse_expression
- def wraps(self, ret, args, strict=True):
- """Wraps a function to become pint-aware.
+ wraps = registry_helpers.wraps
- Use it when a function requires a numerical value but in some specific
- units. The wrapper function will take a pint quantity, convert to the units
- specified in `args` and then call the wrapped function with the resulting
- magnitude.
-
- The value returned by the wrapped function will be converted to the units
- specified in `ret`.
-
- Use None to skip argument conversion.
- Set strict to False, to accept also numerical values.
-
- :param ret: output units.
- :param args: iterable of input units.
- :param strict: boolean to indicate that only quantities are accepted.
- :return: the wrapped function.
- :raises:
- :class:`ValueError` if strict and one of the arguments is not a Quantity.
- """
-
- Q_ = self.Quantity
-
- if not isinstance(args, (list, tuple)):
- args = (args, )
-
- units = [to_units_container(arg, self) for arg in args]
-
- if isinstance(ret, (list, tuple)):
- ret = ret.__class__([to_units_container(arg, self) for arg in ret])
- elif isinstance(ret, string_types):
- ret = self.parse_units(ret)
-
- def decorator(func):
- assigned = tuple(attr for attr in functools.WRAPPER_ASSIGNMENTS if hasattr(func, attr))
- updated = tuple(attr for attr in functools.WRAPPER_UPDATES if hasattr(func, attr))
- @functools.wraps(func, assigned=assigned, updated=updated)
- def wrapper(*values, **kw):
- new_args = []
- for unit, value in zip(units, values):
- if unit is None:
- new_args.append(value)
- elif isinstance(value, Q_):
- new_args.append(self._convert(value._magnitude,
- value._units, unit))
- elif not strict:
- new_args.append(value)
- else:
- raise ValueError('A wrapped function using strict=True requires '
- 'quantity for all arguments with not None units. '
- '(error found for {0}, {1})'.format(unit, value))
-
- result = func(*new_args, **kw)
-
- if isinstance(ret, (list, tuple)):
- return ret.__class__(res if unit is None else Q_(res, unit)
- for unit, res in zip(ret, result))
- elif ret is not None:
- return Q_(result, ret)
-
- return result
- return wrapper
- return decorator
-
- def check(self, *args):
- """Decorator to for quantity type checking for function inputs.
-
- Use it to ensure that the decorated function input parameters match
- the expected type of pint quantity.
-
- Use None to skip argument checking.
-
- :param args: iterable of input units.
- :return: the wrapped function.
- :raises:
- :class:`DimensionalityError` if the parameters don't match dimensions
- """
- dimensions = [self.get_dimensionality(dim) for dim in args]
-
- def decorator(func):
- assigned = tuple(attr for attr in functools.WRAPPER_ASSIGNMENTS if hasattr(func, attr))
- updated = tuple(attr for attr in functools.WRAPPER_UPDATES if hasattr(func, attr))
-
- @functools.wraps(func, assigned=assigned, updated=updated)
- def wrapper(*values, **kwargs):
- for dim, value in zip_longest(dimensions, values):
- if dim and value.dimensionality != dim:
- raise DimensionalityError(value, 'a quantity of',
- value.dimensionality, dim)
- return func(*values, **kwargs)
- return wrapper
- return decorator
+ check = registry_helpers.check
def build_unit_class(registry):
diff --git a/pint/util.py b/pint/util.py
index dd9cd04..b0131a5 100644
--- a/pint/util.py
+++ b/pint/util.py
@@ -11,6 +11,7 @@
from __future__ import division, unicode_literals, print_function, absolute_import
+from decimal import Decimal
import locale
import sys
import re