diff options
Diffstat (limited to 'pint/testsuite/test_matplotlib.py')
-rw-r--r-- | pint/testsuite/test_matplotlib.py | 35 |
1 files changed, 20 insertions, 15 deletions
diff --git a/pint/testsuite/test_matplotlib.py b/pint/testsuite/test_matplotlib.py index 3c590ae..25f3172 100644 --- a/pint/testsuite/test_matplotlib.py +++ b/pint/testsuite/test_matplotlib.py @@ -6,38 +6,43 @@ from pint import UnitRegistry plt = pytest.importorskip("matplotlib.pyplot", reason="matplotlib is not available") np = pytest.importorskip("numpy", reason="NumPy is not available") -# Set up unit registry for matplotlib -ureg = UnitRegistry() -ureg.setup_matplotlib(True) + +@pytest.fixture(scope="module") +def local_registry(): + # Set up unit registry for matplotlib + ureg = UnitRegistry() + ureg.setup_matplotlib(True) + return ureg + # Set up matplotlib plt.switch_backend("agg") @pytest.mark.mpl_image_compare(tolerance=0, remove_text=True) -def test_basic_plot(): - y = np.linspace(0, 30) * ureg.miles - x = np.linspace(0, 5) * ureg.hours +def test_basic_plot(local_registry): + y = np.linspace(0, 30) * local_registry.miles + x = np.linspace(0, 5) * local_registry.hours fig, ax = plt.subplots() ax.plot(x, y, "tab:blue") - ax.axhline(26400 * ureg.feet, color="tab:red") - ax.axvline(120 * ureg.minutes, color="tab:green") + ax.axhline(26400 * local_registry.feet, color="tab:red") + ax.axvline(120 * local_registry.minutes, color="tab:green") return fig @pytest.mark.mpl_image_compare(tolerance=0, remove_text=True) -def test_plot_with_set_units(): - y = np.linspace(0, 30) * ureg.miles - x = np.linspace(0, 5) * ureg.hours +def test_plot_with_set_units(local_registry): + y = np.linspace(0, 30) * local_registry.miles + x = np.linspace(0, 5) * local_registry.hours fig, ax = plt.subplots() - ax.yaxis.set_units(ureg.inches) - ax.xaxis.set_units(ureg.seconds) + ax.yaxis.set_units(local_registry.inches) + ax.xaxis.set_units(local_registry.seconds) ax.plot(x, y, "tab:blue") - ax.axhline(26400 * ureg.feet, color="tab:red") - ax.axvline(120 * ureg.minutes, color="tab:green") + ax.axhline(26400 * local_registry.feet, color="tab:red") + ax.axvline(120 * local_registry.minutes, color="tab:green") return fig |