summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnthon van der Neut <anthon@mnt.org>2017-08-06 00:45:39 +0200
committerAnthon van der Neut <anthon@mnt.org>2017-08-06 00:45:39 +0200
commitc93c3cd9efdebd72873655494fc695b9c65858eb (patch)
treea9be64d9e545dba6f2381a9fe17df140c2032470
parent51de865df2d66215b5e904813f16fa84a90b215a (diff)
downloadruamel.yaml-c93c3cd9efdebd72873655494fc695b9c65858eb.tar.gz
scalarfloat support
-rw-r--r--_test/roundtrip.py1
-rw-r--r--_test/test_float.py105
-rw-r--r--constructor.py53
-rw-r--r--representer.py85
-rw-r--r--scalarfloat.py77
5 files changed, 313 insertions, 8 deletions
diff --git a/_test/roundtrip.py b/_test/roundtrip.py
index 7ce51ae..00413e6 100644
--- a/_test/roundtrip.py
+++ b/_test/roundtrip.py
@@ -91,6 +91,7 @@ def round_trip(inp, outp=None, extra=None, intermediate=None, indent=None,
version=version)
print('roundtrip second round data:\n', res, sep='')
assert res == doutp
+ return data
class YAML(ruamel.yaml.YAML):
diff --git a/_test/test_float.py b/_test/test_float.py
new file mode 100644
index 0000000..0adde1d
--- /dev/null
+++ b/_test/test_float.py
@@ -0,0 +1,105 @@
+# coding: utf-8
+
+from __future__ import print_function, absolute_import, division, unicode_literals
+
+import pytest # NOQA
+
+from roundtrip import round_trip, dedent, round_trip_load, round_trip_dump
+from ruamel.yaml.error import MantissaNoDotYAML1_1Warning
+
+# http://yaml.org/type/int.html is where underscores in integers are defined
+
+
+class TestFloat:
+ def test_round_trip_non_exp(self):
+ data = round_trip("""\
+ - 1.0
+ - 1.00
+ - 23.100
+ """)
+ print(data)
+ assert 0.999 < data[0] < 1.001
+ assert 0.999 < data[1] < 1.001
+ assert 23.099 < data[2] < 23.101
+
+ @pytest.mark.xfail(strict=True)
+ def test_round_trip_non_exp_trailing_dot(self):
+ data = round_trip("""\
+ - 42.
+ """)
+ print(data)
+ assert 41.999 < data[0] < 42.001
+
+ def test_round_trip_exp_00(self):
+ data = round_trip("""\
+ - 42e56
+ - 42E56
+ - 42.0E56
+ - +42.0e56
+ - 42.0E+056
+ - +42.00e+056
+ """)
+ print(data)
+ for d in data:
+ assert 41.99e56 < d < 42.01e56
+
+ @pytest.mark.xfail(strict=True)
+ def test_round_trip_exp_00f(self):
+ data = round_trip("""\
+ - 42.E56
+ """)
+ print(data)
+ for d in data:
+ assert 41.99e56 < d < 42.01e56
+
+ def test_round_trip_exp_01(self):
+ data = round_trip("""\
+ - -42e56
+ - -42E56
+ - -42.0e56
+ - -42.0E+056
+ """)
+ print(data)
+ for d in data:
+ assert -41.99e56 > d > -42.01e56
+
+ def test_round_trip_exp_02(self):
+ data = round_trip("""\
+ - 42e-56
+ - 42E-56
+ - 42.0E-56
+ - +42.0e-56
+ - 42.0E-056
+ - +42.0e-056
+ """)
+ print(data)
+ for d in data:
+ assert 41.99e-56 < d < 42.01e-56
+
+ def test_round_trip_exp_03(self):
+ data = round_trip("""\
+ - -42e-56
+ - -42E-56
+ - -42.0e-56
+ - -42.0E-056
+ """)
+ print(data)
+ for d in data:
+ assert -41.99e-56 > d > -42.01e-56
+
+ def test_round_trip_exp_04(self):
+ data = round_trip("""\
+ - 1.2e+34
+ - 1.23e+034
+ - 1.230e+34
+ - 1.023e+34
+ - -1.023e+34
+ """)
+
+ def test_yaml_1_1_no_dot(self):
+ with pytest.warns(MantissaNoDotYAML1_1Warning):
+ data = round_trip_load("""\
+ %YAML 1.1
+ ---
+ - 1e6
+ """)
diff --git a/constructor.py b/constructor.py
index b67a748..7e572ad 100644
--- a/constructor.py
+++ b/constructor.py
@@ -24,6 +24,7 @@ from ruamel.yaml.scalarstring import * # NOQA
from ruamel.yaml.scalarstring import (PreservedScalarString, SingleQuotedScalarString,
DoubleQuotedScalarString, ScalarString)
from ruamel.yaml.scalarint import ScalarInt, BinaryInt, OctalInt, HexInt, HexCapsInt
+from ruamel.yaml.scalarfloat import ScalarFloat
from ruamel.yaml.timestamp import TimeStamp
if False: # MYPY
@@ -1040,6 +1041,58 @@ class RoundTripConstructor(SafeConstructor):
else:
return sign * int(value_s)
+ def construct_yaml_float(self, node):
+ # type: (Any) -> float
+ underscore = None
+ m_sign = False
+ value_so = to_str(self.construct_scalar(node))
+ value_s = value_so.replace('_', '').lower()
+ sign = +1
+ if value_s[0] == '-':
+ sign = -1
+ if value_s[0] in '+-':
+ m_sign = True
+ value_s = value_s[1:]
+ if value_s == '.inf':
+ return sign * self.inf_value
+ if value_s == '.nan':
+ return self.nan_value
+ if self.resolver.processing_version != (1, 2) and ':' in value_s:
+ digits = [float(part) for part in value_s.split(':')]
+ digits.reverse()
+ base = 1
+ value = 0.0
+ for digit in digits:
+ value += digit * base
+ base *= 60
+ return sign * value
+ if 'e' in value_s:
+ try:
+ mantissa, exponent = value_so.split('e')
+ exp = 'e'
+ except ValueError:
+ mantissa, exponent = value_so.split('E')
+ exp = 'E'
+ if self.resolver.processing_version != (1, 2):
+ # value_s is lower case independent of input
+ if '.' not in mantissa:
+ warnings.warn(MantissaNoDotYAML1_1Warning(node, value_so))
+ width = len(mantissa)
+ prec = mantissa.find('.')
+ if prec > width - 2:
+ prec = 0
+ if m_sign:
+ width -= 1
+ e_width = len(exponent)
+ e_sign = exponent[0] in '+-'
+ # print('sf', width, prec, m_sign, exp, e_width, e_sign)
+ return ScalarFloat(sign * float(value_s), width=width, prec=prec, m_sign=m_sign,
+ exp=exp, e_width=e_width, e_sign=e_sign)
+ width = len(value_so)
+ prec = value_so.index('.')
+ return ScalarFloat(sign * float(value_s), width=width, prec=prec)
+ # return sign * float(value_s)
+
def construct_yaml_str(self, node):
# type: (Any) -> Any
value = self.construct_scalar(node)
diff --git a/representer.py b/representer.py
index f441d80..d2cc557 100644
--- a/representer.py
+++ b/representer.py
@@ -9,6 +9,7 @@ from ruamel.yaml.compat import text_type, binary_type, to_unicode, PY2, PY3, ord
from ruamel.yaml.scalarstring import (PreservedScalarString, SingleQuotedScalarString,
DoubleQuotedScalarString)
from ruamel.yaml.scalarint import ScalarInt, BinaryInt, OctalInt, HexInt, HexCapsInt
+from ruamel.yaml.scalarfloat import ScalarFloat
from ruamel.yaml.timestamp import TimeStamp
import datetime
@@ -308,14 +309,14 @@ class SafeRepresenter(BaseRepresenter):
value = u'-.inf'
else:
value = to_unicode(repr(data)).lower()
- # Note that in some cases `repr(data)` represents a float number
- # without the decimal parts. For instance:
- # >>> repr(1e17)
- # '1e17'
- # Unfortunately, this is not a valid float representation according
- # to the definition of the `!!float` tag. We fix this by adding
- # '.0' before the 'e' symbol.
- if u'.' not in value and u'e' in value:
+ if self.dumper.version == (1, 1) and u'.' not in value and u'e' in value:
+ # Note that in some cases `repr(data)` represents a float number
+ # without the decimal parts. For instance:
+ # >>> repr(1e17)
+ # '1e17'
+ # Unfortunately, this is not a valid float representation according
+ # to the definition of the `!!float` tag in YAML 1.1. We fix this by adding
+ # '.0' before the 'e' symbol.
value = value.replace(u'e', u'.0e', 1)
return self.represent_scalar(u'tag:yaml.org,2002:float', value)
@@ -751,6 +752,70 @@ class RoundTripRepresenter(SafeRepresenter):
s = format(data, 'X')
return self.insert_underscore('0x', s, data._underscore)
+ def represent_scalar_float(self, data):
+ # type: (Any) -> Any
+ value = None
+ if data != data or (data == 0.0 and data == 1.0):
+ value = u'.nan'
+ elif data == self.inf_value:
+ value = u'.inf'
+ elif data == -self.inf_value:
+ value = u'-.inf'
+ if value:
+ return self.represent_scalar(u'tag:yaml.org,2002:float', value)
+ if data._exp is None:
+ prec = data._prec
+ if prec < 1:
+ prec = 1
+ # print('dw2', data._width, prec)
+ value = '{:{}.{}f}'.format(data, data._width, data._width-prec-1)
+ while len(value) < data._width:
+ value += '0'
+ else:
+ # print('pr', data._width, prec)
+ # if data._prec > 0:
+ # prec = data._width - data.prec
+ # prec = data._prec
+ # if prec < 1:
+ # prec = 1
+ m, es = '{:{}e}'.format(data, data._width).split('e')
+ w = data._width if data._prec > 0 else (data._width + 1)
+ if data < 0:
+ w += 1
+ m = m[:w]
+ e = int(es)
+ m1, m2 = m.split('.') # always second?
+ while len(m1) + len(m2) < data._width - (1 if data._prec >= 0 else 0):
+ m2 += '0'
+ if data._m_sign and data > 0:
+ m1 = '+' + m1
+ esgn = '+' if data._e_sign else ''
+ if data._prec < 0: # mantissa without dot
+ # print('ew2', m2, len(m2), e)
+ if m2 != '0':
+ e -= len(m2)
+ else:
+ m2 = ''
+ value = m1 + m2 + data._exp + '{:{}0{}d}'.format(e, esgn, data._e_width)
+ elif data._prec == 0: # mantissa with trailind dot
+ e -= len(m2)
+ value = m1 + m2 + '.' + data._exp + '{:{}0{}d}'.format(e, esgn, data._e_width)
+ else:
+ while len(m1) < data._prec:
+ m1 += m2[0]
+ m2 = m2[1:]
+ e -= 1
+ value = m1 + '.' + m2 + data._exp + '{:{}0{}d}'.format(e, esgn, data._e_width)
+
+ if value is None:
+ value = to_unicode(repr(data)).lower()
+ return self.represent_scalar(u'tag:yaml.org,2002:float', value)
+ #if data._width is not None:
+ # s = '{:0{}d}'.format(data, data._width)
+ #else:
+ # s = format(data, 'f')
+ #return self.insert_underscore('', s, data._underscore)
+
def represent_sequence(self, tag, sequence, flow_style=None):
# type: (Any, Any, Any) -> Any
value = [] # type: List[Any]
@@ -1048,6 +1113,10 @@ RoundTripRepresenter.add_representer(
HexCapsInt,
RoundTripRepresenter.represent_hex_caps_int)
+RoundTripRepresenter.add_representer(
+ ScalarFloat,
+ RoundTripRepresenter.represent_scalar_float)
+
RoundTripRepresenter.add_representer(CommentedSeq,
RoundTripRepresenter.represent_list)
diff --git a/scalarfloat.py b/scalarfloat.py
new file mode 100644
index 0000000..01d7080
--- /dev/null
+++ b/scalarfloat.py
@@ -0,0 +1,77 @@
+# coding: utf-8
+
+from __future__ import print_function, absolute_import, division, unicode_literals
+
+if False: # MYPY
+ from typing import Text, Any, Dict, List # NOQA
+
+__all__ = ["ScalarFloat", "ExponentialFloat", "ExponentialCapsFloat"]
+
+from .compat import no_limit_int # NOQA
+
+
+class ScalarFloat(float):
+ def __new__(cls, *args, **kw):
+ # type: (Any, Any, Any) -> Any
+ width = kw.pop('width', None) # type: ignore
+ prec = kw.pop('prec', None) # type: ignore
+ m_sign = kw.pop('m_sign', None) # type: ignore
+ exp = kw.pop('exp', None) # type: ignore
+ e_width = kw.pop('e_width', None) # type: ignore
+ e_sign = kw.pop('e_sign', None) # type: ignore
+ underscore = kw.pop('underscore', None) # type: ignore
+ v = float.__new__(cls, *args, **kw) # type: ignore
+ v._width = width
+ v._prec = prec
+ v._m_sign = m_sign
+ v._exp = exp
+ v._e_width = e_width
+ v._e_sign = e_sign
+ v._underscore = underscore
+ return v
+
+ def __iadd__(self, a): # type: ignore
+ # type: (Any) -> Any
+ x = type(self)(self + a)
+ x._width = self._width # type: ignore
+ x._underscore = self._underscore[:] if self._underscore is not None else None # type: ignore # NOQA
+ return x
+
+ def __ifloordiv__(self, a): # type: ignore
+ # type: (Any) -> Any
+ x = type(self)(self // a)
+ x._width = self._width # type: ignore
+ x._underscore = self._underscore[:] if self._underscore is not None else None # type: ignore # NOQA
+ return x
+
+ def __imul__(self, a): # type: ignore
+ # type: (Any) -> Any
+ x = type(self)(self * a)
+ x._width = self._width # type: ignore
+ x._underscore = self._underscore[:] if self._underscore is not None else None # type: ignore # NOQA
+ return x
+
+ def __ipow__(self, a): # type: ignore
+ # type: (Any) -> Any
+ x = type(self)(self ** a)
+ x._width = self._width # type: ignore
+ x._underscore = self._underscore[:] if self._underscore is not None else None # type: ignore # NOQA
+ return x
+
+ def __isub__(self, a): # type: ignore
+ # type: (Any) -> Any
+ x = type(self)(self - a)
+ x._width = self._width # type: ignore
+ x._underscore = self._underscore[:] if self._underscore is not None else None # type: ignore # NOQA
+ return x
+
+
+class ExponentialFloat(ScalarFloat):
+ def __new__(cls, value, width=None, underscore=None):
+ # type: (Any, Any, Any) -> Any
+ return ScalarFloat.__new__(cls, value, width=width, underscore=underscore)
+
+class ExponentialCapsFloat(ScalarFloat):
+ def __new__(cls, value, width=None, underscore=None):
+ # type: (Any, Any, Any) -> Any
+ return ScalarFloat.__new__(cls, value, width=width, underscore=underscore)