summaryrefslogtreecommitdiff
path: root/pint/matplotlib.py
blob: ea88c704649fd12cdec8f814e3233d622bc3d4f8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""
    pint.matplotlib
    ~~~~~~~~~~~~~~~

    Functions and classes related to working with Matplotlib's support
    for plotting with units.

    :copyright: 2017 by Pint Authors, see AUTHORS for more details.
    :license: BSD, see LICENSE for more details.
"""

from __future__ import annotations

import matplotlib.units

from .util import iterable, sized


class PintAxisInfo(matplotlib.units.AxisInfo):
    """Support default axis and tick labeling and default limits."""

    def __init__(self, units):
        """Set the default label to the pretty-print of the unit."""
        formatter = units._REGISTRY.mpl_formatter
        super().__init__(label=formatter.format(units))


class PintConverter(matplotlib.units.ConversionInterface):
    """Implement support for pint within matplotlib's unit conversion framework."""

    def __init__(self, registry):
        super().__init__()
        self._reg = registry

    def convert(self, value, unit, axis):
        """Convert :`Quantity` instances for matplotlib to use."""
        if iterable(value):
            return [self._convert_value(v, unit, axis) for v in value]
        else:
            return self._convert_value(value, unit, axis)

    def _convert_value(self, value, unit, axis):
        """Handle converting using attached unit or falling back to axis units."""
        if hasattr(value, "units"):
            return value.to(unit).magnitude
        else:
            return self._reg.Quantity(value, axis.get_units()).to(unit).magnitude

    @staticmethod
    def axisinfo(unit, axis):
        """Return axis information for this particular unit."""

        return PintAxisInfo(unit)

    @staticmethod
    def default_units(x, axis):
        """Get the default unit to use for the given combination of unit and axis."""
        if iterable(x) and sized(x):
            return getattr(x[0], "units", None)
        return getattr(x, "units", None)


def setup_matplotlib_handlers(registry, enable):
    """Set up matplotlib's unit support to handle units from a registry.

    Parameters
    ----------
    registry : pint.UnitRegistry
        The registry that will be used.
    enable : bool
        Whether support should be enabled or disabled.

    Returns
    -------

    """
    if matplotlib.__version__ < "2.0":
        raise RuntimeError("Matplotlib >= 2.0 required to work with pint.")

    if enable:
        matplotlib.units.registry[registry.Quantity] = PintConverter(registry)
    else:
        matplotlib.units.registry.pop(registry.Quantity, None)