diff options
Diffstat (limited to 'Lib/statistics.py')
| -rw-r--r-- | Lib/statistics.py | 79 | 
1 files changed, 66 insertions, 13 deletions
| diff --git a/Lib/statistics.py b/Lib/statistics.py index cf8eaa0a61..9f1efa21b1 100644 --- a/Lib/statistics.py +++ b/Lib/statistics.py @@ -137,7 +137,7 @@ from decimal import Decimal  from itertools import groupby, repeat  from bisect import bisect_left, bisect_right  from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum -from operator import itemgetter, mul +from operator import mul  from collections import Counter, namedtuple  _SQRT2 = sqrt(2.0) @@ -248,6 +248,28 @@ def _exact_ratio(x):      x is expected to be an int, Fraction, Decimal or float.      """ + +    # XXX We should revisit whether using fractions to accumulate exact +    # ratios is the right way to go. + +    # The integer ratios for binary floats can have numerators or +    # denominators with over 300 decimal digits.  The problem is more +    # acute with decimal floats where the the default decimal context +    # supports a huge range of exponents from Emin=-999999 to +    # Emax=999999.  When expanded with as_integer_ratio(), numbers like +    # Decimal('3.14E+5000') and Decimal('3.14E-5000') have large +    # numerators or denominators that will slow computation. + +    # When the integer ratios are accumulated as fractions, the size +    # grows to cover the full range from the smallest magnitude to the +    # largest.  For example, Fraction(3.14E+300) + Fraction(3.14E-300), +    # has a 616 digit numerator.  Likewise, +    # Fraction(Decimal('3.14E+5000')) + Fraction(Decimal('3.14E-5000')) +    # has 10,003 digit numerator. + +    # This doesn't seem to have been problem in practice, but it is a +    # potential pitfall. +      try:          return x.as_integer_ratio()      except AttributeError: @@ -305,28 +327,60 @@ def _fail_neg(values, errmsg='negative value'):              raise StatisticsError(errmsg)          yield x -def _isqrt_frac_rto(n: int, m: int) -> float: + +def _integer_sqrt_of_frac_rto(n: int, m: int) -> int:      """Square root of n/m, rounded to the nearest integer using round-to-odd."""      # Reference: https://www.lri.fr/~melquion/doc/05-imacs17_1-expose.pdf      a = math.isqrt(n // m)      return a | (a*a*m != n) -# For 53 bit precision floats, the _sqrt_frac() shift is 109. -_sqrt_shift: int = 2 * sys.float_info.mant_dig + 3 -def _sqrt_frac(n: int, m: int) -> float: +# For 53 bit precision floats, the bit width used in +# _float_sqrt_of_frac() is 109. +_sqrt_bit_width: int = 2 * sys.float_info.mant_dig + 3 + + +def _float_sqrt_of_frac(n: int, m: int) -> float:      """Square root of n/m as a float, correctly rounded."""      # See principle and proof sketch at: https://bugs.python.org/msg407078 -    q = (n.bit_length() - m.bit_length() - _sqrt_shift) // 2 +    q = (n.bit_length() - m.bit_length() - _sqrt_bit_width) // 2      if q >= 0: -        numerator = _isqrt_frac_rto(n, m << 2 * q) << q +        numerator = _integer_sqrt_of_frac_rto(n, m << 2 * q) << q          denominator = 1      else: -        numerator = _isqrt_frac_rto(n << -2 * q, m) +        numerator = _integer_sqrt_of_frac_rto(n << -2 * q, m)          denominator = 1 << -q      return numerator / denominator   # Convert to float +def _decimal_sqrt_of_frac(n: int, m: int) -> Decimal: +    """Square root of n/m as a Decimal, correctly rounded.""" +    # Premise:  For decimal, computing (n/m).sqrt() can be off +    #           by 1 ulp from the correctly rounded result. +    # Method:   Check the result, moving up or down a step if needed. +    if n <= 0: +        if not n: +            return Decimal('0.0') +        n, m = -n, -m + +    root = (Decimal(n) / Decimal(m)).sqrt() +    nr, dr = root.as_integer_ratio() + +    plus = root.next_plus() +    np, dp = plus.as_integer_ratio() +    # test: n / m > ((root + plus) / 2) ** 2 +    if 4 * n * (dr*dp)**2 > m * (dr*np + dp*nr)**2: +        return plus + +    minus = root.next_minus() +    nm, dm = minus.as_integer_ratio() +    # test: n / m < ((root + minus) / 2) ** 2 +    if 4 * n * (dr*dm)**2 < m * (dr*nm + dm*nr)**2: +        return minus + +    return root + +  # === Measures of central tendency (averages) ===  def mean(data): @@ -869,7 +923,7 @@ def stdev(data, xbar=None):      if hasattr(T, 'sqrt'):          var = _convert(mss, T)          return var.sqrt() -    return _sqrt_frac(mss.numerator, mss.denominator) +    return _float_sqrt_of_frac(mss.numerator, mss.denominator)  def pstdev(data, mu=None): @@ -888,10 +942,9 @@ def pstdev(data, mu=None):          raise StatisticsError('pstdev requires at least one data point')      T, ss = _ss(data, mu)      mss = ss / n -    if hasattr(T, 'sqrt'): -        var = _convert(mss, T) -        return var.sqrt() -    return _sqrt_frac(mss.numerator, mss.denominator) +    if issubclass(T, Decimal): +        return _decimal_sqrt_of_frac(mss.numerator, mss.denominator) +    return _float_sqrt_of_frac(mss.numerator, mss.denominator)  # === Statistics for relations between two inputs === | 
