diff options
author | Anthon van der Neut <anthon@mnt.org> | 2017-08-06 00:45:39 +0200 |
---|---|---|
committer | Anthon van der Neut <anthon@mnt.org> | 2017-08-06 00:45:39 +0200 |
commit | c93c3cd9efdebd72873655494fc695b9c65858eb (patch) | |
tree | a9be64d9e545dba6f2381a9fe17df140c2032470 | |
parent | 51de865df2d66215b5e904813f16fa84a90b215a (diff) | |
download | ruamel.yaml-c93c3cd9efdebd72873655494fc695b9c65858eb.tar.gz |
scalarfloat support
-rw-r--r-- | _test/roundtrip.py | 1 | ||||
-rw-r--r-- | _test/test_float.py | 105 | ||||
-rw-r--r-- | constructor.py | 53 | ||||
-rw-r--r-- | representer.py | 85 | ||||
-rw-r--r-- | scalarfloat.py | 77 |
5 files changed, 313 insertions, 8 deletions
diff --git a/_test/roundtrip.py b/_test/roundtrip.py index 7ce51ae..00413e6 100644 --- a/_test/roundtrip.py +++ b/_test/roundtrip.py @@ -91,6 +91,7 @@ def round_trip(inp, outp=None, extra=None, intermediate=None, indent=None, version=version) print('roundtrip second round data:\n', res, sep='') assert res == doutp + return data class YAML(ruamel.yaml.YAML): diff --git a/_test/test_float.py b/_test/test_float.py new file mode 100644 index 0000000..0adde1d --- /dev/null +++ b/_test/test_float.py @@ -0,0 +1,105 @@ +# coding: utf-8 + +from __future__ import print_function, absolute_import, division, unicode_literals + +import pytest # NOQA + +from roundtrip import round_trip, dedent, round_trip_load, round_trip_dump +from ruamel.yaml.error import MantissaNoDotYAML1_1Warning + +# http://yaml.org/type/int.html is where underscores in integers are defined + + +class TestFloat: + def test_round_trip_non_exp(self): + data = round_trip("""\ + - 1.0 + - 1.00 + - 23.100 + """) + print(data) + assert 0.999 < data[0] < 1.001 + assert 0.999 < data[1] < 1.001 + assert 23.099 < data[2] < 23.101 + + @pytest.mark.xfail(strict=True) + def test_round_trip_non_exp_trailing_dot(self): + data = round_trip("""\ + - 42. + """) + print(data) + assert 41.999 < data[0] < 42.001 + + def test_round_trip_exp_00(self): + data = round_trip("""\ + - 42e56 + - 42E56 + - 42.0E56 + - +42.0e56 + - 42.0E+056 + - +42.00e+056 + """) + print(data) + for d in data: + assert 41.99e56 < d < 42.01e56 + + @pytest.mark.xfail(strict=True) + def test_round_trip_exp_00f(self): + data = round_trip("""\ + - 42.E56 + """) + print(data) + for d in data: + assert 41.99e56 < d < 42.01e56 + + def test_round_trip_exp_01(self): + data = round_trip("""\ + - -42e56 + - -42E56 + - -42.0e56 + - -42.0E+056 + """) + print(data) + for d in data: + assert -41.99e56 > d > -42.01e56 + + def test_round_trip_exp_02(self): + data = round_trip("""\ + - 42e-56 + - 42E-56 + - 42.0E-56 + - +42.0e-56 + - 42.0E-056 + - +42.0e-056 + """) + print(data) + for d in data: + assert 41.99e-56 < d < 42.01e-56 + + def test_round_trip_exp_03(self): + data = round_trip("""\ + - -42e-56 + - -42E-56 + - -42.0e-56 + - -42.0E-056 + """) + print(data) + for d in data: + assert -41.99e-56 > d > -42.01e-56 + + def test_round_trip_exp_04(self): + data = round_trip("""\ + - 1.2e+34 + - 1.23e+034 + - 1.230e+34 + - 1.023e+34 + - -1.023e+34 + """) + + def test_yaml_1_1_no_dot(self): + with pytest.warns(MantissaNoDotYAML1_1Warning): + data = round_trip_load("""\ + %YAML 1.1 + --- + - 1e6 + """) diff --git a/constructor.py b/constructor.py index b67a748..7e572ad 100644 --- a/constructor.py +++ b/constructor.py @@ -24,6 +24,7 @@ from ruamel.yaml.scalarstring import * # NOQA from ruamel.yaml.scalarstring import (PreservedScalarString, SingleQuotedScalarString, DoubleQuotedScalarString, ScalarString) from ruamel.yaml.scalarint import ScalarInt, BinaryInt, OctalInt, HexInt, HexCapsInt +from ruamel.yaml.scalarfloat import ScalarFloat from ruamel.yaml.timestamp import TimeStamp if False: # MYPY @@ -1040,6 +1041,58 @@ class RoundTripConstructor(SafeConstructor): else: return sign * int(value_s) + def construct_yaml_float(self, node): + # type: (Any) -> float + underscore = None + m_sign = False + value_so = to_str(self.construct_scalar(node)) + value_s = value_so.replace('_', '').lower() + sign = +1 + if value_s[0] == '-': + sign = -1 + if value_s[0] in '+-': + m_sign = True + value_s = value_s[1:] + if value_s == '.inf': + return sign * self.inf_value + if value_s == '.nan': + return self.nan_value + if self.resolver.processing_version != (1, 2) and ':' in value_s: + digits = [float(part) for part in value_s.split(':')] + digits.reverse() + base = 1 + value = 0.0 + for digit in digits: + value += digit * base + base *= 60 + return sign * value + if 'e' in value_s: + try: + mantissa, exponent = value_so.split('e') + exp = 'e' + except ValueError: + mantissa, exponent = value_so.split('E') + exp = 'E' + if self.resolver.processing_version != (1, 2): + # value_s is lower case independent of input + if '.' not in mantissa: + warnings.warn(MantissaNoDotYAML1_1Warning(node, value_so)) + width = len(mantissa) + prec = mantissa.find('.') + if prec > width - 2: + prec = 0 + if m_sign: + width -= 1 + e_width = len(exponent) + e_sign = exponent[0] in '+-' + # print('sf', width, prec, m_sign, exp, e_width, e_sign) + return ScalarFloat(sign * float(value_s), width=width, prec=prec, m_sign=m_sign, + exp=exp, e_width=e_width, e_sign=e_sign) + width = len(value_so) + prec = value_so.index('.') + return ScalarFloat(sign * float(value_s), width=width, prec=prec) + # return sign * float(value_s) + def construct_yaml_str(self, node): # type: (Any) -> Any value = self.construct_scalar(node) diff --git a/representer.py b/representer.py index f441d80..d2cc557 100644 --- a/representer.py +++ b/representer.py @@ -9,6 +9,7 @@ from ruamel.yaml.compat import text_type, binary_type, to_unicode, PY2, PY3, ord from ruamel.yaml.scalarstring import (PreservedScalarString, SingleQuotedScalarString, DoubleQuotedScalarString) from ruamel.yaml.scalarint import ScalarInt, BinaryInt, OctalInt, HexInt, HexCapsInt +from ruamel.yaml.scalarfloat import ScalarFloat from ruamel.yaml.timestamp import TimeStamp import datetime @@ -308,14 +309,14 @@ class SafeRepresenter(BaseRepresenter): value = u'-.inf' else: value = to_unicode(repr(data)).lower() - # Note that in some cases `repr(data)` represents a float number - # without the decimal parts. For instance: - # >>> repr(1e17) - # '1e17' - # Unfortunately, this is not a valid float representation according - # to the definition of the `!!float` tag. We fix this by adding - # '.0' before the 'e' symbol. - if u'.' not in value and u'e' in value: + if self.dumper.version == (1, 1) and u'.' not in value and u'e' in value: + # Note that in some cases `repr(data)` represents a float number + # without the decimal parts. For instance: + # >>> repr(1e17) + # '1e17' + # Unfortunately, this is not a valid float representation according + # to the definition of the `!!float` tag in YAML 1.1. We fix this by adding + # '.0' before the 'e' symbol. value = value.replace(u'e', u'.0e', 1) return self.represent_scalar(u'tag:yaml.org,2002:float', value) @@ -751,6 +752,70 @@ class RoundTripRepresenter(SafeRepresenter): s = format(data, 'X') return self.insert_underscore('0x', s, data._underscore) + def represent_scalar_float(self, data): + # type: (Any) -> Any + value = None + if data != data or (data == 0.0 and data == 1.0): + value = u'.nan' + elif data == self.inf_value: + value = u'.inf' + elif data == -self.inf_value: + value = u'-.inf' + if value: + return self.represent_scalar(u'tag:yaml.org,2002:float', value) + if data._exp is None: + prec = data._prec + if prec < 1: + prec = 1 + # print('dw2', data._width, prec) + value = '{:{}.{}f}'.format(data, data._width, data._width-prec-1) + while len(value) < data._width: + value += '0' + else: + # print('pr', data._width, prec) + # if data._prec > 0: + # prec = data._width - data.prec + # prec = data._prec + # if prec < 1: + # prec = 1 + m, es = '{:{}e}'.format(data, data._width).split('e') + w = data._width if data._prec > 0 else (data._width + 1) + if data < 0: + w += 1 + m = m[:w] + e = int(es) + m1, m2 = m.split('.') # always second? + while len(m1) + len(m2) < data._width - (1 if data._prec >= 0 else 0): + m2 += '0' + if data._m_sign and data > 0: + m1 = '+' + m1 + esgn = '+' if data._e_sign else '' + if data._prec < 0: # mantissa without dot + # print('ew2', m2, len(m2), e) + if m2 != '0': + e -= len(m2) + else: + m2 = '' + value = m1 + m2 + data._exp + '{:{}0{}d}'.format(e, esgn, data._e_width) + elif data._prec == 0: # mantissa with trailind dot + e -= len(m2) + value = m1 + m2 + '.' + data._exp + '{:{}0{}d}'.format(e, esgn, data._e_width) + else: + while len(m1) < data._prec: + m1 += m2[0] + m2 = m2[1:] + e -= 1 + value = m1 + '.' + m2 + data._exp + '{:{}0{}d}'.format(e, esgn, data._e_width) + + if value is None: + value = to_unicode(repr(data)).lower() + return self.represent_scalar(u'tag:yaml.org,2002:float', value) + #if data._width is not None: + # s = '{:0{}d}'.format(data, data._width) + #else: + # s = format(data, 'f') + #return self.insert_underscore('', s, data._underscore) + def represent_sequence(self, tag, sequence, flow_style=None): # type: (Any, Any, Any) -> Any value = [] # type: List[Any] @@ -1048,6 +1113,10 @@ RoundTripRepresenter.add_representer( HexCapsInt, RoundTripRepresenter.represent_hex_caps_int) +RoundTripRepresenter.add_representer( + ScalarFloat, + RoundTripRepresenter.represent_scalar_float) + RoundTripRepresenter.add_representer(CommentedSeq, RoundTripRepresenter.represent_list) diff --git a/scalarfloat.py b/scalarfloat.py new file mode 100644 index 0000000..01d7080 --- /dev/null +++ b/scalarfloat.py @@ -0,0 +1,77 @@ +# coding: utf-8 + +from __future__ import print_function, absolute_import, division, unicode_literals + +if False: # MYPY + from typing import Text, Any, Dict, List # NOQA + +__all__ = ["ScalarFloat", "ExponentialFloat", "ExponentialCapsFloat"] + +from .compat import no_limit_int # NOQA + + +class ScalarFloat(float): + def __new__(cls, *args, **kw): + # type: (Any, Any, Any) -> Any + width = kw.pop('width', None) # type: ignore + prec = kw.pop('prec', None) # type: ignore + m_sign = kw.pop('m_sign', None) # type: ignore + exp = kw.pop('exp', None) # type: ignore + e_width = kw.pop('e_width', None) # type: ignore + e_sign = kw.pop('e_sign', None) # type: ignore + underscore = kw.pop('underscore', None) # type: ignore + v = float.__new__(cls, *args, **kw) # type: ignore + v._width = width + v._prec = prec + v._m_sign = m_sign + v._exp = exp + v._e_width = e_width + v._e_sign = e_sign + v._underscore = underscore + return v + + def __iadd__(self, a): # type: ignore + # type: (Any) -> Any + x = type(self)(self + a) + x._width = self._width # type: ignore + x._underscore = self._underscore[:] if self._underscore is not None else None # type: ignore # NOQA + return x + + def __ifloordiv__(self, a): # type: ignore + # type: (Any) -> Any + x = type(self)(self // a) + x._width = self._width # type: ignore + x._underscore = self._underscore[:] if self._underscore is not None else None # type: ignore # NOQA + return x + + def __imul__(self, a): # type: ignore + # type: (Any) -> Any + x = type(self)(self * a) + x._width = self._width # type: ignore + x._underscore = self._underscore[:] if self._underscore is not None else None # type: ignore # NOQA + return x + + def __ipow__(self, a): # type: ignore + # type: (Any) -> Any + x = type(self)(self ** a) + x._width = self._width # type: ignore + x._underscore = self._underscore[:] if self._underscore is not None else None # type: ignore # NOQA + return x + + def __isub__(self, a): # type: ignore + # type: (Any) -> Any + x = type(self)(self - a) + x._width = self._width # type: ignore + x._underscore = self._underscore[:] if self._underscore is not None else None # type: ignore # NOQA + return x + + +class ExponentialFloat(ScalarFloat): + def __new__(cls, value, width=None, underscore=None): + # type: (Any, Any, Any) -> Any + return ScalarFloat.__new__(cls, value, width=width, underscore=underscore) + +class ExponentialCapsFloat(ScalarFloat): + def __new__(cls, value, width=None, underscore=None): + # type: (Any, Any, Any) -> Any + return ScalarFloat.__new__(cls, value, width=width, underscore=underscore) |