summaryrefslogtreecommitdiff
path: root/pint/testing.py
blob: 126a39fc8553ac1c9ae03f7eea0210eb622da0bf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from __future__ import annotations

import math
import warnings
from numbers import Number
from typing import Optional

from . import Quantity
from .compat import ndarray

try:
    import numpy as np
except ImportError:
    np = None


def _get_comparable_magnitudes(first, second, msg):
    if isinstance(first, Quantity) and isinstance(second, Quantity):
        ctx = first._REGISTRY._active_ctx.contexts
        if first.is_compatible_with(second, *ctx):
            second = second.to(first)
        assert first.units == second.units, msg + " Units are not equal."
        m1, m2 = first.magnitude, second.magnitude
    elif isinstance(first, Quantity):
        assert first.dimensionless, msg + " The first is not dimensionless."
        first = first.to("")
        m1, m2 = first.magnitude, second
    elif isinstance(second, Quantity):
        assert second.dimensionless, msg + " The second is not dimensionless."
        second = second.to("")
        m1, m2 = first, second.magnitude
    else:
        m1, m2 = first, second

    return m1, m2


def assert_equal(first, second, msg: Optional[str] = None) -> None:
    if msg is None:
        msg = f"Comparing {first!r} and {second!r}. "

    m1, m2 = _get_comparable_magnitudes(first, second, msg)
    msg += f" (Converted to {m1!r} and {m2!r}): Magnitudes are not equal"

    if isinstance(m1, ndarray) or isinstance(m2, ndarray):
        np.testing.assert_array_equal(m1, m2, err_msg=msg)
    elif not isinstance(m1, Number):
        warnings.warn(RuntimeWarning)
        return
    elif not isinstance(m2, Number):
        warnings.warn(RuntimeWarning)
        return
    elif math.isnan(m1):
        assert math.isnan(m2), msg
    elif math.isnan(m2):
        assert math.isnan(m1), msg
    else:
        assert m1 == m2, msg


def assert_allclose(
    first, second, rtol: float = 1e-07, atol: float = 0, msg: Optional[str] = None
) -> None:
    if msg is None:
        try:
            msg = f"Comparing {first!r} and {second!r}. "
        except TypeError:
            try:
                msg = f"Comparing {first} and {second}. "
            except Exception:
                msg = "Comparing"

    m1, m2 = _get_comparable_magnitudes(first, second, msg)
    msg += f" (Converted to {m1!r} and {m2!r})"

    if isinstance(m1, ndarray) or isinstance(m2, ndarray):
        np.testing.assert_allclose(m1, m2, rtol=rtol, atol=atol, err_msg=msg)
    elif not isinstance(m1, Number):
        warnings.warn(RuntimeWarning)
        return
    elif not isinstance(m2, Number):
        warnings.warn(RuntimeWarning)
        return
    elif math.isnan(m1):
        assert math.isnan(m2), msg
    elif math.isnan(m2):
        assert math.isnan(m1), msg
    elif math.isinf(m1):
        assert math.isinf(m2), msg
    elif math.isinf(m2):
        assert math.isinf(m1), msg
    else:
        # Numpy version (don't like because is not symmetric)
        # assert abs(m1 - m2) <= atol + rtol * abs(m2), msg
        assert abs(m1 - m2) <= max(rtol * max(abs(m1), abs(m2)), atol), msg