summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHernan Grecco <hernan.grecco@gmail.com>2014-07-03 17:21:25 -0300
committerHernan Grecco <hernan.grecco@gmail.com>2014-07-03 17:21:25 -0300
commit33854821fd08e19bf46c3b5efb5a0c2b009665e2 (patch)
treeb92199ff8c9d2f46d57b7378160aca42ff2dbd20
parent17afd4fd4725ef24fa164d0ed34c23ee90016d97 (diff)
parentc8ae9eeccd9334c2fb9976067a6f1340e3f97752 (diff)
downloadpint-33854821fd08e19bf46c3b5efb5a0c2b009665e2.tar.gz
Merge branch 'jbmohler-recurse_cls_method' into develop
-rw-r--r--pint/unit.py63
1 files changed, 36 insertions, 27 deletions
diff --git a/pint/unit.py b/pint/unit.py
index 27f6b5e..7ef47f7 100644
--- a/pint/unit.py
+++ b/pint/unit.py
@@ -885,9 +885,8 @@ class UnitRegistry(object):
:param input_units:
:return: dimensionality
"""
- dims = UnitsContainer()
if not input_units:
- return dims
+ return UnitsContainer()
if isinstance(input_units, string_types):
input_units = ParserHelper.from_string(input_units)
@@ -895,25 +894,30 @@ class UnitRegistry(object):
if input_units in self._dimensionality_cache:
return copy.copy(self._dimensionality_cache[input_units])
- for key, value in input_units.items():
- if _is_dim(key):
- reg = self._dimensions[key]
- if reg.is_base:
- dims.add(key, value)
- else:
- dims *= self.get_dimensionality(reg.reference) ** value
- else:
- reg = self._units[self.get_name(key)]
- if reg.is_base:
- dims *= reg.reference ** value
- else:
- dims *= self.get_dimensionality(reg.reference) ** value
+ accumulator = defaultdict(lambda: 0.0)
+ self._get_dimensionality_recurse(input_units, 1.0, accumulator)
+
+ dims = UnitsContainer(dict((k, v) for k, v in accumulator.items() if v != 0.))
if '[]' in dims:
del dims['[]']
return dims
+ def _get_dimensionality_recurse(self, ref, exp, accumulator):
+ for key, value in ref.items():
+ exp2 = exp*value
+ if _is_dim(key):
+ reg = self._dimensions[key]
+ if reg.is_base:
+ accumulator[key] += exp2
+ elif reg.reference != None:
+ self._get_dimensionality_recurse(reg.reference, exp2, accumulator)
+ else:
+ reg = self._units[self.get_name(key)]
+ if reg.reference != None:
+ self._get_dimensionality_recurse(reg.reference, exp2, accumulator)
+
def get_base_units(self, input_units, check_nonmult=True):
"""Convert unit or dict of units to the base units.
@@ -936,18 +940,11 @@ class UnitRegistry(object):
if check_nonmult and input_units in self._base_units_cache:
return copy.deepcopy(self._base_units_cache[input_units])
- factor = 1.
- units = UnitsContainer()
- for key, value in input_units.items():
- key = self.get_name(key)
- reg = self._units[key]
- if reg.is_base:
- units.add(key, value)
- else:
- fac, uni = self.get_base_units(reg.reference, check_nonmult=False)
- if factor is not None:
- factor *= (reg.converter.scale * fac) ** value
- units *= uni ** value
+ accumulators = [1., defaultdict(lambda: 0.0)]
+ self._get_base_units(input_units, 1.0, accumulators)
+
+ factor = accumulators[0]
+ units = UnitsContainer(dict((k, v) for k, v in accumulators[1].items() if v != 0.))
# Check if any of the final units is non multiplicative and return None instead.
if check_nonmult:
@@ -957,6 +954,18 @@ class UnitRegistry(object):
return factor, units
+ def _get_base_units(self, ref, exp, accumulators):
+ for key, value in ref.items():
+ key = self.get_name(key)
+ reg = self._units[key]
+ exp2 = exp*value
+ if reg.is_base:
+ accumulators[1][key] += exp2
+ else:
+ accumulators[0] *= reg.converter.scale ** exp2
+ if reg.reference != None:
+ self._get_base_units(reg.reference, exp2, accumulators)
+
def get_compatible_units(self, input_units):
if not input_units:
return 1., UnitsContainer()