diff options
-rw-r--r-- | .travis.yml | 21 | ||||
-rw-r--r-- | bench.py | 107 | ||||
-rw-r--r-- | bench_base.yaml | 42 | ||||
-rw-r--r-- | bench_numpy.yaml | 23 | ||||
-rw-r--r-- | docs/getting.rst | 2 | ||||
-rw-r--r-- | pint/__init__.py | 2 | ||||
-rw-r--r-- | pint/compat/__init__.py | 112 | ||||
-rw-r--r-- | pint/compat/chainmap.py (renamed from pint/compat.py) | 35 | ||||
-rw-r--r-- | pint/compat/lrucache.py | 177 | ||||
-rw-r--r-- | pint/compat/nullhandler.py | 32 | ||||
-rw-r--r-- | pint/compat/transformdict.py | 136 | ||||
-rw-r--r-- | pint/context.py | 5 | ||||
-rw-r--r-- | pint/measurement.py | 109 | ||||
-rw-r--r-- | pint/quantity.py | 101 | ||||
-rw-r--r-- | pint/testsuite/__init__.py | 30 | ||||
-rw-r--r-- | pint/testsuite/test_contexts.py | 3 | ||||
-rw-r--r-- | pint/testsuite/test_issues.py | 22 | ||||
-rw-r--r-- | pint/testsuite/test_measurement.py | 53 | ||||
-rw-r--r-- | pint/testsuite/test_numpy.py | 3 | ||||
-rw-r--r-- | pint/testsuite/test_pitheorem.py | 2 | ||||
-rw-r--r-- | pint/testsuite/test_quantity.py | 6 | ||||
-rw-r--r-- | pint/testsuite/test_umath.py | 3 | ||||
-rw-r--r-- | pint/testsuite/test_unit.py | 50 | ||||
-rw-r--r-- | pint/testsuite/test_util.py | 36 | ||||
-rw-r--r-- | pint/unit.py | 107 | ||||
-rw-r--r-- | pint/util.py | 55 |
26 files changed, 983 insertions, 291 deletions
diff --git a/.travis.yml b/.travis.yml index a50582d..9769ad7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,10 +7,14 @@ python: - "3.3" env: - - NUMPY_VERSION=0 - - NUMPY_VERSION=1.6 - - NUMPY_VERSION=1.7 - - NUMPY_VERSION=1.8 + - UNCERTAINTIES="N" NUMPY_VERSION=0 + - UNCERTAINTIES="N" NUMPY_VERSION=1.6 + - UNCERTAINTIES="N" NUMPY_VERSION=1.7 + - UNCERTAINTIES="N" NUMPY_VERSION=1.8 + - UNCERTAINTIES="Y" NUMPY_VERSION=0 + - UNCERTAINTIES="Y" NUMPY_VERSION=1.6 + - UNCERTAINTIES="Y" NUMPY_VERSION=1.7 + - UNCERTAINTIES="Y" NUMPY_VERSION=1.8 branches: only: @@ -19,6 +23,7 @@ branches: install: - if [[ $TRAVIS_PYTHON_VERSION == '2.6' ]]; then pip install --use-mirrors unittest2; fi + - if [[ $UNCERTAINTIES == 'Y' ]]; then pip install --use-mirrors uncertainties; fi - if [ $NUMPY_VERSION = '0' ]; then pip uninstall -y numpy || true; else pip install numpy==$NUMPY_VERSION; fi - pip install . --use-mirrors - pip list --pre @@ -29,6 +34,10 @@ matrix: # Don't run with these version combinations as these NumPy packages don't seem to be installable with Travis exclude: - python: "3.2" - env: NUMPY_VERSION=1.6 + env: UNCERTAINTIES="N" NUMPY_VERSION=1.6 + - python: "3.2" + env: UNCERTAINTIES="Y" NUMPY_VERSION=1.6 + - python: "3.3" + env: UNCERTAINTIES="N" NUMPY_VERSION=1.6 - python: "3.3" - env: NUMPY_VERSION=1.6 + env: UNCERTAINTIES="Y" NUMPY_VERSION=1.6 diff --git a/bench.py b/bench.py new file mode 100644 index 0000000..b0e1efc --- /dev/null +++ b/bench.py @@ -0,0 +1,107 @@ + + +from __future__ import division, unicode_literals, print_function, absolute_import + +import glob +import copy +from timeit import Timer + +import yaml + + +def time_stmt(stmt='pass', setup='pass', number=0, repeat=3): + """Timer function with the same behaviour as running `python -m timeit ` + in the command line. + + :return: elapsed time in seconds or NaN if the command failed. + :rtype: float + """ + + t = Timer(stmt, setup) + + if not number: + # determine number so that 0.2 <= total time < 2.0 + for i in range(1, 10): + number = 10**i + + try: + x = t.timeit(number) + except: + print(t.print_exc()) + return float('NaN') + + if x >= 0.2: + break + + try: + r = t.repeat(repeat, number) + except: + print(t.print_exc()) + return float('NaN') + + best = min(r) + + return best / number + + +def build_task(task, name='', setup='', number=0, repeat=3): + nt = copy.copy(task) + + nt['name'] = (name + ' ' + task.get('name', '')).strip() + nt['setup'] = (setup + '\n' + task.get('setup', '')).strip('\n') + nt['stmt'] = task.get('stmt', '') + nt['number'] = task.get('number', number) + nt['repeat'] = task.get('repeat', repeat) + + return nt + + +def time_task(name, stmt='pass', setup='pass', number=0, repeat=3, stmts='', base=''): + + if base: + nvalue = time_stmt(stmt=base, setup=setup, number=number, repeat=repeat) + yield name + ' (base)', nvalue + suffix = ' (normalized)' + else: + nvalue = 1. + suffix = '' + + if stmt: + value = time_stmt(stmt=stmt, setup=setup, number=number, repeat=repeat) + yield name, value / nvalue + + for task in stmts: + new_task = build_task(task, name, setup, number, repeat) + for task_name, value in time_task(**new_task): + yield task_name + suffix, value / nvalue + + +def time_file(filename, name='', setup='', number=0, repeat=3): + """Open a yaml benchmark file an time each statement, + + yields a tuple with filename, task name, time in seconds. + """ + with open(filename, 'r') as fp: + tasks = yaml.load(fp) + + for task in tasks: + new_task = build_task(task, name, setup, number, repeat) + for task_name, value in time_task(**new_task): + yield task_name, value + + +def main(filenames=None): + if not filenames: + filenames = glob.iglob('bench_*.yaml') + elif isinstance(filenames, basestring): + filenames = [filenames, ] + + for filename in filenames: + print(filename) + print('-' * len(filename)) + print() + for task_name, value in time_file(filename): + print('%.2e %s' % (value, task_name)) + print() + +main() diff --git a/bench_base.yaml b/bench_base.yaml new file mode 100644 index 0000000..f767a04 --- /dev/null +++ b/bench_base.yaml @@ -0,0 +1,42 @@ + + +- name: importing + stmt: import pint + +- name: empty registry + setup: import pint + stmt: ureg = pint.UnitRegistry(None) + +- name: default registry + setup: import pint + stmt: ureg = pint.UnitRegistry() + +- name: finding meter + setup: | + import pint + ureg = pint.UnitRegistry() + stmts: + - name: (attr) + stmt: q = ureg.meter + - name: (item) + stmt: q = ureg['meter'] + +- name: base units + setup: | + import pint + ureg = pint.UnitRegistry() + stmts: + - name: meter + stmt: ureg.get_base_units('meter') + - name: yard + stmt: ureg.get_base_units('yard') + - name: meter / second + stmt: ureg.get_base_units('meter / second') + - name: yard / minute + stmt: ureg.get_base_units('yard / minute') + +- name: build cache + setup: | + import pint + ureg = pint.UnitRegistry() + stmt: ureg._build_cache() diff --git a/bench_numpy.yaml b/bench_numpy.yaml new file mode 100644 index 0000000..8609bc0 --- /dev/null +++ b/bench_numpy.yaml @@ -0,0 +1,23 @@ + + +- name: NumPy + setup: | + import numpy as np + import pint + ureg = pint.UnitRegistry() + stmts: + - name: cosine + setup: | + d = np.arange(0, 90, 10) + r = np.deg2rad(d) + base: np.cos(r) + stmts: + - name: radian + setup: x = r * ureg.radian + stmt: np.cos(x) + - name: dimensionless + setup: x = r * ureg.dimensionless + stmt: np.cos(x) + - name: degree + setup: x = d * ureg.degree + stmt: np.cos(x) diff --git a/docs/getting.rst b/docs/getting.rst index 14f1d7f..a845d6b 100644 --- a/docs/getting.rst +++ b/docs/getting.rst @@ -3,7 +3,7 @@ Installation ============ -Pint has no dependencies except Python_ itself. In runs on Python 2.7 and 3.0+. +Pint has no dependencies except Python_ itself. In runs on Python 2.6 and 3.0+. You can install it using pip_:: diff --git a/pint/__init__.py b/pint/__init__.py index 45de3b0..049a7d6 100644 --- a/pint/__init__.py +++ b/pint/__init__.py @@ -18,7 +18,7 @@ import pkg_resources from .formatter import formatter from .unit import UnitRegistry, DimensionalityError, UndefinedUnitError from .util import pi_theorem, logger -from .measurement import Measurement + from .context import Context _DEFAULT_REGISTRY = UnitRegistry() diff --git a/pint/compat/__init__.py b/pint/compat/__init__.py new file mode 100644 index 0000000..e4221c4 --- /dev/null +++ b/pint/compat/__init__.py @@ -0,0 +1,112 @@ +# -*- coding: utf-8 -*- +""" + pint.compat + ~~~~~~~~~~~ + + Compatibility layer. + + :copyright: 2013 by Pint Authors, see AUTHORS for more details. + :license: BSD, see LICENSE for more details. +""" + +from __future__ import division, unicode_literals, print_function, absolute_import + +import sys +import tokenize + +from numbers import Number +from decimal import Decimal + + +PYTHON3 = sys.version >= '3' + +if PYTHON3: + from io import BytesIO + string_types = str + tokenizer = lambda input_string: tokenize.tokenize(BytesIO(input_string.encode('utf-8')).readline) + + def u(x): + return x +else: + from StringIO import StringIO + string_types = basestring + tokenizer = lambda input_string: tokenize.generate_tokens(StringIO(input_string).readline) + + import codecs + string_types = basestring + + def u(x): + return codecs.unicode_escape_decode(x)[0] + + +if sys.version_info < (2, 7): + import unittest2 as unittest +else: + import unittest + + +try: + from collections import Chainmap +except ImportError: + from .chainmap import ChainMap + +try: + from collections import TransformDict +except ImportError: + from .transformdict import TransformDict + +try: + from functools import lru_cache +except ImportError: + from .lrucache import lru_cache + +try: + from logging import NullHandler +except ImportError: + from .nullhandler import NullHandler + +try: + import numpy as np + from numpy import ndarray + + HAS_NUMPY = True + NUMPY_VER = np.__version__ + NUMERIC_TYPES = (Number, Decimal, ndarray, np.number) + + def _to_magnitude(value, force_ndarray=False): + if isinstance(value, (dict, bool)) or value is None: + raise TypeError('Invalid magnitude for Quantity: {0!r}'.format(value)) + elif isinstance(value, string_types) and value == '': + raise ValueError('Quantity magnitude cannot be an empty string.') + elif isinstance(value, (list, tuple)): + return np.asarray(value) + if force_ndarray: + return np.asarray(value) + return value + +except ImportError: + + np = None + + class ndarray(object): + pass + + HAS_NUMPY = False + NUMPY_VER = 0 + NUMERIC_TYPES = (Number, Decimal) + + def _to_magnitude(value, force_ndarray=False): + if isinstance(value, (dict, bool)) or value is None: + raise TypeError('Invalid magnitude for Quantity: {0!r}'.format(value)) + elif isinstance(value, string_types) and value == '': + raise ValueError('Quantity magnitude cannot be an empty string.') + elif isinstance(value, (list, tuple)): + raise TypeError('lists and tuples are valid magnitudes for ' + 'Quantity only when NumPy is present.') + return value + +try: + from uncertainties import ufloat +except ImportError: + ufloat = None + diff --git a/pint/compat.py b/pint/compat/chainmap.py index 65808f1..f4c9a4e 100644 --- a/pint/compat.py +++ b/pint/compat/chainmap.py @@ -1,18 +1,17 @@ # -*- coding: utf-8 -*- """ - pint.compat - ~~~~~~~~~~~ + pint.compat.chainmap + ~~~~~~~~~~~~~~~~~~~~ - Compatibility layer. + Taken from the Python 3.3 source code. - :copyright: 2013 by Pint Authors, see AUTHORS for more details. - :license: BSD, see LICENSE for more details. + :copyright: 2013, PSF + :license: PSF License """ from __future__ import division, unicode_literals, print_function, absolute_import import sys -import logging from collections import MutableMapping if sys.version_info < (3, 0): @@ -152,27 +151,3 @@ class ChainMap(MutableMapping): def clear(self): 'Clear maps[0], leaving maps[1:] intact.' self.maps[0].clear() - - -if hasattr(logging, "NullHandler"): - NullHandler = logging.NullHandler -else: - class NullHandler(logging.Handler): - """ - This handler does nothing. It's intended to be used to avoid the - "No handlers could be found for logger XXX" one-off warning. This is - important for library code, which may contain code to log events. If a user - of the library does not configure logging, the one-off warning might be - produced; to avoid this, the library developer simply needs to instantiate - a NullHandler and add it to the top-level logger of the library module or - package. - """ - def handle(self, record): - pass - - def emit(self, record): - pass - - def createLock(self): - self.lock = None - diff --git a/pint/compat/lrucache.py b/pint/compat/lrucache.py new file mode 100644 index 0000000..868b598 --- /dev/null +++ b/pint/compat/lrucache.py @@ -0,0 +1,177 @@ +# -*- coding: utf-8 -*- +""" + pint.compat.lrucache + ~~~~~~~~~~~~~~~~~~~~ + + LRU (least recently used) cache backport. + + From https://code.activestate.com/recipes/578078-py26-and-py30-backport-of-python-33s-lru-cache/ + + :copyright: 2004, Raymond Hettinger, + :license: MIT License +""" + +from collections import namedtuple +from functools import update_wrapper +from threading import RLock + +_CacheInfo = namedtuple("CacheInfo", ["hits", "misses", "maxsize", "currsize"]) + +class _HashedSeq(list): + __slots__ = 'hashvalue' + + def __init__(self, tup, hash=hash): + self[:] = tup + self.hashvalue = hash(tup) + + def __hash__(self): + return self.hashvalue + +def _make_key(args, kwds, typed, + kwd_mark = (object(),), + fasttypes = set((int, str, frozenset, type(None))), + sorted=sorted, tuple=tuple, type=type, len=len): + 'Make a cache key from optionally typed positional and keyword arguments' + key = args + if kwds: + sorted_items = sorted(kwds.items()) + key += kwd_mark + for item in sorted_items: + key += item + if typed: + key += tuple(type(v) for v in args) + if kwds: + key += tuple(type(v) for k, v in sorted_items) + elif len(key) == 1 and type(key[0]) in fasttypes: + return key[0] + return _HashedSeq(key) + +def lru_cache(maxsize=100, typed=False): + """Least-recently-used cache decorator. + + If *maxsize* is set to None, the LRU features are disabled and the cache + can grow without bound. + + If *typed* is True, arguments of different types will be cached separately. + For example, f(3.0) and f(3) will be treated as distinct calls with + distinct results. + + Arguments to the cached function must be hashable. + + View the cache statistics named tuple (hits, misses, maxsize, currsize) with + f.cache_info(). Clear the cache and statistics with f.cache_clear(). + Access the underlying function with f.__wrapped__. + + See: http://en.wikipedia.org/wiki/Cache_algorithms#Least_Recently_Used + + """ + + # Users should only access the lru_cache through its public API: + # cache_info, cache_clear, and f.__wrapped__ + # The internals of the lru_cache are encapsulated for thread safety and + # to allow the implementation to change (including a possible C version). + + def decorating_function(user_function): + + cache = dict() + stats = [0, 0] # make statistics updateable non-locally + HITS, MISSES = 0, 1 # names for the stats fields + make_key = _make_key + cache_get = cache.get # bound method to lookup key or return None + _len = len # localize the global len() function + lock = RLock() # because linkedlist updates aren't threadsafe + root = [] # root of the circular doubly linked list + root[:] = [root, root, None, None] # initialize by pointing to self + nonlocal_root = [root] # make updateable non-locally + PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields + + if maxsize == 0: + + def wrapper(*args, **kwds): + # no caching, just do a statistics update after a successful call + result = user_function(*args, **kwds) + stats[MISSES] += 1 + return result + + elif maxsize is None: + + def wrapper(*args, **kwds): + # simple caching without ordering or size limit + key = make_key(args, kwds, typed) + result = cache_get(key, root) # root used here as a unique not-found sentinel + if result is not root: + stats[HITS] += 1 + return result + result = user_function(*args, **kwds) + cache[key] = result + stats[MISSES] += 1 + return result + + else: + + def wrapper(*args, **kwds): + # size limited caching that tracks accesses by recency + key = make_key(args, kwds, typed) if kwds or typed else args + with lock: + link = cache_get(key) + if link is not None: + # record recent use of the key by moving it to the front of the list + root, = nonlocal_root + link_prev, link_next, key, result = link + link_prev[NEXT] = link_next + link_next[PREV] = link_prev + last = root[PREV] + last[NEXT] = root[PREV] = link + link[PREV] = last + link[NEXT] = root + stats[HITS] += 1 + return result + result = user_function(*args, **kwds) + with lock: + root, = nonlocal_root + if key in cache: + # getting here means that this same key was added to the + # cache while the lock was released. since the link + # update is already done, we need only return the + # computed result and update the count of misses. + pass + elif _len(cache) >= maxsize: + # use the old root to store the new key and result + oldroot = root + oldroot[KEY] = key + oldroot[RESULT] = result + # empty the oldest link and make it the new root + root = nonlocal_root[0] = oldroot[NEXT] + oldkey = root[KEY] + oldvalue = root[RESULT] + root[KEY] = root[RESULT] = None + # now update the cache dictionary for the new links + del cache[oldkey] + cache[key] = oldroot + else: + # put result in a new link at the front of the list + last = root[PREV] + link = [last, root, key, result] + last[NEXT] = root[PREV] = cache[key] = link + stats[MISSES] += 1 + return result + + def cache_info(): + """Report cache statistics""" + with lock: + return _CacheInfo(stats[HITS], stats[MISSES], maxsize, len(cache)) + + def cache_clear(): + """Clear the cache and cache statistics""" + with lock: + cache.clear() + root = nonlocal_root[0] + root[:] = [root, root, None, None] + stats[:] = [0, 0] + + wrapper.__wrapped__ = user_function + wrapper.cache_info = cache_info + wrapper.cache_clear = cache_clear + return update_wrapper(wrapper, user_function) + + return decorating_function diff --git a/pint/compat/nullhandler.py b/pint/compat/nullhandler.py new file mode 100644 index 0000000..288cbb3 --- /dev/null +++ b/pint/compat/nullhandler.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +""" + pint.compat.nullhandler + ~~~~~~~~~~~~~~~~~~~~~~~ + + Taken from the Python 2.7 source code. + + :copyright: 2013, PSF + :license: PSF License +""" + + +import logging + +class NullHandler(logging.Handler): + """ + This handler does nothing. It's intended to be used to avoid the + "No handlers could be found for logger XXX" one-off warning. This is + important for library code, which may contain code to log events. If a user + of the library does not configure logging, the one-off warning might be + produced; to avoid this, the library developer simply needs to instantiate + a NullHandler and add it to the top-level logger of the library module or + package. + """ + def handle(self, record): + pass + + def emit(self, record): + pass + + def createLock(self): + self.lock = None diff --git a/pint/compat/transformdict.py b/pint/compat/transformdict.py new file mode 100644 index 0000000..c01ea30 --- /dev/null +++ b/pint/compat/transformdict.py @@ -0,0 +1,136 @@ +# -*- coding: utf-8 -*- +""" + pint.compat.transformdict + ~~~~~~~~~~~~~~~~~~~~~~~~~ + + Taken from the Python 3.4 source code. + + :copyright: 2013, PSF + :license: PSF License +""" + +from collections import MutableMapping + +_sentinel = object() + +class TransformDict(MutableMapping): + '''Dictionary that calls a transformation function when looking + up keys, but preserves the original keys. + + >>> d = TransformDict(str.lower) + >>> d['Foo'] = 5 + >>> d['foo'] == d['FOO'] == d['Foo'] == 5 + True + >>> set(d.keys()) + {'Foo'} + ''' + + __slots__ = ('_transform', '_original', '_data') + + def __init__(self, transform, init_dict=None, **kwargs): + '''Create a new TransformDict with the given *transform* function. + *init_dict* and *kwargs* are optional initializers, as in the + dict constructor. + ''' + if not callable(transform): + raise TypeError("expected a callable, got %r" % transform.__class__) + self._transform = transform + # transformed => original + self._original = {} + self._data = {} + if init_dict: + self.update(init_dict) + if kwargs: + self.update(kwargs) + + def getitem(self, key): + 'D.getitem(key) -> (stored key, value)' + transformed = self._transform(key) + original = self._original[transformed] + value = self._data[transformed] + return original, value + + @property + def transform_func(self): + "This TransformDict's transformation function" + return self._transform + + # Minimum set of methods required for MutableMapping + + def __len__(self): + return len(self._data) + + def __iter__(self): + return iter(self._original.values()) + + def __getitem__(self, key): + return self._data[self._transform(key)] + + def __setitem__(self, key, value): + transformed = self._transform(key) + self._data[transformed] = value + self._original.setdefault(transformed, key) + + def __delitem__(self, key): + transformed = self._transform(key) + del self._data[transformed] + del self._original[transformed] + + # Methods overriden to mitigate the performance overhead. + + def clear(self): + 'D.clear() -> None. Remove all items from D.' + self._data.clear() + self._original.clear() + + def __contains__(self, key): + return self._transform(key) in self._data + + def get(self, key, default=None): + 'D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None.' + return self._data.get(self._transform(key), default) + + def pop(self, key, default=_sentinel): + '''D.pop(k[,d]) -> v, remove specified key and return the corresponding value. + If key is not found, d is returned if given, otherwise KeyError is raised. + ''' + transformed = self._transform(key) + if default is _sentinel: + del self._original[transformed] + return self._data.pop(transformed) + else: + self._original.pop(transformed, None) + return self._data.pop(transformed, default) + + def popitem(self): + '''D.popitem() -> (k, v), remove and return some (key, value) pair + as a 2-tuple; but raise KeyError if D is empty. + ''' + transformed, value = self._data.popitem() + return self._original.pop(transformed), value + + # Other methods + + def copy(self): + 'D.copy() -> a shallow copy of D' + other = self.__class__(self._transform) + other._original = self._original.copy() + other._data = self._data.copy() + return other + + __copy__ = copy + + def __getstate__(self): + return (self._transform, self._data, self._original) + + def __setstate__(self, state): + self._transform, self._data, self._original = state + + def __repr__(self): + try: + equiv = dict(self) + except TypeError: + # Some keys are unhashable, fall back on .items() + equiv = list(self.items()) + return '%s(%r, %s)' % (self.__class__.__name__, + self._transform, repr(equiv)) diff --git a/pint/context.py b/pint/context.py index a59664e..947d643 100644 --- a/pint/context.py +++ b/pint/context.py @@ -15,8 +15,9 @@ from __future__ import division, unicode_literals, print_function, absolute_impo import re from collections import defaultdict import weakref -from pint.compat import ChainMap -from pint.util import ParserHelper, string_types + +from .compat import ChainMap +from .util import ParserHelper, string_types #: Regex to match the header parts of a context. _header_re = re.compile('@context\s*(?P<defaults>\(.*\))?\s+(?P<name>\w+)\s*(=(?P<aliases>.*))*') diff --git a/pint/measurement.py b/pint/measurement.py index 346ce0c..e3d5f4c 100644 --- a/pint/measurement.py +++ b/pint/measurement.py @@ -9,116 +9,51 @@ from __future__ import division, unicode_literals, print_function, absolute_import +from .compat import ufloat -import operator +MISSING = object() - -class Measurement(object): +class _Measurement(object): """Implements a class to describe a quantity with uncertainty. :param value: The most likely value of the measurement. :type value: Quantity or Number :param error: The error or uncertainty of the measurement. - :type value: Quantity or Number + :type error: Quantity or Number """ - def __init__(self, value, error): - if not (value/error).unitless: - raise ValueError('{0} and {1} have incompatible units'.format(value, error)) + def __new__(cls, value, error, units=MISSING): + if units is MISSING: + try: + value, units = value.magnitude, value.units + except AttributeError: + try: + value, error, units = value.nominal_value, value.std_dev, error + except AttributeError: + units = '' try: - emag = error.magnitude + error = error.to(units).magnitude except AttributeError: - emag = error + pass - if emag < 0: - raise ValueError('The magnitude of the error cannot be negative'.format(value, error)) + inst = super(_Measurement, cls).__new__(cls, ufloat(value, error), units) - self._value = value - self._error = error + if error < 0: + raise ValueError('The magnitude of the error cannot be negative'.format(value, error)) + return inst @property def value(self): - return self._value + return self._REGISTRY.Quantity(self.magnitude.nominal_value, self.units) @property def error(self): - return self._error + return self._REGISTRY.Quantity(self.magnitude.std_dev, self.units) @property def rel(self): - return float(abs(self._error / self._value)) - - def _add_sub(self, other, operator): - result = self.value + other.value - if isinstance(other, self.__class__): - error = (self.error ** 2.0 + other.error ** 2.0) ** (1/2) - else: - error = self.error - return result.plus_minus(error) - - def __add__(self, other): - return self._add_sub(other, operator.add) - - __radd__ = __add__ - - def __sub__(self, other): - return self._add_sub(other, operator.sub) - - __rsub__ = __sub__ - - def _mul_div(self, other, operator): - if isinstance(other, self.__class__): - result = operator(self.value, other.value) - return result.plus_minus((self.rel ** 2.0 + other.rel ** 2.0) ** (1/2), relative=True) - else: - result = operator(self.value, other) - return result.plus_minus(abs(operator(self.error, other))) - - def __mul__(self, other): - return self._mul_div(other, operator.mul) - - __rmul__ = __mul__ - - def __truediv__(self, other): - return self._mul_div(other, operator.truediv) - - def __floordiv__(self, other): - return self._mul_div(other, operator.floordiv) - - __div__ = __floordiv__ - - def __str__(self): - return '{}'.format(self) + return float(abs(self.magnitude.std_dev / self.magnitude.nominal_value)) def __repr__(self): return "<Measurement({0:!r}, {1:!r})>".format(self._value, self._error) - def __format__(self, spec): - if '!' in spec: - fmt, conv = spec.split('!') - conv = '!' + conv - else: - fmt, conv = spec, '' - - left, right = '(', ')' - if '!l' == conv: - pm = r'\pm' - left = r'\left' + left - right = r'\right' + right - elif '!p' == conv: - pm = '±' - else: - pm = '+/-' - - if hasattr(self.value, 'units'): - vmag = format(self.value.magnitude, fmt) - if self.value.units != self.error.units: - emag = self.error.to(self.value.units).magnitude - else: - emag = self.error.magnitude - emag = format(emag, fmt) - units = ' ' + format(self.value.units, conv) - else: - vmag, emag, units = self.value, self.error, '' - - return left + vmag + ' ' + pm + ' ' + emag + right + units if units else '' diff --git a/pint/quantity.py b/pint/quantity.py index 853195b..924f74b 100644 --- a/pint/quantity.py +++ b/pint/quantity.py @@ -16,33 +16,7 @@ from collections import Iterable from .formatter import remove_custom_flags from .unit import DimensionalityError, UnitsContainer, UnitDefinition, UndefinedUnitError -from .measurement import Measurement -from .util import string_types, NUMERIC_TYPES, ndarray - -try: - import numpy as np - - def _to_magnitude(value, force_ndarray=False): - if isinstance(value, (dict, bool)) or value is None: - raise ValueError('Invalid magnitude for Quantity: {!r}'.format(value)) - elif isinstance(value, string_types) and value == '': - raise ValueError('Quantity magnitude cannot be an empty string.') - elif isinstance(value, (list, tuple)): - return np.asarray(value) - if force_ndarray: - return np.asarray(value) - return value - -except ImportError: - def _to_magnitude(value, force_ndarray=False): - if isinstance(value, (dict, bool)) or value is None: - raise ValueError('Invalid magnitude for Quantity: {!r}'.format(value)) - elif isinstance(value, string_types) and value == '': - raise ValueError('Quantity magnitude cannot be an empty string.') - elif isinstance(value, (list, tuple)): - raise ValueError('lists and tuples are valid magnitudes for ' - 'Quantity only when NumPy is present.') - return value +from .compat import string_types, ndarray, np, _to_magnitude def _eq(first, second, check_all): @@ -84,7 +58,7 @@ def _check(q1, other): raise ValueError('Cannot operate between quantities of different registries') -def _has_multiplicative_units(q): +def _only_multiplicative_units(q): """Check if the quantity has non-multiplicative units. """ @@ -118,7 +92,7 @@ class _Quantity(object): if units is None: if isinstance(value, string_types): if value == '': - raise ValueError('Quantity magnitude cannot be an empty string.') + raise ValueError('Expression to parse as Quantity cannot be an empty string.') inst = cls._REGISTRY.parse_expression(value) return cls.__new__(cls, inst) elif isinstance(value, cls): @@ -142,11 +116,18 @@ class _Quantity(object): raise TypeError('units must be of type str, Quantity or ' 'UnitsContainer; not {0}.'.format(type(units))) + inst.__used = False inst.__handling = None return inst + @property + def debug_used(self): + return self.__used + def __copy__(self): - return self.__class__(copy.copy(self._magnitude), copy.copy(self._units)) + ret = self.__class__(copy.copy(self._magnitude), copy.copy(self._units)) + ret.__used = self.__used + return ret def __str__(self): return '{0} {1}'.format(self._magnitude, self._units) @@ -213,6 +194,13 @@ class _Quantity(object): return self._dimensionality + def compatible_units(self, *contexts): + if contexts: + with self._REGISTRY.context(*contexts): + return self._REGISTRY.get_compatible_units(self._units) + + return self._REGISTRY.get_compatible_units(self._units) + def _convert_magnitude(self, other, *contexts, **ctx_kwargs): if contexts: with self._REGISTRY.context(*contexts, **ctx_kwargs): @@ -279,6 +267,13 @@ class _Quantity(object): raise DimensionalityError(self.units, 'dimensionless') def _iadd_sub(self, other, op): + """Perform addition or subtraction operation in-place and return the result. + + :param other: object to be added to / subtracted from self + :type other: Quantity or any type accepted by :func:`_to_magnitude` + :param op: operator function (e.g. operator.add, operator.isub) + :type op: function + """ if _check(self, other): if not self.dimensionality == other.dimensionality: raise DimensionalityError(self.units, other.units, @@ -288,15 +283,26 @@ class _Quantity(object): else: self._magnitude = op(self._magnitude, other.to(self)._magnitude) else: + try: + other_magnitude = _to_magnitude(other, self.force_ndarray) + except TypeError: + return NotImplemented if self.dimensionless: self.ito(UnitsContainer()) - self._magnitude = op(self._magnitude, _to_magnitude(other, self.force_ndarray)) + self._magnitude = op(self._magnitude, other_magnitude) else: raise DimensionalityError(self.units, 'dimensionless') return self def _add_sub(self, other, op): + """Perform addition or subtraction operation and return the result. + + :param other: object to be added to / subtracted from self + :type other: Quantity or any type accepted by :func:`_to_magnitude` + :param op: operator function (e.g. operator.add, operator.isub) + :type op: function + """ if _check(self, other): if not self.dimensionality == other.dimensionality: raise DimensionalityError(self.units, other.units, @@ -337,18 +343,24 @@ class _Quantity(object): def _imul_div(self, other, magnitude_op, units_op=None): """Perform multiplication or division operation in-place and return the result. - Arguments: - other -- object to be multiplied/divided with self - magnitude_op -- operator function to perform on the magnitudes (e.g. operator.mul) - units_op -- operator function to perform on the units; if None, magnitude_op is used - + :param other: object to be multiplied/divided with self + :type other: Quantity or any type accepted by :func:`_to_magnitude` + :param magnitude_op: operator function to perform on the magnitudes (e.g. operator.mul) + :type magnitude_op: function + :param units_op: operator function to perform on the units; if None, *magnitude_op* is used + :type units_op: function or None """ if units_op is None: units_op = magnitude_op - if _check(self, other): - if not _has_multiplicative_units(self): + + if self.__used: + if not _only_multiplicative_units(self): self.ito_base_units() - if not _has_multiplicative_units(other): + else: + self.__used = True + + if _check(self, other): + if not _only_multiplicative_units(other): other = other.to_base_units() self._magnitude = magnitude_op(self._magnitude, other._magnitude) self._units = units_op(self._units, other._units) @@ -410,7 +422,7 @@ class _Quantity(object): except TypeError: return NotImplemented else: - if not _has_multiplicative_units(self): + if not _only_multiplicative_units(self): self.ito_base_units() self._magnitude **= _to_magnitude(other, self.force_ndarray) self._units **= other @@ -785,11 +797,10 @@ class _Quantity(object): def plus_minus(self, error, relative=False): if isinstance(error, self.__class__): if relative: - raise ValueError('{0} is not a valid relative error.'.format(error)) + raise ValueError('{} is not a valid relative error.'.format(error)) + error = error.to(self.units).magnitude else: if relative: - error = error * abs(self) - else: - error = self.__class__(error, self.units) + error = error * abs(self.magnitude) - return Measurement(copy.copy(self), error) + return self._REGISTRY.Measurement(copy.copy(self.magnitude), error, self.units) diff --git a/pint/testsuite/__init__.py b/pint/testsuite/__init__.py index f8e94aa..528ead7 100644 --- a/pint/testsuite/__init__.py +++ b/pint/testsuite/__init__.py @@ -3,37 +3,9 @@ from __future__ import division, unicode_literals, print_function, absolute_import import os -import sys import logging -if sys.version_info < (2, 7): - import unittest2 as unittest -else: - import unittest - -try: - import numpy as np - HAS_NUMPY = True - ndarray = np.ndarray - NUMPY_VER = np.__version__ -except ImportError: - np = None - HAS_NUMPY = False - NUMPY_VER = 0 - class ndarray(object): - pass - -PYTHON3 = sys.version >= '3' - -if PYTHON3: - string_types = str - def u(x): - return x -else: - import codecs - string_types = basestring - def u(x): - return codecs.unicode_escape_decode(x)[0] +from pint.compat import ndarray, unittest from pint import logger, UnitRegistry diff --git a/pint/testsuite/test_contexts.py b/pint/testsuite/test_contexts.py index 4f93a9c..598f138 100644 --- a/pint/testsuite/test_contexts.py +++ b/pint/testsuite/test_contexts.py @@ -8,7 +8,8 @@ from collections import defaultdict from pint import UnitRegistry from pint.context import Context, _freeze from pint.unit import UnitsContainer -from pint.testsuite import TestCase, unittest +from pint.testsuite import TestCase +from pint.compat import unittest from pint import logger diff --git a/pint/testsuite/test_issues.py b/pint/testsuite/test_issues.py index 0eaecaa..57d675b 100644 --- a/pint/testsuite/test_issues.py +++ b/pint/testsuite/test_issues.py @@ -6,7 +6,9 @@ from pint import UnitRegistry from pint.unit import UnitsContainer from pint.util import ParserHelper -from pint.testsuite import HAS_NUMPY, np, TestCase, ndarray, NUMPY_VER, unittest +from pint.compat import HAS_NUMPY, np, unittest +from pint.testsuite import TestCase + class TestIssues(TestCase): @@ -62,17 +64,19 @@ class TestIssues(TestCase): def test_issue61(self): ureg = UnitRegistry() Q_ = ureg.Quantity - for value in ({}, {'a': 3}, '', None, True, False): - self.assertRaises(ValueError, Q_, value) - self.assertRaises(ValueError, Q_, value, 'meter') + for value in ({}, {'a': 3}, None): + self.assertRaises(TypeError, Q_, value) + self.assertRaises(TypeError, Q_, value, 'meter') + self.assertRaises(ValueError, Q_, '', 'meter') + self.assertRaises(ValueError, Q_, '') @unittest.skipIf(HAS_NUMPY, 'Numpy present') def test_issue61_notNP(self): ureg = UnitRegistry() Q_ = ureg.Quantity for value in ([1, 2, 3], (1, 2, 3)): - self.assertRaises(ValueError, Q_, value) - self.assertRaises(ValueError, Q_, value, 'meter') + self.assertRaises(TypeError, Q_, value) + self.assertRaises(TypeError, Q_, value, 'meter') def test_issue66(self): ureg = UnitRegistry() @@ -163,6 +167,12 @@ class TestIssues(TestCase): self.assertAlmostEqual(v1.to_base_units(), v2) self.assertAlmostEqual(v1.to_base_units(), v2.to_base_units()) + def test_issue86c(self): + ureg = self.ureg + T = ureg.degC + T = 100. * T + self.assertAlmostEqual(ureg.k*2*T, ureg.k*(2*T)) + def test_issue93(self): ureg = UnitRegistry() self.assertIsInstance(ureg.meter.magnitude, int) diff --git a/pint/testsuite/test_measurement.py b/pint/testsuite/test_measurement.py index c4cc14b..731ded8 100644 --- a/pint/testsuite/test_measurement.py +++ b/pint/testsuite/test_measurement.py @@ -2,21 +2,28 @@ from __future__ import division, unicode_literals, print_function, absolute_import -from pint import Measurement -from pint.testsuite import TestCase - +from pint.compat import ufloat +from pint.testsuite import TestCase, unittest +@unittest.skipIf(ufloat is None, 'Uncertainties not installed.') class TestMeasurement(TestCase): FORCE_NDARRAY = False + def test_simple(self): + M_ = self.ureg.Measurement + M_(4.0, 0.1, 's') + def test_build(self): + M_ = self.ureg.Measurement v, u = self.Q_(4.0, 's'), self.Q_(.1, 's') - - ms = (Measurement(v, u), + M_(v.magnitude, u.magnitude, 's') + ms = (M_(v.magnitude, u.magnitude, 's'), + M_(v, u.magnitude), + M_(v, u), v.plus_minus(.1), v.plus_minus(0.025, True), - v.plus_minus(u)) + v.plus_minus(u),) for m in ms: self.assertEqual(m.value, v) @@ -25,7 +32,7 @@ class TestMeasurement(TestCase): def _test_format(self): v, u = self.Q_(4.0, 's'), self.Q_(.1, 's') - m = Measurement(v, u) + m = self.ureg.Measurement(v, u) print(str(m)) print(repr(m)) print('{:!s}'.format(m)) @@ -39,8 +46,8 @@ class TestMeasurement(TestCase): v, u = self.Q_(1.0, 's'), self.Q_(.1, 's') o = self.Q_(.1, 'm') - self.assertRaises(ValueError, Measurement, u, 1) - self.assertRaises(ValueError, Measurement, u, o) + M_ = self.ureg.Measurement + self.assertRaises(ValueError, M_, v, o) self.assertRaises(ValueError, v.plus_minus, o) self.assertRaises(ValueError, v.plus_minus, u, True) @@ -56,18 +63,26 @@ class TestMeasurement(TestCase): for factor, m in zip((3, -3, 3, -3), (m1, m3, m1, m3)): r = factor * m - self.assertAlmostEqual(r.value, factor * m.value) - self.assertAlmostEqual(r.error ** 2.0, (factor * m.error) **2.0) + self.assertAlmostEqual(r.value.magnitude, factor * m.value.magnitude) + self.assertAlmostEqual(r.error.magnitude, abs(factor * m.error.magnitude)) + self.assertEqual(r.value.units, m.value.units) for ml, mr in zip((m1, m1, m1, m3), (m1, m2, m3, m3)): r = ml + mr - self.assertAlmostEqual(r.value, ml.value + mr.value) - self.assertAlmostEqual(r.error ** 2.0, ml.error **2.0 + mr.error ** 2.0) + self.assertAlmostEqual(r.value.magnitude, ml.value.magnitude + mr.value.magnitude) + self.assertAlmostEqual(r.error.magnitude, + ml.error.magnitude + mr.error.magnitude if ml is mr else + (ml.error.magnitude ** 2 + mr.error.magnitude ** 2) ** .5) + self.assertEqual(r.value.units, ml.value.units) for ml, mr in zip((m1, m1, m1, m3), (m1, m2, m3, m3)): r = ml - mr - self.assertAlmostEqual(r.value, ml.value + mr.value) - self.assertAlmostEqual(r.error ** 2.0, ml.error **2.0 + mr.error ** 2.0) + print(ml, mr, ml is mr, r) + self.assertAlmostEqual(r.value.magnitude, ml.value.magnitude - mr.value.magnitude) + self.assertAlmostEqual(r.error.magnitude, + 0 if ml is mr else + (ml.error.magnitude ** 2 + mr.error.magnitude ** 2) ** .5) + self.assertEqual(r.value.units, ml.value.units) def test_propagate_product(self): @@ -84,10 +99,10 @@ class TestMeasurement(TestCase): for ml, mr in zip((m1, m1, m1, m3, m4), (m1, m2, m3, m3, m5)): r = ml * mr - self.assertAlmostEqual(r.value, ml.value * mr.value) - self.assertAlmostEqual(r.rel ** 2.0, ml.rel ** 2.0 + mr.rel ** 2.0) + self.assertAlmostEqual(r.value.magnitude, ml.value.magnitude * mr.value.magnitude) + self.assertEqual(r.value.units, ml.value.units * mr.value.units) for ml, mr in zip((m1, m1, m1, m3, m4), (m1, m2, m3, m3, m5)): r = ml / mr - self.assertAlmostEqual(r.value, ml.value / mr.value) - self.assertAlmostEqual(r.rel ** 2.0, ml.rel ** 2.0 + mr.rel ** 2.0) + self.assertAlmostEqual(r.value.magnitude, ml.value.magnitude / mr.value.magnitude) + self.assertEqual(r.value.units, ml.value.units / mr.value.units) diff --git a/pint/testsuite/test_numpy.py b/pint/testsuite/test_numpy.py index 1bf45ba..5063e1c 100644 --- a/pint/testsuite/test_numpy.py +++ b/pint/testsuite/test_numpy.py @@ -2,7 +2,8 @@ from __future__ import division, unicode_literals, print_function, absolute_import -from pint.testsuite import TestCase, HAS_NUMPY, np, unittest +from pint.compat import HAS_NUMPY, np, unittest +from pint.testsuite import TestCase @unittest.skipUnless(HAS_NUMPY, 'Numpy not present') class TestNumpyMethods(TestCase): diff --git a/pint/testsuite/test_pitheorem.py b/pint/testsuite/test_pitheorem.py index 88f7fbb..3f5a2a1 100644 --- a/pint/testsuite/test_pitheorem.py +++ b/pint/testsuite/test_pitheorem.py @@ -6,7 +6,7 @@ import itertools from pint import pi_theorem -from pint.testsuite import TestCase, unittest +from pint.testsuite import TestCase class TestPiTheorem(TestCase): diff --git a/pint/testsuite/test_quantity.py b/pint/testsuite/test_quantity.py index 3824102..088ea4f 100644 --- a/pint/testsuite/test_quantity.py +++ b/pint/testsuite/test_quantity.py @@ -6,10 +6,10 @@ import copy import math import operator as op +from pint import DimensionalityError, UnitRegistry from pint.unit import UnitsContainer -from pint import DimensionalityError, UndefinedUnitError, UnitRegistry - -from pint.testsuite import TestCase, string_types, PYTHON3 +from pint.compat import string_types, PYTHON3 +from pint.testsuite import TestCase class TestQuantity(TestCase): diff --git a/pint/testsuite/test_umath.py b/pint/testsuite/test_umath.py index 5d1f0e8..a44993a 100644 --- a/pint/testsuite/test_umath.py +++ b/pint/testsuite/test_umath.py @@ -2,7 +2,8 @@ from __future__ import division, unicode_literals, print_function, absolute_import -from pint.testsuite import TestCase, HAS_NUMPY, np, unittest +from pint.compat import HAS_NUMPY, np, unittest +from pint.testsuite import TestCase # Following http://docs.scipy.org/doc/numpy/reference/ufuncs.html diff --git a/pint/testsuite/test_unit.py b/pint/testsuite/test_unit.py index 96b7963..f101a59 100644 --- a/pint/testsuite/test_unit.py +++ b/pint/testsuite/test_unit.py @@ -8,9 +8,11 @@ import operator as op from pint.unit import (ScaleConverter, OffsetConverter, UnitsContainer, Definition, PrefixDefinition, UnitDefinition, - DimensionDefinition) + DimensionDefinition, _freeze) from pint import DimensionalityError, UndefinedUnitError -from pint.testsuite import TestCase, u, unittest +from pint.compat import u, unittest +from pint.testsuite import TestCase + class TestConverter(unittest.TestCase): @@ -24,6 +26,7 @@ class TestConverter(unittest.TestCase): self.assertEqual(c.from_reference(c.to_reference(100)), 100) self.assertEqual(c.to_reference(c.from_reference(100)), 100) + class TestDefinition(unittest.TestCase): def test_prefix_definition(self): @@ -336,3 +339,46 @@ class TestRegistry(TestCase): h2 = ureg.wraps(('meter', 'cm'), [None, None])(hfunc) self.assertEqual(h2(3, 1), (3 * ureg.meter, 1 * ureg.cm)) + + +class TestEquivalents(TestCase): + + FORCE_NDARRAY= False + + def _test(self, input_units): + gd = self.ureg.get_dimensionality + dim = gd(input_units) + equiv = self.ureg.get_compatible_units(input_units) + for eq in equiv: + self.assertEqual(gd(eq), dim) + self.assertEqual(equiv, self.ureg.get_compatible_units(dim)) + + def _test2(self, units1, units2): + equiv1 = self.ureg.get_compatible_units(units1) + equiv2 = self.ureg.get_compatible_units(units2) + self.assertEqual(equiv1, equiv2) + + def test_many(self): + self._test(self.ureg.meter.units) + self._test(self.ureg.seconds.units) + self._test(self.ureg.newton.units) + self._test(self.ureg.kelvin.units) + + def test_context_sp(self): + + + gd = self.ureg.get_dimensionality + + # length, frequency, energy + valid = [gd(self.ureg.meter.units), gd(self.ureg.hertz.units), gd(self.ureg.joule.units)] + + with self.ureg.context('sp'): + equiv = self.ureg.get_compatible_units(self.ureg.meter.units) + result = set() + for eq in equiv: + dim = gd(eq) + result.add(_freeze(dim)) + self.assertIn(dim, valid) + + self.assertEqual(len(result), len(valid)) + diff --git a/pint/testsuite/test_util.py b/pint/testsuite/test_util.py index 9264eb6..5532cd0 100644 --- a/pint/testsuite/test_util.py +++ b/pint/testsuite/test_util.py @@ -5,7 +5,7 @@ from __future__ import division, unicode_literals, print_function, absolute_impo import collections from pint.util import string_preprocessor, find_shortest_path -from pint.testsuite import unittest +from pint.compat import unittest class TestStringProcessor(unittest.TestCase): @@ -15,31 +15,51 @@ class TestStringProcessor(unittest.TestCase): a = pattern.format(aft) self.assertEqual(string_preprocessor(b), a) - def test_rules(self): + def test_square_cube(self): self._test('bcd^3', 'bcd**3') + self._test('bcd^ 3', 'bcd** 3') + self._test('bcd ^3', 'bcd **3') + self._test('bcd squared', 'bcd**2') self._test('bcd squared', 'bcd**2') self._test('bcd cubed', 'bcd**3') self._test('sq bcd', 'bcd**2') self._test('square bcd', 'bcd**2') self._test('cubic bcd', 'bcd**3') self._test('bcd efg', 'bcd*efg') - self._test('bcd efg', 'bcd*efg') + + def test_per(self): self._test('miles per hour', 'miles/hour') + + def test_numbers(self): self._test('1,234,567', '1234567') - self._test('1hour', '1*hour') - self._test('1.1hour', '1.1*hour') self._test('1e-24', '1e-24') self._test('1e+24', '1e+24') self._test('1e24', '1e24') self._test('1E-24', '1E-24') self._test('1E+24', '1E+24') self._test('1E24', '1E24') + + def test_space_multiplication(self): + self._test('bcd efg', 'bcd*efg') + self._test('bcd efg', 'bcd*efg') + self._test('1 hour', '1*hour') + self._test('1. hour', '1.*hour') + self._test('1.1 hour', '1.1*hour') + self._test('1E24 hour', '1E24*hour') + self._test('1E-24 hour', '1E-24*hour') + self._test('1E+24 hour', '1E+24*hour') + self._test('1.2E24 hour', '1.2E24*hour') + self._test('1.2E-24 hour', '1.2E-24*hour') + self._test('1.2E+24 hour', '1.2E+24*hour') + + def names(self): self._test('g_0', 'g_0') - self._test('1g_0', '1*g_0') self._test('g0', 'g0') - self._test('1g0', '1*g0') self._test('g', 'g') - self._test('1g', '1*g') + self._test('water_60F', 'water_60F') + + +class TestGraph(unittest.TestCase): def test_shortest_path(self): g = collections.defaultdict(list) diff --git a/pint/unit.py b/pint/unit.py index cbbaeea..c4fa928 100644 --- a/pint/unit.py +++ b/pint/unit.py @@ -23,11 +23,11 @@ from io import open from numbers import Number from tokenize import untokenize, NUMBER, STRING, NAME, OP -from .context import Context, ContextChain +from .context import Context, ContextChain, _freeze +from .util import (logger, pi_theorem, solve_dependencies, ParserHelper, + string_preprocessor, find_connected_nodes, find_shortest_path) +from .compat import tokenizer, string_types, NUMERIC_TYPES, TransformDict from .formatter import format_unit -from .util import (logger, NUMERIC_TYPES, pi_theorem, solve_dependencies, - ParserHelper, string_types, ptok, string_preprocessor) -from .util import find_shortest_path class UndefinedUnitError(ValueError): @@ -391,6 +391,7 @@ class UnitRegistry(object): def __init__(self, filename='', force_ndarray=False, default_to_delta=True): self.Quantity = build_quantity_class(self, force_ndarray) + self.Measurement = build_measurement_class(self, force_ndarray) #: Map dimension name (string) to its definition (DimensionDefinition). self._dimensions = {} @@ -410,6 +411,14 @@ class UnitRegistry(object): #: Stores active contexts. self._active_ctx = ContextChain() + #: Maps dimensionality (_freeze(UnitsContainer)) to Units (str) + self._dimensional_equivalents = TransformDict(_freeze) + + #: Maps dimensionality (_freeze(UnitsContainer)) to Dimensionality (_freeze(UnitsContainer)) + self._base_units_cache = TransformDict(_freeze) + #: Maps dimensionality (_freeze(UnitsContainer)) to Units (_freeze(UnitsContainer)) + self._dimensionality_cache = TransformDict(_freeze) + #: When performing a multiplication of units, interpret #: non-multiplicative units as their *delta* counterparts. self.default_to_delta = default_to_delta @@ -422,6 +431,8 @@ class UnitRegistry(object): self.define(UnitDefinition('pi', 'π', (), ScaleConverter(math.pi))) + self._build_cache() + def __getattr__(self, item): return self.Quantity(1, item) @@ -666,18 +677,39 @@ class UnitRegistry(object): except Exception as ex: logger.error("In line {0}, cannot add '{1}' {2}".format(no, line, ex)) - def validate(self): - """Walk the registry and calculate for each unit definition - the corresponding base units and dimensionality. + def _build_cache(self): + """Build a cache of dimensionality and base units. """ - deps = dict((name, set(definition.reference.keys())) + deps = dict((name, set(definition.reference.keys() if definition.reference else {})) for name, definition in self._units.items()) for unit_names in solve_dependencies(deps): for unit_name in unit_names: - bu = self.get_base_units(unit_name) - di = self.get_dimensionality(bu) + prefixed = False + for p in self._prefixes.keys(): + if p and unit_name.startswith(p): + prefixed = True + break + if '[' in unit_name: + continue + try: + uc = ParserHelper.from_word(unit_name) + + bu = self.get_base_units(uc) + di = self.get_dimensionality(uc) + + self._base_units_cache[uc] = bu + self._dimensionality_cache[uc] = di + + if not prefixed: + if di not in self._dimensional_equivalents: + self._dimensional_equivalents[di] = set() + + self._dimensional_equivalents[di].add(self._units[unit_name].name) + + except Exception as e: + logger.warning('Could not resolve {}: {!r}'.format(unit_name, e)) def get_name(self, name_or_alias): """Return the canonical name of a unit. @@ -739,6 +771,9 @@ class UnitRegistry(object): if isinstance(input_units, string_types): input_units = ParserHelper.from_string(input_units) + if input_units in self._dimensionality_cache: + return copy.copy(self._dimensionality_cache[input_units]) + for key, value in input_units.items(): if _is_dim(key): reg = self._dimensions[key] @@ -766,9 +801,8 @@ class UnitRegistry(object): :param input_units: units :type input_units: UnitsContainer or str - :param check_nonmult: if True None will be returned as the multiplicative factor - is a non-multiplicative units is found in the final - Units. + :param check_nonmult: if True, None will be returned as the multiplicative factor + is a non-multiplicative units is found in the final Units. :return: multiplicative factor, base units """ if not input_units: @@ -777,6 +811,10 @@ class UnitRegistry(object): if isinstance(input_units, string_types): input_units = ParserHelper.from_string(input_units) + # The cache is only done for check_nonmult=True + if check_nonmult and input_units in self._base_units_cache: + return copy.deepcopy(self._base_units_cache[input_units]) + factor = 1. units = UnitsContainer() for key, value in input_units.items(): @@ -790,7 +828,7 @@ class UnitRegistry(object): factor *= (reg.converter.scale * fac) ** value units *= uni ** value - # Check if any of the final units is non multiplicative and return non instead. + # Check if any of the final units is non multiplicative and return None instead. if check_nonmult: for unit in units.keys(): if not isinstance(self._units[unit].converter, ScaleConverter): @@ -798,6 +836,26 @@ class UnitRegistry(object): return factor, units + def get_compatible_units(self, input_units): + if not input_units: + return 1., UnitsContainer() + + if isinstance(input_units, string_types): + input_units = ParserHelper.from_string(input_units) + + src_dim = self.get_dimensionality(input_units) + + ret = self._dimensional_equivalents[src_dim] + + if self._active_ctx: + nodes = find_connected_nodes(self._active_ctx.graph, _freeze(src_dim)) + ret = set() + if nodes: + for node in nodes: + ret |= self._dimensional_equivalents[node] + + return ret + def convert(self, value, src, dst): """Convert value from some source to destination units. @@ -967,7 +1025,7 @@ class UnitRegistry(object): return self.Quantity(1) input_string = string_preprocessor(input_string) - gen = ptok(input_string) + gen = tokenizer(input_string) result = [] unknown = set() for toknum, tokval, _, _, _ in gen: @@ -1088,3 +1146,22 @@ def build_quantity_class(registry, force_ndarray=False): Quantity.force_ndarray = force_ndarray return Quantity + + +def build_measurement_class(registry, force_ndarray=False): + from .measurement import _Measurement, ufloat + + if ufloat is None: + class Measurement(object): + + def __init__(self, *args): + raise RuntimeError("Pint requires the 'uncertainties' package to create a Measurement object.") + + else: + class Measurement(_Measurement, registry.Quantity): + pass + + Measurement._REGISTRY = registry + Measurement.force_ndarray = force_ndarray + + return Measurement diff --git a/pint/util.py b/pint/util.py index 1fce174..f8ffd15 100644 --- a/pint/util.py +++ b/pint/util.py @@ -12,8 +12,6 @@ from __future__ import division, unicode_literals, print_function, absolute_import import re -import sys -import tokenize import operator from numbers import Number from fractions import Fraction @@ -22,20 +20,11 @@ import logging from token import STRING, NAME, OP from tokenize import untokenize -from .compat import NullHandler +from .compat import string_types, tokenizer, lru_cache, NullHandler logger = logging.getLogger(__name__) logger.addHandler(NullHandler()) -if sys.version < '3': - from StringIO import StringIO - string_types = basestring - ptok = lambda input_string: tokenize.generate_tokens(StringIO(input_string).readline) -else: - from io import BytesIO - string_types = str - ptok = lambda input_string: tokenize.tokenize(BytesIO(input_string.encode('utf-8')).readline) - def matrix_to_string(matrix, row_headers=None, col_headers=None, fmtfun=lambda x: str(int(x))): """Takes a 2D matrix (as nested list) and returns a string. @@ -123,21 +112,6 @@ def column_echelon_form(matrix, ntype=Fraction, transpose_result=False): return _transpose(M), _transpose(I), swapped -try: - import numpy as np - from numpy import ndarray - - HAS_NUMPY = True - NUMERIC_TYPES = (Number, ndarray) - -except ImportError: - - class ndarray(object): - pass - - HAS_NUMPY = False - NUMERIC_TYPES = (Number, ) - def pi_theorem(quantities, registry=None): """Builds dimensionless quantities using the Buckingham π theorem @@ -217,8 +191,8 @@ def solve_dependencies(dependencies): return r -def find_shortest_path(graph, start, end, path=[]): - path = path + [start] +def find_shortest_path(graph, start, end, path=None): + path = (path or []) + [start] if start == end: return path if not start in graph: @@ -233,6 +207,20 @@ def find_shortest_path(graph, start, end, path=[]): return shortest +def find_connected_nodes(graph, start, visited=None): + if not start in graph: + return None + + visited = (visited or set()) + visited.add(start) + + for node in graph[start]: + if node not in visited: + find_connected_nodes(graph, node, visited) + + return visited + + class ParserHelper(dict): """The ParserHelper stores in place the product of variables and their respective exponent and implements the corresponding operations. @@ -256,6 +244,7 @@ class ParserHelper(dict): return ret @classmethod + @lru_cache() def from_string(cls, input_string): """Parse linear expression mathematical units and return a quantity object. """ @@ -271,7 +260,7 @@ class ParserHelper(dict): else: reps = False - gen = ptok(input_string) + gen = tokenizer(input_string) result = [] for toknum, tokval, _, _, _ in gen: if toknum == NAME: @@ -371,13 +360,13 @@ class ParserHelper(dict): #: List of regex substitution pairs. -_subs_re = [(r"({0}) squared", r"\1**2"), # Handle square and cube +_subs_re = [(r"([\w\.\-\+\*\\\^])\s+", r"\1 "), # merge multiple spaces + (r"({0}) squared", r"\1**2"), # Handle square and cube (r"({0}) cubed", r"\1**3"), (r"cubic ({0})", r"\1**3"), (r"square ({0})", r"\1**2"), (r"sq ({0})", r"\1**2"), - (r"(\w)\s+(?=\w)", r"\1*"), # Handle space for multiplication - (r"([0-9])(?={0})(?!(?:[e|E][-+]?[0-9]+))", r"\1*") + (r"([\w\.\-])\s+(?=\w)", r"\1*"), # Handle space for multiplication ] #: Compiles the regex and replace {0} by a regex that matches an identifier. |