summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHernan Grecco <hernan.grecco@gmail.com>2014-03-29 12:02:15 -0300
committerHernan Grecco <hernan.grecco@gmail.com>2014-03-29 12:11:58 -0300
commite215db7b28bc67e5cd126f8734cdcb8bddc2a4e0 (patch)
treec0392384474bd04c196aa9bc149377f45e0cd840
parent9135642432ab67280d7ea883288554f9dc50de35 (diff)
downloadpint-e215db7b28bc67e5cd126f8734cdcb8bddc2a4e0.tar.gz
Implemented case_sensitive argument for parse_expression
The argument is added to: - parse_expression - get_name - parse_unit_name It also adds a private dictionary to UnitRegistry (`_units_casei` that maps units names (in lower case) to a set of real unit names (in the correct case) See #105
-rw-r--r--pint/testsuite/test_issues.py14
-rw-r--r--pint/unit.py44
2 files changed, 45 insertions, 13 deletions
diff --git a/pint/testsuite/test_issues.py b/pint/testsuite/test_issues.py
index 027db98..a2e7e7b 100644
--- a/pint/testsuite/test_issues.py
+++ b/pint/testsuite/test_issues.py
@@ -215,6 +215,20 @@ class TestIssues(TestCase):
self.assertEqual(sum([v * ureg.meter, v * ureg.meter]), 2 * v * ureg.meter)
+ def test_issue105(self):
+ ureg = UnitRegistry()
+
+ func = ureg.parse_unit_name
+ val = list(func('meter'))
+ self.assertEqual(list(func('METER')), [])
+ self.assertEqual(val, list(func('METER', False)))
+
+ for func in (ureg.get_name, ureg.parse_expression):
+ val = func('meter')
+ self.assertRaises(ValueError, func, 'METER')
+ self.assertEqual(val, func('METER', False))
+
+
@unittest.skipUnless(HAS_NUMPY, 'Numpy not present')
class TestIssuesNP(TestCase):
diff --git a/pint/unit.py b/pint/unit.py
index 4ef4ab5..0f5f8b4 100644
--- a/pint/unit.py
+++ b/pint/unit.py
@@ -21,6 +21,7 @@ from decimal import Decimal
from contextlib import contextmanager
from io import open, StringIO
from numbers import Number
+from collections import defaultdict
from tokenize import untokenize, NUMBER, STRING, NAME, OP
from .context import Context, ContextChain, _freeze
@@ -433,13 +434,21 @@ class UnitRegistry(object):
self.Quantity = build_quantity_class(self, force_ndarray)
self.Measurement = build_measurement_class(self, force_ndarray)
+ #: Action to take in case a unit is redefined. 'warn', 'raise', 'ignore'
self._on_redefinition = on_redefinition
+
#: Map dimension name (string) to its definition (DimensionDefinition).
self._dimensions = {}
#: Map unit name (string) to its definition (UnitDefinition).
+ #: Might contain prefixed units.
self._units = {}
+ #: Map unit name in lower case (string) to a set of unit names with the right case.
+ #: Does not contain prefixed units.
+ #: e.g: 'hz' - > set('Hz', )
+ self._units_casei = defaultdict(set)
+
#: Map prefix name (string) to its definition (PrefixDefinition).
self._prefixes = {'': PrefixDefinition('', '', (), 1)}
@@ -627,9 +636,9 @@ class UnitRegistry(object):
definition = Definition.from_string(definition)
if isinstance(definition, DimensionDefinition):
- d = self._dimensions
+ d, di = self._dimensions, None
elif isinstance(definition, UnitDefinition):
- d = self._units
+ d, di = self._units, self._units_casei
if definition.is_base:
for dimension in definition.reference.keys():
if dimension in self._dimensions:
@@ -640,11 +649,11 @@ class UnitRegistry(object):
self.define(DimensionDefinition(dimension, '', (), None, is_base=True))
elif isinstance(definition, PrefixDefinition):
- d = self._prefixes
+ d, di = self._prefixes, None
else:
raise TypeError('{0} is not a valid definition.'.format(definition))
- def _adder(key, value, action=self._on_redefinition, selected_dict=d):
+ def _adder(key, value, action=self._on_redefinition, selected_dict=d, casei_dict=di):
if key in selected_dict:
if action == 'raise':
raise RedefinitionError(key, type(value))
@@ -652,6 +661,8 @@ class UnitRegistry(object):
logger.warning("Redefining '%s' (%s)", key, type(value))
selected_dict[key] = value
+ if casei_dict is not None:
+ casei_dict[key.lower()].add(key)
_adder(definition.name, definition)
@@ -776,7 +787,7 @@ class UnitRegistry(object):
except Exception as e:
logger.warning('Could not resolve {0}: {1!r}'.format(unit_name, e))
- def get_name(self, name_or_alias):
+ def get_name(self, name_or_alias, case_sensitive=True):
"""Return the canonical name of a unit.
"""
@@ -788,7 +799,7 @@ class UnitRegistry(object):
except KeyError:
pass
- candidates = self._dedup_candidates(self.parse_unit_name(name_or_alias))
+ candidates = self._dedup_candidates(self.parse_unit_name(name_or_alias, case_sensitive))
if not candidates:
raise UndefinedUnitError(name_or_alias)
elif len(candidates) == 1:
@@ -1026,10 +1037,11 @@ class UnitRegistry(object):
return tuple(unique)
- def parse_unit_name(self, unit_name):
+ def parse_unit_name(self, unit_name, case_sensitive=True):
"""Parse a unit to identify prefix, unit name and suffix
by walking the list of prefix and suffix.
"""
+
for suffix, prefix in itertools.product(self._suffixes.keys(), self._prefixes.keys()):
if unit_name.startswith(prefix) and unit_name.endswith(suffix):
name = unit_name[len(prefix):]
@@ -1037,10 +1049,16 @@ class UnitRegistry(object):
name = name[:-len(suffix)]
if len(name) == 1:
continue
- if name in self._units:
- yield (self._prefixes[prefix].name,
- self._units[name].name,
- self._suffixes[suffix])
+ if case_sensitive:
+ if name in self._units:
+ yield (self._prefixes[prefix].name,
+ self._units[name].name,
+ self._suffixes[suffix])
+ else:
+ for real_name in self._units_casei.get(name.lower(), ()):
+ yield (self._prefixes[prefix].name,
+ self._units[real_name].name,
+ self._suffixes[suffix])
def parse_units(self, input_string, to_delta=None):
"""Parse a units expression and returns a UnitContainer with
@@ -1079,7 +1097,7 @@ class UnitRegistry(object):
return ret
- def parse_expression(self, input_string, **values):
+ def parse_expression(self, input_string, case_sensitive=True, **values):
"""Parse a mathematical expression including units and return a quantity object.
Numerical constants can be specified as keyword arguments and will take precedence
@@ -1100,7 +1118,7 @@ class UnitRegistry(object):
result.append((toknum, tokval))
continue
try:
- tokval = self.get_name(tokval)
+ tokval = self.get_name(tokval, case_sensitive)
except UndefinedUnitError as ex:
unknown.add(ex.unit_names)
if tokval: