summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.travis.yml21
-rw-r--r--bench.py107
-rw-r--r--bench_base.yaml42
-rw-r--r--bench_numpy.yaml23
-rw-r--r--docs/getting.rst2
-rw-r--r--pint/__init__.py2
-rw-r--r--pint/compat/__init__.py112
-rw-r--r--pint/compat/chainmap.py (renamed from pint/compat.py)35
-rw-r--r--pint/compat/lrucache.py177
-rw-r--r--pint/compat/nullhandler.py32
-rw-r--r--pint/compat/transformdict.py136
-rw-r--r--pint/context.py5
-rw-r--r--pint/measurement.py109
-rw-r--r--pint/quantity.py101
-rw-r--r--pint/testsuite/__init__.py30
-rw-r--r--pint/testsuite/test_contexts.py3
-rw-r--r--pint/testsuite/test_issues.py22
-rw-r--r--pint/testsuite/test_measurement.py53
-rw-r--r--pint/testsuite/test_numpy.py3
-rw-r--r--pint/testsuite/test_pitheorem.py2
-rw-r--r--pint/testsuite/test_quantity.py6
-rw-r--r--pint/testsuite/test_umath.py3
-rw-r--r--pint/testsuite/test_unit.py50
-rw-r--r--pint/testsuite/test_util.py36
-rw-r--r--pint/unit.py107
-rw-r--r--pint/util.py55
26 files changed, 983 insertions, 291 deletions
diff --git a/.travis.yml b/.travis.yml
index a50582d..9769ad7 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -7,10 +7,14 @@ python:
- "3.3"
env:
- - NUMPY_VERSION=0
- - NUMPY_VERSION=1.6
- - NUMPY_VERSION=1.7
- - NUMPY_VERSION=1.8
+ - UNCERTAINTIES="N" NUMPY_VERSION=0
+ - UNCERTAINTIES="N" NUMPY_VERSION=1.6
+ - UNCERTAINTIES="N" NUMPY_VERSION=1.7
+ - UNCERTAINTIES="N" NUMPY_VERSION=1.8
+ - UNCERTAINTIES="Y" NUMPY_VERSION=0
+ - UNCERTAINTIES="Y" NUMPY_VERSION=1.6
+ - UNCERTAINTIES="Y" NUMPY_VERSION=1.7
+ - UNCERTAINTIES="Y" NUMPY_VERSION=1.8
branches:
only:
@@ -19,6 +23,7 @@ branches:
install:
- if [[ $TRAVIS_PYTHON_VERSION == '2.6' ]]; then pip install --use-mirrors unittest2; fi
+ - if [[ $UNCERTAINTIES == 'Y' ]]; then pip install --use-mirrors uncertainties; fi
- if [ $NUMPY_VERSION = '0' ]; then pip uninstall -y numpy || true; else pip install numpy==$NUMPY_VERSION; fi
- pip install . --use-mirrors
- pip list --pre
@@ -29,6 +34,10 @@ matrix:
# Don't run with these version combinations as these NumPy packages don't seem to be installable with Travis
exclude:
- python: "3.2"
- env: NUMPY_VERSION=1.6
+ env: UNCERTAINTIES="N" NUMPY_VERSION=1.6
+ - python: "3.2"
+ env: UNCERTAINTIES="Y" NUMPY_VERSION=1.6
+ - python: "3.3"
+ env: UNCERTAINTIES="N" NUMPY_VERSION=1.6
- python: "3.3"
- env: NUMPY_VERSION=1.6
+ env: UNCERTAINTIES="Y" NUMPY_VERSION=1.6
diff --git a/bench.py b/bench.py
new file mode 100644
index 0000000..b0e1efc
--- /dev/null
+++ b/bench.py
@@ -0,0 +1,107 @@
+
+
+from __future__ import division, unicode_literals, print_function, absolute_import
+
+import glob
+import copy
+from timeit import Timer
+
+import yaml
+
+
+def time_stmt(stmt='pass', setup='pass', number=0, repeat=3):
+ """Timer function with the same behaviour as running `python -m timeit `
+ in the command line.
+
+ :return: elapsed time in seconds or NaN if the command failed.
+ :rtype: float
+ """
+
+ t = Timer(stmt, setup)
+
+ if not number:
+ # determine number so that 0.2 <= total time < 2.0
+ for i in range(1, 10):
+ number = 10**i
+
+ try:
+ x = t.timeit(number)
+ except:
+ print(t.print_exc())
+ return float('NaN')
+
+ if x >= 0.2:
+ break
+
+ try:
+ r = t.repeat(repeat, number)
+ except:
+ print(t.print_exc())
+ return float('NaN')
+
+ best = min(r)
+
+ return best / number
+
+
+def build_task(task, name='', setup='', number=0, repeat=3):
+ nt = copy.copy(task)
+
+ nt['name'] = (name + ' ' + task.get('name', '')).strip()
+ nt['setup'] = (setup + '\n' + task.get('setup', '')).strip('\n')
+ nt['stmt'] = task.get('stmt', '')
+ nt['number'] = task.get('number', number)
+ nt['repeat'] = task.get('repeat', repeat)
+
+ return nt
+
+
+def time_task(name, stmt='pass', setup='pass', number=0, repeat=3, stmts='', base=''):
+
+ if base:
+ nvalue = time_stmt(stmt=base, setup=setup, number=number, repeat=repeat)
+ yield name + ' (base)', nvalue
+ suffix = ' (normalized)'
+ else:
+ nvalue = 1.
+ suffix = ''
+
+ if stmt:
+ value = time_stmt(stmt=stmt, setup=setup, number=number, repeat=repeat)
+ yield name, value / nvalue
+
+ for task in stmts:
+ new_task = build_task(task, name, setup, number, repeat)
+ for task_name, value in time_task(**new_task):
+ yield task_name + suffix, value / nvalue
+
+
+def time_file(filename, name='', setup='', number=0, repeat=3):
+ """Open a yaml benchmark file an time each statement,
+
+ yields a tuple with filename, task name, time in seconds.
+ """
+ with open(filename, 'r') as fp:
+ tasks = yaml.load(fp)
+
+ for task in tasks:
+ new_task = build_task(task, name, setup, number, repeat)
+ for task_name, value in time_task(**new_task):
+ yield task_name, value
+
+
+def main(filenames=None):
+ if not filenames:
+ filenames = glob.iglob('bench_*.yaml')
+ elif isinstance(filenames, basestring):
+ filenames = [filenames, ]
+
+ for filename in filenames:
+ print(filename)
+ print('-' * len(filename))
+ print()
+ for task_name, value in time_file(filename):
+ print('%.2e %s' % (value, task_name))
+ print()
+
+main()
diff --git a/bench_base.yaml b/bench_base.yaml
new file mode 100644
index 0000000..f767a04
--- /dev/null
+++ b/bench_base.yaml
@@ -0,0 +1,42 @@
+
+
+- name: importing
+ stmt: import pint
+
+- name: empty registry
+ setup: import pint
+ stmt: ureg = pint.UnitRegistry(None)
+
+- name: default registry
+ setup: import pint
+ stmt: ureg = pint.UnitRegistry()
+
+- name: finding meter
+ setup: |
+ import pint
+ ureg = pint.UnitRegistry()
+ stmts:
+ - name: (attr)
+ stmt: q = ureg.meter
+ - name: (item)
+ stmt: q = ureg['meter']
+
+- name: base units
+ setup: |
+ import pint
+ ureg = pint.UnitRegistry()
+ stmts:
+ - name: meter
+ stmt: ureg.get_base_units('meter')
+ - name: yard
+ stmt: ureg.get_base_units('yard')
+ - name: meter / second
+ stmt: ureg.get_base_units('meter / second')
+ - name: yard / minute
+ stmt: ureg.get_base_units('yard / minute')
+
+- name: build cache
+ setup: |
+ import pint
+ ureg = pint.UnitRegistry()
+ stmt: ureg._build_cache()
diff --git a/bench_numpy.yaml b/bench_numpy.yaml
new file mode 100644
index 0000000..8609bc0
--- /dev/null
+++ b/bench_numpy.yaml
@@ -0,0 +1,23 @@
+
+
+- name: NumPy
+ setup: |
+ import numpy as np
+ import pint
+ ureg = pint.UnitRegistry()
+ stmts:
+ - name: cosine
+ setup: |
+ d = np.arange(0, 90, 10)
+ r = np.deg2rad(d)
+ base: np.cos(r)
+ stmts:
+ - name: radian
+ setup: x = r * ureg.radian
+ stmt: np.cos(x)
+ - name: dimensionless
+ setup: x = r * ureg.dimensionless
+ stmt: np.cos(x)
+ - name: degree
+ setup: x = d * ureg.degree
+ stmt: np.cos(x)
diff --git a/docs/getting.rst b/docs/getting.rst
index 14f1d7f..a845d6b 100644
--- a/docs/getting.rst
+++ b/docs/getting.rst
@@ -3,7 +3,7 @@
Installation
============
-Pint has no dependencies except Python_ itself. In runs on Python 2.7 and 3.0+.
+Pint has no dependencies except Python_ itself. In runs on Python 2.6 and 3.0+.
You can install it using pip_::
diff --git a/pint/__init__.py b/pint/__init__.py
index 45de3b0..049a7d6 100644
--- a/pint/__init__.py
+++ b/pint/__init__.py
@@ -18,7 +18,7 @@ import pkg_resources
from .formatter import formatter
from .unit import UnitRegistry, DimensionalityError, UndefinedUnitError
from .util import pi_theorem, logger
-from .measurement import Measurement
+
from .context import Context
_DEFAULT_REGISTRY = UnitRegistry()
diff --git a/pint/compat/__init__.py b/pint/compat/__init__.py
new file mode 100644
index 0000000..e4221c4
--- /dev/null
+++ b/pint/compat/__init__.py
@@ -0,0 +1,112 @@
+# -*- coding: utf-8 -*-
+"""
+ pint.compat
+ ~~~~~~~~~~~
+
+ Compatibility layer.
+
+ :copyright: 2013 by Pint Authors, see AUTHORS for more details.
+ :license: BSD, see LICENSE for more details.
+"""
+
+from __future__ import division, unicode_literals, print_function, absolute_import
+
+import sys
+import tokenize
+
+from numbers import Number
+from decimal import Decimal
+
+
+PYTHON3 = sys.version >= '3'
+
+if PYTHON3:
+ from io import BytesIO
+ string_types = str
+ tokenizer = lambda input_string: tokenize.tokenize(BytesIO(input_string.encode('utf-8')).readline)
+
+ def u(x):
+ return x
+else:
+ from StringIO import StringIO
+ string_types = basestring
+ tokenizer = lambda input_string: tokenize.generate_tokens(StringIO(input_string).readline)
+
+ import codecs
+ string_types = basestring
+
+ def u(x):
+ return codecs.unicode_escape_decode(x)[0]
+
+
+if sys.version_info < (2, 7):
+ import unittest2 as unittest
+else:
+ import unittest
+
+
+try:
+ from collections import Chainmap
+except ImportError:
+ from .chainmap import ChainMap
+
+try:
+ from collections import TransformDict
+except ImportError:
+ from .transformdict import TransformDict
+
+try:
+ from functools import lru_cache
+except ImportError:
+ from .lrucache import lru_cache
+
+try:
+ from logging import NullHandler
+except ImportError:
+ from .nullhandler import NullHandler
+
+try:
+ import numpy as np
+ from numpy import ndarray
+
+ HAS_NUMPY = True
+ NUMPY_VER = np.__version__
+ NUMERIC_TYPES = (Number, Decimal, ndarray, np.number)
+
+ def _to_magnitude(value, force_ndarray=False):
+ if isinstance(value, (dict, bool)) or value is None:
+ raise TypeError('Invalid magnitude for Quantity: {0!r}'.format(value))
+ elif isinstance(value, string_types) and value == '':
+ raise ValueError('Quantity magnitude cannot be an empty string.')
+ elif isinstance(value, (list, tuple)):
+ return np.asarray(value)
+ if force_ndarray:
+ return np.asarray(value)
+ return value
+
+except ImportError:
+
+ np = None
+
+ class ndarray(object):
+ pass
+
+ HAS_NUMPY = False
+ NUMPY_VER = 0
+ NUMERIC_TYPES = (Number, Decimal)
+
+ def _to_magnitude(value, force_ndarray=False):
+ if isinstance(value, (dict, bool)) or value is None:
+ raise TypeError('Invalid magnitude for Quantity: {0!r}'.format(value))
+ elif isinstance(value, string_types) and value == '':
+ raise ValueError('Quantity magnitude cannot be an empty string.')
+ elif isinstance(value, (list, tuple)):
+ raise TypeError('lists and tuples are valid magnitudes for '
+ 'Quantity only when NumPy is present.')
+ return value
+
+try:
+ from uncertainties import ufloat
+except ImportError:
+ ufloat = None
+
diff --git a/pint/compat.py b/pint/compat/chainmap.py
index 65808f1..f4c9a4e 100644
--- a/pint/compat.py
+++ b/pint/compat/chainmap.py
@@ -1,18 +1,17 @@
# -*- coding: utf-8 -*-
"""
- pint.compat
- ~~~~~~~~~~~
+ pint.compat.chainmap
+ ~~~~~~~~~~~~~~~~~~~~
- Compatibility layer.
+ Taken from the Python 3.3 source code.
- :copyright: 2013 by Pint Authors, see AUTHORS for more details.
- :license: BSD, see LICENSE for more details.
+ :copyright: 2013, PSF
+ :license: PSF License
"""
from __future__ import division, unicode_literals, print_function, absolute_import
import sys
-import logging
from collections import MutableMapping
if sys.version_info < (3, 0):
@@ -152,27 +151,3 @@ class ChainMap(MutableMapping):
def clear(self):
'Clear maps[0], leaving maps[1:] intact.'
self.maps[0].clear()
-
-
-if hasattr(logging, "NullHandler"):
- NullHandler = logging.NullHandler
-else:
- class NullHandler(logging.Handler):
- """
- This handler does nothing. It's intended to be used to avoid the
- "No handlers could be found for logger XXX" one-off warning. This is
- important for library code, which may contain code to log events. If a user
- of the library does not configure logging, the one-off warning might be
- produced; to avoid this, the library developer simply needs to instantiate
- a NullHandler and add it to the top-level logger of the library module or
- package.
- """
- def handle(self, record):
- pass
-
- def emit(self, record):
- pass
-
- def createLock(self):
- self.lock = None
-
diff --git a/pint/compat/lrucache.py b/pint/compat/lrucache.py
new file mode 100644
index 0000000..868b598
--- /dev/null
+++ b/pint/compat/lrucache.py
@@ -0,0 +1,177 @@
+# -*- coding: utf-8 -*-
+"""
+ pint.compat.lrucache
+ ~~~~~~~~~~~~~~~~~~~~
+
+ LRU (least recently used) cache backport.
+
+ From https://code.activestate.com/recipes/578078-py26-and-py30-backport-of-python-33s-lru-cache/
+
+ :copyright: 2004, Raymond Hettinger,
+ :license: MIT License
+"""
+
+from collections import namedtuple
+from functools import update_wrapper
+from threading import RLock
+
+_CacheInfo = namedtuple("CacheInfo", ["hits", "misses", "maxsize", "currsize"])
+
+class _HashedSeq(list):
+ __slots__ = 'hashvalue'
+
+ def __init__(self, tup, hash=hash):
+ self[:] = tup
+ self.hashvalue = hash(tup)
+
+ def __hash__(self):
+ return self.hashvalue
+
+def _make_key(args, kwds, typed,
+ kwd_mark = (object(),),
+ fasttypes = set((int, str, frozenset, type(None))),
+ sorted=sorted, tuple=tuple, type=type, len=len):
+ 'Make a cache key from optionally typed positional and keyword arguments'
+ key = args
+ if kwds:
+ sorted_items = sorted(kwds.items())
+ key += kwd_mark
+ for item in sorted_items:
+ key += item
+ if typed:
+ key += tuple(type(v) for v in args)
+ if kwds:
+ key += tuple(type(v) for k, v in sorted_items)
+ elif len(key) == 1 and type(key[0]) in fasttypes:
+ return key[0]
+ return _HashedSeq(key)
+
+def lru_cache(maxsize=100, typed=False):
+ """Least-recently-used cache decorator.
+
+ If *maxsize* is set to None, the LRU features are disabled and the cache
+ can grow without bound.
+
+ If *typed* is True, arguments of different types will be cached separately.
+ For example, f(3.0) and f(3) will be treated as distinct calls with
+ distinct results.
+
+ Arguments to the cached function must be hashable.
+
+ View the cache statistics named tuple (hits, misses, maxsize, currsize) with
+ f.cache_info(). Clear the cache and statistics with f.cache_clear().
+ Access the underlying function with f.__wrapped__.
+
+ See: http://en.wikipedia.org/wiki/Cache_algorithms#Least_Recently_Used
+
+ """
+
+ # Users should only access the lru_cache through its public API:
+ # cache_info, cache_clear, and f.__wrapped__
+ # The internals of the lru_cache are encapsulated for thread safety and
+ # to allow the implementation to change (including a possible C version).
+
+ def decorating_function(user_function):
+
+ cache = dict()
+ stats = [0, 0] # make statistics updateable non-locally
+ HITS, MISSES = 0, 1 # names for the stats fields
+ make_key = _make_key
+ cache_get = cache.get # bound method to lookup key or return None
+ _len = len # localize the global len() function
+ lock = RLock() # because linkedlist updates aren't threadsafe
+ root = [] # root of the circular doubly linked list
+ root[:] = [root, root, None, None] # initialize by pointing to self
+ nonlocal_root = [root] # make updateable non-locally
+ PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields
+
+ if maxsize == 0:
+
+ def wrapper(*args, **kwds):
+ # no caching, just do a statistics update after a successful call
+ result = user_function(*args, **kwds)
+ stats[MISSES] += 1
+ return result
+
+ elif maxsize is None:
+
+ def wrapper(*args, **kwds):
+ # simple caching without ordering or size limit
+ key = make_key(args, kwds, typed)
+ result = cache_get(key, root) # root used here as a unique not-found sentinel
+ if result is not root:
+ stats[HITS] += 1
+ return result
+ result = user_function(*args, **kwds)
+ cache[key] = result
+ stats[MISSES] += 1
+ return result
+
+ else:
+
+ def wrapper(*args, **kwds):
+ # size limited caching that tracks accesses by recency
+ key = make_key(args, kwds, typed) if kwds or typed else args
+ with lock:
+ link = cache_get(key)
+ if link is not None:
+ # record recent use of the key by moving it to the front of the list
+ root, = nonlocal_root
+ link_prev, link_next, key, result = link
+ link_prev[NEXT] = link_next
+ link_next[PREV] = link_prev
+ last = root[PREV]
+ last[NEXT] = root[PREV] = link
+ link[PREV] = last
+ link[NEXT] = root
+ stats[HITS] += 1
+ return result
+ result = user_function(*args, **kwds)
+ with lock:
+ root, = nonlocal_root
+ if key in cache:
+ # getting here means that this same key was added to the
+ # cache while the lock was released. since the link
+ # update is already done, we need only return the
+ # computed result and update the count of misses.
+ pass
+ elif _len(cache) >= maxsize:
+ # use the old root to store the new key and result
+ oldroot = root
+ oldroot[KEY] = key
+ oldroot[RESULT] = result
+ # empty the oldest link and make it the new root
+ root = nonlocal_root[0] = oldroot[NEXT]
+ oldkey = root[KEY]
+ oldvalue = root[RESULT]
+ root[KEY] = root[RESULT] = None
+ # now update the cache dictionary for the new links
+ del cache[oldkey]
+ cache[key] = oldroot
+ else:
+ # put result in a new link at the front of the list
+ last = root[PREV]
+ link = [last, root, key, result]
+ last[NEXT] = root[PREV] = cache[key] = link
+ stats[MISSES] += 1
+ return result
+
+ def cache_info():
+ """Report cache statistics"""
+ with lock:
+ return _CacheInfo(stats[HITS], stats[MISSES], maxsize, len(cache))
+
+ def cache_clear():
+ """Clear the cache and cache statistics"""
+ with lock:
+ cache.clear()
+ root = nonlocal_root[0]
+ root[:] = [root, root, None, None]
+ stats[:] = [0, 0]
+
+ wrapper.__wrapped__ = user_function
+ wrapper.cache_info = cache_info
+ wrapper.cache_clear = cache_clear
+ return update_wrapper(wrapper, user_function)
+
+ return decorating_function
diff --git a/pint/compat/nullhandler.py b/pint/compat/nullhandler.py
new file mode 100644
index 0000000..288cbb3
--- /dev/null
+++ b/pint/compat/nullhandler.py
@@ -0,0 +1,32 @@
+# -*- coding: utf-8 -*-
+"""
+ pint.compat.nullhandler
+ ~~~~~~~~~~~~~~~~~~~~~~~
+
+ Taken from the Python 2.7 source code.
+
+ :copyright: 2013, PSF
+ :license: PSF License
+"""
+
+
+import logging
+
+class NullHandler(logging.Handler):
+ """
+ This handler does nothing. It's intended to be used to avoid the
+ "No handlers could be found for logger XXX" one-off warning. This is
+ important for library code, which may contain code to log events. If a user
+ of the library does not configure logging, the one-off warning might be
+ produced; to avoid this, the library developer simply needs to instantiate
+ a NullHandler and add it to the top-level logger of the library module or
+ package.
+ """
+ def handle(self, record):
+ pass
+
+ def emit(self, record):
+ pass
+
+ def createLock(self):
+ self.lock = None
diff --git a/pint/compat/transformdict.py b/pint/compat/transformdict.py
new file mode 100644
index 0000000..c01ea30
--- /dev/null
+++ b/pint/compat/transformdict.py
@@ -0,0 +1,136 @@
+# -*- coding: utf-8 -*-
+"""
+ pint.compat.transformdict
+ ~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ Taken from the Python 3.4 source code.
+
+ :copyright: 2013, PSF
+ :license: PSF License
+"""
+
+from collections import MutableMapping
+
+_sentinel = object()
+
+class TransformDict(MutableMapping):
+ '''Dictionary that calls a transformation function when looking
+ up keys, but preserves the original keys.
+
+ >>> d = TransformDict(str.lower)
+ >>> d['Foo'] = 5
+ >>> d['foo'] == d['FOO'] == d['Foo'] == 5
+ True
+ >>> set(d.keys())
+ {'Foo'}
+ '''
+
+ __slots__ = ('_transform', '_original', '_data')
+
+ def __init__(self, transform, init_dict=None, **kwargs):
+ '''Create a new TransformDict with the given *transform* function.
+ *init_dict* and *kwargs* are optional initializers, as in the
+ dict constructor.
+ '''
+ if not callable(transform):
+ raise TypeError("expected a callable, got %r" % transform.__class__)
+ self._transform = transform
+ # transformed => original
+ self._original = {}
+ self._data = {}
+ if init_dict:
+ self.update(init_dict)
+ if kwargs:
+ self.update(kwargs)
+
+ def getitem(self, key):
+ 'D.getitem(key) -> (stored key, value)'
+ transformed = self._transform(key)
+ original = self._original[transformed]
+ value = self._data[transformed]
+ return original, value
+
+ @property
+ def transform_func(self):
+ "This TransformDict's transformation function"
+ return self._transform
+
+ # Minimum set of methods required for MutableMapping
+
+ def __len__(self):
+ return len(self._data)
+
+ def __iter__(self):
+ return iter(self._original.values())
+
+ def __getitem__(self, key):
+ return self._data[self._transform(key)]
+
+ def __setitem__(self, key, value):
+ transformed = self._transform(key)
+ self._data[transformed] = value
+ self._original.setdefault(transformed, key)
+
+ def __delitem__(self, key):
+ transformed = self._transform(key)
+ del self._data[transformed]
+ del self._original[transformed]
+
+ # Methods overriden to mitigate the performance overhead.
+
+ def clear(self):
+ 'D.clear() -> None. Remove all items from D.'
+ self._data.clear()
+ self._original.clear()
+
+ def __contains__(self, key):
+ return self._transform(key) in self._data
+
+ def get(self, key, default=None):
+ 'D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None.'
+ return self._data.get(self._transform(key), default)
+
+ def pop(self, key, default=_sentinel):
+ '''D.pop(k[,d]) -> v, remove specified key and return the corresponding value.
+ If key is not found, d is returned if given, otherwise KeyError is raised.
+ '''
+ transformed = self._transform(key)
+ if default is _sentinel:
+ del self._original[transformed]
+ return self._data.pop(transformed)
+ else:
+ self._original.pop(transformed, None)
+ return self._data.pop(transformed, default)
+
+ def popitem(self):
+ '''D.popitem() -> (k, v), remove and return some (key, value) pair
+ as a 2-tuple; but raise KeyError if D is empty.
+ '''
+ transformed, value = self._data.popitem()
+ return self._original.pop(transformed), value
+
+ # Other methods
+
+ def copy(self):
+ 'D.copy() -> a shallow copy of D'
+ other = self.__class__(self._transform)
+ other._original = self._original.copy()
+ other._data = self._data.copy()
+ return other
+
+ __copy__ = copy
+
+ def __getstate__(self):
+ return (self._transform, self._data, self._original)
+
+ def __setstate__(self, state):
+ self._transform, self._data, self._original = state
+
+ def __repr__(self):
+ try:
+ equiv = dict(self)
+ except TypeError:
+ # Some keys are unhashable, fall back on .items()
+ equiv = list(self.items())
+ return '%s(%r, %s)' % (self.__class__.__name__,
+ self._transform, repr(equiv))
diff --git a/pint/context.py b/pint/context.py
index a59664e..947d643 100644
--- a/pint/context.py
+++ b/pint/context.py
@@ -15,8 +15,9 @@ from __future__ import division, unicode_literals, print_function, absolute_impo
import re
from collections import defaultdict
import weakref
-from pint.compat import ChainMap
-from pint.util import ParserHelper, string_types
+
+from .compat import ChainMap
+from .util import ParserHelper, string_types
#: Regex to match the header parts of a context.
_header_re = re.compile('@context\s*(?P<defaults>\(.*\))?\s+(?P<name>\w+)\s*(=(?P<aliases>.*))*')
diff --git a/pint/measurement.py b/pint/measurement.py
index 346ce0c..e3d5f4c 100644
--- a/pint/measurement.py
+++ b/pint/measurement.py
@@ -9,116 +9,51 @@
from __future__ import division, unicode_literals, print_function, absolute_import
+from .compat import ufloat
-import operator
+MISSING = object()
-
-class Measurement(object):
+class _Measurement(object):
"""Implements a class to describe a quantity with uncertainty.
:param value: The most likely value of the measurement.
:type value: Quantity or Number
:param error: The error or uncertainty of the measurement.
- :type value: Quantity or Number
+ :type error: Quantity or Number
"""
- def __init__(self, value, error):
- if not (value/error).unitless:
- raise ValueError('{0} and {1} have incompatible units'.format(value, error))
+ def __new__(cls, value, error, units=MISSING):
+ if units is MISSING:
+ try:
+ value, units = value.magnitude, value.units
+ except AttributeError:
+ try:
+ value, error, units = value.nominal_value, value.std_dev, error
+ except AttributeError:
+ units = ''
try:
- emag = error.magnitude
+ error = error.to(units).magnitude
except AttributeError:
- emag = error
+ pass
- if emag < 0:
- raise ValueError('The magnitude of the error cannot be negative'.format(value, error))
+ inst = super(_Measurement, cls).__new__(cls, ufloat(value, error), units)
- self._value = value
- self._error = error
+ if error < 0:
+ raise ValueError('The magnitude of the error cannot be negative'.format(value, error))
+ return inst
@property
def value(self):
- return self._value
+ return self._REGISTRY.Quantity(self.magnitude.nominal_value, self.units)
@property
def error(self):
- return self._error
+ return self._REGISTRY.Quantity(self.magnitude.std_dev, self.units)
@property
def rel(self):
- return float(abs(self._error / self._value))
-
- def _add_sub(self, other, operator):
- result = self.value + other.value
- if isinstance(other, self.__class__):
- error = (self.error ** 2.0 + other.error ** 2.0) ** (1/2)
- else:
- error = self.error
- return result.plus_minus(error)
-
- def __add__(self, other):
- return self._add_sub(other, operator.add)
-
- __radd__ = __add__
-
- def __sub__(self, other):
- return self._add_sub(other, operator.sub)
-
- __rsub__ = __sub__
-
- def _mul_div(self, other, operator):
- if isinstance(other, self.__class__):
- result = operator(self.value, other.value)
- return result.plus_minus((self.rel ** 2.0 + other.rel ** 2.0) ** (1/2), relative=True)
- else:
- result = operator(self.value, other)
- return result.plus_minus(abs(operator(self.error, other)))
-
- def __mul__(self, other):
- return self._mul_div(other, operator.mul)
-
- __rmul__ = __mul__
-
- def __truediv__(self, other):
- return self._mul_div(other, operator.truediv)
-
- def __floordiv__(self, other):
- return self._mul_div(other, operator.floordiv)
-
- __div__ = __floordiv__
-
- def __str__(self):
- return '{}'.format(self)
+ return float(abs(self.magnitude.std_dev / self.magnitude.nominal_value))
def __repr__(self):
return "<Measurement({0:!r}, {1:!r})>".format(self._value, self._error)
- def __format__(self, spec):
- if '!' in spec:
- fmt, conv = spec.split('!')
- conv = '!' + conv
- else:
- fmt, conv = spec, ''
-
- left, right = '(', ')'
- if '!l' == conv:
- pm = r'\pm'
- left = r'\left' + left
- right = r'\right' + right
- elif '!p' == conv:
- pm = '±'
- else:
- pm = '+/-'
-
- if hasattr(self.value, 'units'):
- vmag = format(self.value.magnitude, fmt)
- if self.value.units != self.error.units:
- emag = self.error.to(self.value.units).magnitude
- else:
- emag = self.error.magnitude
- emag = format(emag, fmt)
- units = ' ' + format(self.value.units, conv)
- else:
- vmag, emag, units = self.value, self.error, ''
-
- return left + vmag + ' ' + pm + ' ' + emag + right + units if units else ''
diff --git a/pint/quantity.py b/pint/quantity.py
index 853195b..924f74b 100644
--- a/pint/quantity.py
+++ b/pint/quantity.py
@@ -16,33 +16,7 @@ from collections import Iterable
from .formatter import remove_custom_flags
from .unit import DimensionalityError, UnitsContainer, UnitDefinition, UndefinedUnitError
-from .measurement import Measurement
-from .util import string_types, NUMERIC_TYPES, ndarray
-
-try:
- import numpy as np
-
- def _to_magnitude(value, force_ndarray=False):
- if isinstance(value, (dict, bool)) or value is None:
- raise ValueError('Invalid magnitude for Quantity: {!r}'.format(value))
- elif isinstance(value, string_types) and value == '':
- raise ValueError('Quantity magnitude cannot be an empty string.')
- elif isinstance(value, (list, tuple)):
- return np.asarray(value)
- if force_ndarray:
- return np.asarray(value)
- return value
-
-except ImportError:
- def _to_magnitude(value, force_ndarray=False):
- if isinstance(value, (dict, bool)) or value is None:
- raise ValueError('Invalid magnitude for Quantity: {!r}'.format(value))
- elif isinstance(value, string_types) and value == '':
- raise ValueError('Quantity magnitude cannot be an empty string.')
- elif isinstance(value, (list, tuple)):
- raise ValueError('lists and tuples are valid magnitudes for '
- 'Quantity only when NumPy is present.')
- return value
+from .compat import string_types, ndarray, np, _to_magnitude
def _eq(first, second, check_all):
@@ -84,7 +58,7 @@ def _check(q1, other):
raise ValueError('Cannot operate between quantities of different registries')
-def _has_multiplicative_units(q):
+def _only_multiplicative_units(q):
"""Check if the quantity has non-multiplicative units.
"""
@@ -118,7 +92,7 @@ class _Quantity(object):
if units is None:
if isinstance(value, string_types):
if value == '':
- raise ValueError('Quantity magnitude cannot be an empty string.')
+ raise ValueError('Expression to parse as Quantity cannot be an empty string.')
inst = cls._REGISTRY.parse_expression(value)
return cls.__new__(cls, inst)
elif isinstance(value, cls):
@@ -142,11 +116,18 @@ class _Quantity(object):
raise TypeError('units must be of type str, Quantity or '
'UnitsContainer; not {0}.'.format(type(units)))
+ inst.__used = False
inst.__handling = None
return inst
+ @property
+ def debug_used(self):
+ return self.__used
+
def __copy__(self):
- return self.__class__(copy.copy(self._magnitude), copy.copy(self._units))
+ ret = self.__class__(copy.copy(self._magnitude), copy.copy(self._units))
+ ret.__used = self.__used
+ return ret
def __str__(self):
return '{0} {1}'.format(self._magnitude, self._units)
@@ -213,6 +194,13 @@ class _Quantity(object):
return self._dimensionality
+ def compatible_units(self, *contexts):
+ if contexts:
+ with self._REGISTRY.context(*contexts):
+ return self._REGISTRY.get_compatible_units(self._units)
+
+ return self._REGISTRY.get_compatible_units(self._units)
+
def _convert_magnitude(self, other, *contexts, **ctx_kwargs):
if contexts:
with self._REGISTRY.context(*contexts, **ctx_kwargs):
@@ -279,6 +267,13 @@ class _Quantity(object):
raise DimensionalityError(self.units, 'dimensionless')
def _iadd_sub(self, other, op):
+ """Perform addition or subtraction operation in-place and return the result.
+
+ :param other: object to be added to / subtracted from self
+ :type other: Quantity or any type accepted by :func:`_to_magnitude`
+ :param op: operator function (e.g. operator.add, operator.isub)
+ :type op: function
+ """
if _check(self, other):
if not self.dimensionality == other.dimensionality:
raise DimensionalityError(self.units, other.units,
@@ -288,15 +283,26 @@ class _Quantity(object):
else:
self._magnitude = op(self._magnitude, other.to(self)._magnitude)
else:
+ try:
+ other_magnitude = _to_magnitude(other, self.force_ndarray)
+ except TypeError:
+ return NotImplemented
if self.dimensionless:
self.ito(UnitsContainer())
- self._magnitude = op(self._magnitude, _to_magnitude(other, self.force_ndarray))
+ self._magnitude = op(self._magnitude, other_magnitude)
else:
raise DimensionalityError(self.units, 'dimensionless')
return self
def _add_sub(self, other, op):
+ """Perform addition or subtraction operation and return the result.
+
+ :param other: object to be added to / subtracted from self
+ :type other: Quantity or any type accepted by :func:`_to_magnitude`
+ :param op: operator function (e.g. operator.add, operator.isub)
+ :type op: function
+ """
if _check(self, other):
if not self.dimensionality == other.dimensionality:
raise DimensionalityError(self.units, other.units,
@@ -337,18 +343,24 @@ class _Quantity(object):
def _imul_div(self, other, magnitude_op, units_op=None):
"""Perform multiplication or division operation in-place and return the result.
- Arguments:
- other -- object to be multiplied/divided with self
- magnitude_op -- operator function to perform on the magnitudes (e.g. operator.mul)
- units_op -- operator function to perform on the units; if None, magnitude_op is used
-
+ :param other: object to be multiplied/divided with self
+ :type other: Quantity or any type accepted by :func:`_to_magnitude`
+ :param magnitude_op: operator function to perform on the magnitudes (e.g. operator.mul)
+ :type magnitude_op: function
+ :param units_op: operator function to perform on the units; if None, *magnitude_op* is used
+ :type units_op: function or None
"""
if units_op is None:
units_op = magnitude_op
- if _check(self, other):
- if not _has_multiplicative_units(self):
+
+ if self.__used:
+ if not _only_multiplicative_units(self):
self.ito_base_units()
- if not _has_multiplicative_units(other):
+ else:
+ self.__used = True
+
+ if _check(self, other):
+ if not _only_multiplicative_units(other):
other = other.to_base_units()
self._magnitude = magnitude_op(self._magnitude, other._magnitude)
self._units = units_op(self._units, other._units)
@@ -410,7 +422,7 @@ class _Quantity(object):
except TypeError:
return NotImplemented
else:
- if not _has_multiplicative_units(self):
+ if not _only_multiplicative_units(self):
self.ito_base_units()
self._magnitude **= _to_magnitude(other, self.force_ndarray)
self._units **= other
@@ -785,11 +797,10 @@ class _Quantity(object):
def plus_minus(self, error, relative=False):
if isinstance(error, self.__class__):
if relative:
- raise ValueError('{0} is not a valid relative error.'.format(error))
+ raise ValueError('{} is not a valid relative error.'.format(error))
+ error = error.to(self.units).magnitude
else:
if relative:
- error = error * abs(self)
- else:
- error = self.__class__(error, self.units)
+ error = error * abs(self.magnitude)
- return Measurement(copy.copy(self), error)
+ return self._REGISTRY.Measurement(copy.copy(self.magnitude), error, self.units)
diff --git a/pint/testsuite/__init__.py b/pint/testsuite/__init__.py
index f8e94aa..528ead7 100644
--- a/pint/testsuite/__init__.py
+++ b/pint/testsuite/__init__.py
@@ -3,37 +3,9 @@
from __future__ import division, unicode_literals, print_function, absolute_import
import os
-import sys
import logging
-if sys.version_info < (2, 7):
- import unittest2 as unittest
-else:
- import unittest
-
-try:
- import numpy as np
- HAS_NUMPY = True
- ndarray = np.ndarray
- NUMPY_VER = np.__version__
-except ImportError:
- np = None
- HAS_NUMPY = False
- NUMPY_VER = 0
- class ndarray(object):
- pass
-
-PYTHON3 = sys.version >= '3'
-
-if PYTHON3:
- string_types = str
- def u(x):
- return x
-else:
- import codecs
- string_types = basestring
- def u(x):
- return codecs.unicode_escape_decode(x)[0]
+from pint.compat import ndarray, unittest
from pint import logger, UnitRegistry
diff --git a/pint/testsuite/test_contexts.py b/pint/testsuite/test_contexts.py
index 4f93a9c..598f138 100644
--- a/pint/testsuite/test_contexts.py
+++ b/pint/testsuite/test_contexts.py
@@ -8,7 +8,8 @@ from collections import defaultdict
from pint import UnitRegistry
from pint.context import Context, _freeze
from pint.unit import UnitsContainer
-from pint.testsuite import TestCase, unittest
+from pint.testsuite import TestCase
+from pint.compat import unittest
from pint import logger
diff --git a/pint/testsuite/test_issues.py b/pint/testsuite/test_issues.py
index 0eaecaa..57d675b 100644
--- a/pint/testsuite/test_issues.py
+++ b/pint/testsuite/test_issues.py
@@ -6,7 +6,9 @@ from pint import UnitRegistry
from pint.unit import UnitsContainer
from pint.util import ParserHelper
-from pint.testsuite import HAS_NUMPY, np, TestCase, ndarray, NUMPY_VER, unittest
+from pint.compat import HAS_NUMPY, np, unittest
+from pint.testsuite import TestCase
+
class TestIssues(TestCase):
@@ -62,17 +64,19 @@ class TestIssues(TestCase):
def test_issue61(self):
ureg = UnitRegistry()
Q_ = ureg.Quantity
- for value in ({}, {'a': 3}, '', None, True, False):
- self.assertRaises(ValueError, Q_, value)
- self.assertRaises(ValueError, Q_, value, 'meter')
+ for value in ({}, {'a': 3}, None):
+ self.assertRaises(TypeError, Q_, value)
+ self.assertRaises(TypeError, Q_, value, 'meter')
+ self.assertRaises(ValueError, Q_, '', 'meter')
+ self.assertRaises(ValueError, Q_, '')
@unittest.skipIf(HAS_NUMPY, 'Numpy present')
def test_issue61_notNP(self):
ureg = UnitRegistry()
Q_ = ureg.Quantity
for value in ([1, 2, 3], (1, 2, 3)):
- self.assertRaises(ValueError, Q_, value)
- self.assertRaises(ValueError, Q_, value, 'meter')
+ self.assertRaises(TypeError, Q_, value)
+ self.assertRaises(TypeError, Q_, value, 'meter')
def test_issue66(self):
ureg = UnitRegistry()
@@ -163,6 +167,12 @@ class TestIssues(TestCase):
self.assertAlmostEqual(v1.to_base_units(), v2)
self.assertAlmostEqual(v1.to_base_units(), v2.to_base_units())
+ def test_issue86c(self):
+ ureg = self.ureg
+ T = ureg.degC
+ T = 100. * T
+ self.assertAlmostEqual(ureg.k*2*T, ureg.k*(2*T))
+
def test_issue93(self):
ureg = UnitRegistry()
self.assertIsInstance(ureg.meter.magnitude, int)
diff --git a/pint/testsuite/test_measurement.py b/pint/testsuite/test_measurement.py
index c4cc14b..731ded8 100644
--- a/pint/testsuite/test_measurement.py
+++ b/pint/testsuite/test_measurement.py
@@ -2,21 +2,28 @@
from __future__ import division, unicode_literals, print_function, absolute_import
-from pint import Measurement
-from pint.testsuite import TestCase
-
+from pint.compat import ufloat
+from pint.testsuite import TestCase, unittest
+@unittest.skipIf(ufloat is None, 'Uncertainties not installed.')
class TestMeasurement(TestCase):
FORCE_NDARRAY = False
+ def test_simple(self):
+ M_ = self.ureg.Measurement
+ M_(4.0, 0.1, 's')
+
def test_build(self):
+ M_ = self.ureg.Measurement
v, u = self.Q_(4.0, 's'), self.Q_(.1, 's')
-
- ms = (Measurement(v, u),
+ M_(v.magnitude, u.magnitude, 's')
+ ms = (M_(v.magnitude, u.magnitude, 's'),
+ M_(v, u.magnitude),
+ M_(v, u),
v.plus_minus(.1),
v.plus_minus(0.025, True),
- v.plus_minus(u))
+ v.plus_minus(u),)
for m in ms:
self.assertEqual(m.value, v)
@@ -25,7 +32,7 @@ class TestMeasurement(TestCase):
def _test_format(self):
v, u = self.Q_(4.0, 's'), self.Q_(.1, 's')
- m = Measurement(v, u)
+ m = self.ureg.Measurement(v, u)
print(str(m))
print(repr(m))
print('{:!s}'.format(m))
@@ -39,8 +46,8 @@ class TestMeasurement(TestCase):
v, u = self.Q_(1.0, 's'), self.Q_(.1, 's')
o = self.Q_(.1, 'm')
- self.assertRaises(ValueError, Measurement, u, 1)
- self.assertRaises(ValueError, Measurement, u, o)
+ M_ = self.ureg.Measurement
+ self.assertRaises(ValueError, M_, v, o)
self.assertRaises(ValueError, v.plus_minus, o)
self.assertRaises(ValueError, v.plus_minus, u, True)
@@ -56,18 +63,26 @@ class TestMeasurement(TestCase):
for factor, m in zip((3, -3, 3, -3), (m1, m3, m1, m3)):
r = factor * m
- self.assertAlmostEqual(r.value, factor * m.value)
- self.assertAlmostEqual(r.error ** 2.0, (factor * m.error) **2.0)
+ self.assertAlmostEqual(r.value.magnitude, factor * m.value.magnitude)
+ self.assertAlmostEqual(r.error.magnitude, abs(factor * m.error.magnitude))
+ self.assertEqual(r.value.units, m.value.units)
for ml, mr in zip((m1, m1, m1, m3), (m1, m2, m3, m3)):
r = ml + mr
- self.assertAlmostEqual(r.value, ml.value + mr.value)
- self.assertAlmostEqual(r.error ** 2.0, ml.error **2.0 + mr.error ** 2.0)
+ self.assertAlmostEqual(r.value.magnitude, ml.value.magnitude + mr.value.magnitude)
+ self.assertAlmostEqual(r.error.magnitude,
+ ml.error.magnitude + mr.error.magnitude if ml is mr else
+ (ml.error.magnitude ** 2 + mr.error.magnitude ** 2) ** .5)
+ self.assertEqual(r.value.units, ml.value.units)
for ml, mr in zip((m1, m1, m1, m3), (m1, m2, m3, m3)):
r = ml - mr
- self.assertAlmostEqual(r.value, ml.value + mr.value)
- self.assertAlmostEqual(r.error ** 2.0, ml.error **2.0 + mr.error ** 2.0)
+ print(ml, mr, ml is mr, r)
+ self.assertAlmostEqual(r.value.magnitude, ml.value.magnitude - mr.value.magnitude)
+ self.assertAlmostEqual(r.error.magnitude,
+ 0 if ml is mr else
+ (ml.error.magnitude ** 2 + mr.error.magnitude ** 2) ** .5)
+ self.assertEqual(r.value.units, ml.value.units)
def test_propagate_product(self):
@@ -84,10 +99,10 @@ class TestMeasurement(TestCase):
for ml, mr in zip((m1, m1, m1, m3, m4), (m1, m2, m3, m3, m5)):
r = ml * mr
- self.assertAlmostEqual(r.value, ml.value * mr.value)
- self.assertAlmostEqual(r.rel ** 2.0, ml.rel ** 2.0 + mr.rel ** 2.0)
+ self.assertAlmostEqual(r.value.magnitude, ml.value.magnitude * mr.value.magnitude)
+ self.assertEqual(r.value.units, ml.value.units * mr.value.units)
for ml, mr in zip((m1, m1, m1, m3, m4), (m1, m2, m3, m3, m5)):
r = ml / mr
- self.assertAlmostEqual(r.value, ml.value / mr.value)
- self.assertAlmostEqual(r.rel ** 2.0, ml.rel ** 2.0 + mr.rel ** 2.0)
+ self.assertAlmostEqual(r.value.magnitude, ml.value.magnitude / mr.value.magnitude)
+ self.assertEqual(r.value.units, ml.value.units / mr.value.units)
diff --git a/pint/testsuite/test_numpy.py b/pint/testsuite/test_numpy.py
index 1bf45ba..5063e1c 100644
--- a/pint/testsuite/test_numpy.py
+++ b/pint/testsuite/test_numpy.py
@@ -2,7 +2,8 @@
from __future__ import division, unicode_literals, print_function, absolute_import
-from pint.testsuite import TestCase, HAS_NUMPY, np, unittest
+from pint.compat import HAS_NUMPY, np, unittest
+from pint.testsuite import TestCase
@unittest.skipUnless(HAS_NUMPY, 'Numpy not present')
class TestNumpyMethods(TestCase):
diff --git a/pint/testsuite/test_pitheorem.py b/pint/testsuite/test_pitheorem.py
index 88f7fbb..3f5a2a1 100644
--- a/pint/testsuite/test_pitheorem.py
+++ b/pint/testsuite/test_pitheorem.py
@@ -6,7 +6,7 @@ import itertools
from pint import pi_theorem
-from pint.testsuite import TestCase, unittest
+from pint.testsuite import TestCase
class TestPiTheorem(TestCase):
diff --git a/pint/testsuite/test_quantity.py b/pint/testsuite/test_quantity.py
index 3824102..088ea4f 100644
--- a/pint/testsuite/test_quantity.py
+++ b/pint/testsuite/test_quantity.py
@@ -6,10 +6,10 @@ import copy
import math
import operator as op
+from pint import DimensionalityError, UnitRegistry
from pint.unit import UnitsContainer
-from pint import DimensionalityError, UndefinedUnitError, UnitRegistry
-
-from pint.testsuite import TestCase, string_types, PYTHON3
+from pint.compat import string_types, PYTHON3
+from pint.testsuite import TestCase
class TestQuantity(TestCase):
diff --git a/pint/testsuite/test_umath.py b/pint/testsuite/test_umath.py
index 5d1f0e8..a44993a 100644
--- a/pint/testsuite/test_umath.py
+++ b/pint/testsuite/test_umath.py
@@ -2,7 +2,8 @@
from __future__ import division, unicode_literals, print_function, absolute_import
-from pint.testsuite import TestCase, HAS_NUMPY, np, unittest
+from pint.compat import HAS_NUMPY, np, unittest
+from pint.testsuite import TestCase
# Following http://docs.scipy.org/doc/numpy/reference/ufuncs.html
diff --git a/pint/testsuite/test_unit.py b/pint/testsuite/test_unit.py
index 96b7963..f101a59 100644
--- a/pint/testsuite/test_unit.py
+++ b/pint/testsuite/test_unit.py
@@ -8,9 +8,11 @@ import operator as op
from pint.unit import (ScaleConverter, OffsetConverter, UnitsContainer,
Definition, PrefixDefinition, UnitDefinition,
- DimensionDefinition)
+ DimensionDefinition, _freeze)
from pint import DimensionalityError, UndefinedUnitError
-from pint.testsuite import TestCase, u, unittest
+from pint.compat import u, unittest
+from pint.testsuite import TestCase
+
class TestConverter(unittest.TestCase):
@@ -24,6 +26,7 @@ class TestConverter(unittest.TestCase):
self.assertEqual(c.from_reference(c.to_reference(100)), 100)
self.assertEqual(c.to_reference(c.from_reference(100)), 100)
+
class TestDefinition(unittest.TestCase):
def test_prefix_definition(self):
@@ -336,3 +339,46 @@ class TestRegistry(TestCase):
h2 = ureg.wraps(('meter', 'cm'), [None, None])(hfunc)
self.assertEqual(h2(3, 1), (3 * ureg.meter, 1 * ureg.cm))
+
+
+class TestEquivalents(TestCase):
+
+ FORCE_NDARRAY= False
+
+ def _test(self, input_units):
+ gd = self.ureg.get_dimensionality
+ dim = gd(input_units)
+ equiv = self.ureg.get_compatible_units(input_units)
+ for eq in equiv:
+ self.assertEqual(gd(eq), dim)
+ self.assertEqual(equiv, self.ureg.get_compatible_units(dim))
+
+ def _test2(self, units1, units2):
+ equiv1 = self.ureg.get_compatible_units(units1)
+ equiv2 = self.ureg.get_compatible_units(units2)
+ self.assertEqual(equiv1, equiv2)
+
+ def test_many(self):
+ self._test(self.ureg.meter.units)
+ self._test(self.ureg.seconds.units)
+ self._test(self.ureg.newton.units)
+ self._test(self.ureg.kelvin.units)
+
+ def test_context_sp(self):
+
+
+ gd = self.ureg.get_dimensionality
+
+ # length, frequency, energy
+ valid = [gd(self.ureg.meter.units), gd(self.ureg.hertz.units), gd(self.ureg.joule.units)]
+
+ with self.ureg.context('sp'):
+ equiv = self.ureg.get_compatible_units(self.ureg.meter.units)
+ result = set()
+ for eq in equiv:
+ dim = gd(eq)
+ result.add(_freeze(dim))
+ self.assertIn(dim, valid)
+
+ self.assertEqual(len(result), len(valid))
+
diff --git a/pint/testsuite/test_util.py b/pint/testsuite/test_util.py
index 9264eb6..5532cd0 100644
--- a/pint/testsuite/test_util.py
+++ b/pint/testsuite/test_util.py
@@ -5,7 +5,7 @@ from __future__ import division, unicode_literals, print_function, absolute_impo
import collections
from pint.util import string_preprocessor, find_shortest_path
-from pint.testsuite import unittest
+from pint.compat import unittest
class TestStringProcessor(unittest.TestCase):
@@ -15,31 +15,51 @@ class TestStringProcessor(unittest.TestCase):
a = pattern.format(aft)
self.assertEqual(string_preprocessor(b), a)
- def test_rules(self):
+ def test_square_cube(self):
self._test('bcd^3', 'bcd**3')
+ self._test('bcd^ 3', 'bcd** 3')
+ self._test('bcd ^3', 'bcd **3')
+ self._test('bcd squared', 'bcd**2')
self._test('bcd squared', 'bcd**2')
self._test('bcd cubed', 'bcd**3')
self._test('sq bcd', 'bcd**2')
self._test('square bcd', 'bcd**2')
self._test('cubic bcd', 'bcd**3')
self._test('bcd efg', 'bcd*efg')
- self._test('bcd efg', 'bcd*efg')
+
+ def test_per(self):
self._test('miles per hour', 'miles/hour')
+
+ def test_numbers(self):
self._test('1,234,567', '1234567')
- self._test('1hour', '1*hour')
- self._test('1.1hour', '1.1*hour')
self._test('1e-24', '1e-24')
self._test('1e+24', '1e+24')
self._test('1e24', '1e24')
self._test('1E-24', '1E-24')
self._test('1E+24', '1E+24')
self._test('1E24', '1E24')
+
+ def test_space_multiplication(self):
+ self._test('bcd efg', 'bcd*efg')
+ self._test('bcd efg', 'bcd*efg')
+ self._test('1 hour', '1*hour')
+ self._test('1. hour', '1.*hour')
+ self._test('1.1 hour', '1.1*hour')
+ self._test('1E24 hour', '1E24*hour')
+ self._test('1E-24 hour', '1E-24*hour')
+ self._test('1E+24 hour', '1E+24*hour')
+ self._test('1.2E24 hour', '1.2E24*hour')
+ self._test('1.2E-24 hour', '1.2E-24*hour')
+ self._test('1.2E+24 hour', '1.2E+24*hour')
+
+ def names(self):
self._test('g_0', 'g_0')
- self._test('1g_0', '1*g_0')
self._test('g0', 'g0')
- self._test('1g0', '1*g0')
self._test('g', 'g')
- self._test('1g', '1*g')
+ self._test('water_60F', 'water_60F')
+
+
+class TestGraph(unittest.TestCase):
def test_shortest_path(self):
g = collections.defaultdict(list)
diff --git a/pint/unit.py b/pint/unit.py
index cbbaeea..c4fa928 100644
--- a/pint/unit.py
+++ b/pint/unit.py
@@ -23,11 +23,11 @@ from io import open
from numbers import Number
from tokenize import untokenize, NUMBER, STRING, NAME, OP
-from .context import Context, ContextChain
+from .context import Context, ContextChain, _freeze
+from .util import (logger, pi_theorem, solve_dependencies, ParserHelper,
+ string_preprocessor, find_connected_nodes, find_shortest_path)
+from .compat import tokenizer, string_types, NUMERIC_TYPES, TransformDict
from .formatter import format_unit
-from .util import (logger, NUMERIC_TYPES, pi_theorem, solve_dependencies,
- ParserHelper, string_types, ptok, string_preprocessor)
-from .util import find_shortest_path
class UndefinedUnitError(ValueError):
@@ -391,6 +391,7 @@ class UnitRegistry(object):
def __init__(self, filename='', force_ndarray=False, default_to_delta=True):
self.Quantity = build_quantity_class(self, force_ndarray)
+ self.Measurement = build_measurement_class(self, force_ndarray)
#: Map dimension name (string) to its definition (DimensionDefinition).
self._dimensions = {}
@@ -410,6 +411,14 @@ class UnitRegistry(object):
#: Stores active contexts.
self._active_ctx = ContextChain()
+ #: Maps dimensionality (_freeze(UnitsContainer)) to Units (str)
+ self._dimensional_equivalents = TransformDict(_freeze)
+
+ #: Maps dimensionality (_freeze(UnitsContainer)) to Dimensionality (_freeze(UnitsContainer))
+ self._base_units_cache = TransformDict(_freeze)
+ #: Maps dimensionality (_freeze(UnitsContainer)) to Units (_freeze(UnitsContainer))
+ self._dimensionality_cache = TransformDict(_freeze)
+
#: When performing a multiplication of units, interpret
#: non-multiplicative units as their *delta* counterparts.
self.default_to_delta = default_to_delta
@@ -422,6 +431,8 @@ class UnitRegistry(object):
self.define(UnitDefinition('pi', 'π', (), ScaleConverter(math.pi)))
+ self._build_cache()
+
def __getattr__(self, item):
return self.Quantity(1, item)
@@ -666,18 +677,39 @@ class UnitRegistry(object):
except Exception as ex:
logger.error("In line {0}, cannot add '{1}' {2}".format(no, line, ex))
- def validate(self):
- """Walk the registry and calculate for each unit definition
- the corresponding base units and dimensionality.
+ def _build_cache(self):
+ """Build a cache of dimensionality and base units.
"""
- deps = dict((name, set(definition.reference.keys()))
+ deps = dict((name, set(definition.reference.keys() if definition.reference else {}))
for name, definition in self._units.items())
for unit_names in solve_dependencies(deps):
for unit_name in unit_names:
- bu = self.get_base_units(unit_name)
- di = self.get_dimensionality(bu)
+ prefixed = False
+ for p in self._prefixes.keys():
+ if p and unit_name.startswith(p):
+ prefixed = True
+ break
+ if '[' in unit_name:
+ continue
+ try:
+ uc = ParserHelper.from_word(unit_name)
+
+ bu = self.get_base_units(uc)
+ di = self.get_dimensionality(uc)
+
+ self._base_units_cache[uc] = bu
+ self._dimensionality_cache[uc] = di
+
+ if not prefixed:
+ if di not in self._dimensional_equivalents:
+ self._dimensional_equivalents[di] = set()
+
+ self._dimensional_equivalents[di].add(self._units[unit_name].name)
+
+ except Exception as e:
+ logger.warning('Could not resolve {}: {!r}'.format(unit_name, e))
def get_name(self, name_or_alias):
"""Return the canonical name of a unit.
@@ -739,6 +771,9 @@ class UnitRegistry(object):
if isinstance(input_units, string_types):
input_units = ParserHelper.from_string(input_units)
+ if input_units in self._dimensionality_cache:
+ return copy.copy(self._dimensionality_cache[input_units])
+
for key, value in input_units.items():
if _is_dim(key):
reg = self._dimensions[key]
@@ -766,9 +801,8 @@ class UnitRegistry(object):
:param input_units: units
:type input_units: UnitsContainer or str
- :param check_nonmult: if True None will be returned as the multiplicative factor
- is a non-multiplicative units is found in the final
- Units.
+ :param check_nonmult: if True, None will be returned as the multiplicative factor
+ is a non-multiplicative units is found in the final Units.
:return: multiplicative factor, base units
"""
if not input_units:
@@ -777,6 +811,10 @@ class UnitRegistry(object):
if isinstance(input_units, string_types):
input_units = ParserHelper.from_string(input_units)
+ # The cache is only done for check_nonmult=True
+ if check_nonmult and input_units in self._base_units_cache:
+ return copy.deepcopy(self._base_units_cache[input_units])
+
factor = 1.
units = UnitsContainer()
for key, value in input_units.items():
@@ -790,7 +828,7 @@ class UnitRegistry(object):
factor *= (reg.converter.scale * fac) ** value
units *= uni ** value
- # Check if any of the final units is non multiplicative and return non instead.
+ # Check if any of the final units is non multiplicative and return None instead.
if check_nonmult:
for unit in units.keys():
if not isinstance(self._units[unit].converter, ScaleConverter):
@@ -798,6 +836,26 @@ class UnitRegistry(object):
return factor, units
+ def get_compatible_units(self, input_units):
+ if not input_units:
+ return 1., UnitsContainer()
+
+ if isinstance(input_units, string_types):
+ input_units = ParserHelper.from_string(input_units)
+
+ src_dim = self.get_dimensionality(input_units)
+
+ ret = self._dimensional_equivalents[src_dim]
+
+ if self._active_ctx:
+ nodes = find_connected_nodes(self._active_ctx.graph, _freeze(src_dim))
+ ret = set()
+ if nodes:
+ for node in nodes:
+ ret |= self._dimensional_equivalents[node]
+
+ return ret
+
def convert(self, value, src, dst):
"""Convert value from some source to destination units.
@@ -967,7 +1025,7 @@ class UnitRegistry(object):
return self.Quantity(1)
input_string = string_preprocessor(input_string)
- gen = ptok(input_string)
+ gen = tokenizer(input_string)
result = []
unknown = set()
for toknum, tokval, _, _, _ in gen:
@@ -1088,3 +1146,22 @@ def build_quantity_class(registry, force_ndarray=False):
Quantity.force_ndarray = force_ndarray
return Quantity
+
+
+def build_measurement_class(registry, force_ndarray=False):
+ from .measurement import _Measurement, ufloat
+
+ if ufloat is None:
+ class Measurement(object):
+
+ def __init__(self, *args):
+ raise RuntimeError("Pint requires the 'uncertainties' package to create a Measurement object.")
+
+ else:
+ class Measurement(_Measurement, registry.Quantity):
+ pass
+
+ Measurement._REGISTRY = registry
+ Measurement.force_ndarray = force_ndarray
+
+ return Measurement
diff --git a/pint/util.py b/pint/util.py
index 1fce174..f8ffd15 100644
--- a/pint/util.py
+++ b/pint/util.py
@@ -12,8 +12,6 @@
from __future__ import division, unicode_literals, print_function, absolute_import
import re
-import sys
-import tokenize
import operator
from numbers import Number
from fractions import Fraction
@@ -22,20 +20,11 @@ import logging
from token import STRING, NAME, OP
from tokenize import untokenize
-from .compat import NullHandler
+from .compat import string_types, tokenizer, lru_cache, NullHandler
logger = logging.getLogger(__name__)
logger.addHandler(NullHandler())
-if sys.version < '3':
- from StringIO import StringIO
- string_types = basestring
- ptok = lambda input_string: tokenize.generate_tokens(StringIO(input_string).readline)
-else:
- from io import BytesIO
- string_types = str
- ptok = lambda input_string: tokenize.tokenize(BytesIO(input_string.encode('utf-8')).readline)
-
def matrix_to_string(matrix, row_headers=None, col_headers=None, fmtfun=lambda x: str(int(x))):
"""Takes a 2D matrix (as nested list) and returns a string.
@@ -123,21 +112,6 @@ def column_echelon_form(matrix, ntype=Fraction, transpose_result=False):
return _transpose(M), _transpose(I), swapped
-try:
- import numpy as np
- from numpy import ndarray
-
- HAS_NUMPY = True
- NUMERIC_TYPES = (Number, ndarray)
-
-except ImportError:
-
- class ndarray(object):
- pass
-
- HAS_NUMPY = False
- NUMERIC_TYPES = (Number, )
-
def pi_theorem(quantities, registry=None):
"""Builds dimensionless quantities using the Buckingham π theorem
@@ -217,8 +191,8 @@ def solve_dependencies(dependencies):
return r
-def find_shortest_path(graph, start, end, path=[]):
- path = path + [start]
+def find_shortest_path(graph, start, end, path=None):
+ path = (path or []) + [start]
if start == end:
return path
if not start in graph:
@@ -233,6 +207,20 @@ def find_shortest_path(graph, start, end, path=[]):
return shortest
+def find_connected_nodes(graph, start, visited=None):
+ if not start in graph:
+ return None
+
+ visited = (visited or set())
+ visited.add(start)
+
+ for node in graph[start]:
+ if node not in visited:
+ find_connected_nodes(graph, node, visited)
+
+ return visited
+
+
class ParserHelper(dict):
"""The ParserHelper stores in place the product of variables and
their respective exponent and implements the corresponding operations.
@@ -256,6 +244,7 @@ class ParserHelper(dict):
return ret
@classmethod
+ @lru_cache()
def from_string(cls, input_string):
"""Parse linear expression mathematical units and return a quantity object.
"""
@@ -271,7 +260,7 @@ class ParserHelper(dict):
else:
reps = False
- gen = ptok(input_string)
+ gen = tokenizer(input_string)
result = []
for toknum, tokval, _, _, _ in gen:
if toknum == NAME:
@@ -371,13 +360,13 @@ class ParserHelper(dict):
#: List of regex substitution pairs.
-_subs_re = [(r"({0}) squared", r"\1**2"), # Handle square and cube
+_subs_re = [(r"([\w\.\-\+\*\\\^])\s+", r"\1 "), # merge multiple spaces
+ (r"({0}) squared", r"\1**2"), # Handle square and cube
(r"({0}) cubed", r"\1**3"),
(r"cubic ({0})", r"\1**3"),
(r"square ({0})", r"\1**2"),
(r"sq ({0})", r"\1**2"),
- (r"(\w)\s+(?=\w)", r"\1*"), # Handle space for multiplication
- (r"([0-9])(?={0})(?!(?:[e|E][-+]?[0-9]+))", r"\1*")
+ (r"([\w\.\-])\s+(?=\w)", r"\1*"), # Handle space for multiplication
]
#: Compiles the regex and replace {0} by a regex that matches an identifier.