diff options
author | Hernan <hernan.grecco@gmail.com> | 2022-02-14 00:41:09 -0300 |
---|---|---|
committer | Hernan <hernan.grecco@gmail.com> | 2022-02-14 00:49:41 -0300 |
commit | 2f90d96694dbb2cc7eee81e36639f19741920328 (patch) | |
tree | c71054a2c0ec4dc1a8ad1287809de72b30bb7564 /pint/testsuite/test_dask.py | |
parent | 5403f46ecf636d0749cf54cddf725d177d60af61 (diff) | |
download | pint-2f90d96694dbb2cc7eee81e36639f19741920328.tar.gz |
Update testsuite to avoid a complete fail when the UnitRegistry is faulty
Under no circunstances a registry should be instantiated in a
module outside a fixture to avoid error during collection.
This precludes running simple tests that do not depend on
the registry.
Diffstat (limited to 'pint/testsuite/test_dask.py')
-rw-r--r-- | pint/testsuite/test_dask.py | 79 |
1 files changed, 44 insertions, 35 deletions
diff --git a/pint/testsuite/test_dask.py b/pint/testsuite/test_dask.py index cb4d2c0..ef6044c 100644 --- a/pint/testsuite/test_dask.py +++ b/pint/testsuite/test_dask.py @@ -16,12 +16,17 @@ from distributed.utils_test import cluster, gen_cluster, loop # isort:skip loop = loop # flake8 -ureg = UnitRegistry(force_ndarray_like=True) units_ = "kilogram" -def add_five(q): - return q + 5 * ureg(units_) +@pytest.fixture(scope="module") +def local_registry(): + # Set up unit registry and sample + return UnitRegistry(force_ndarray_like=True) + + +def add_five(local_registry, q): + return q + 5 * local_registry(units_) @pytest.fixture @@ -34,21 +39,21 @@ def numpy_array(): return np.arange(0, 25, dtype=float).reshape((5, 5)) + 5 -def test_is_dask_collection(dask_array): +def test_is_dask_collection(local_registry, dask_array): """Test that a pint.Quantity wrapped Dask array is a Dask collection.""" - q = ureg.Quantity(dask_array, units_) + q = local_registry.Quantity(dask_array, units_) assert dask.is_dask_collection(q) -def test_is_not_dask_collection(numpy_array): +def test_is_not_dask_collection(local_registry, numpy_array): """Test that other pint.Quantity wrapped objects are not Dask collections.""" - q = ureg.Quantity(numpy_array, units_) + q = local_registry.Quantity(numpy_array, units_) assert not dask.is_dask_collection(q) -def test_dask_scheduler(dask_array): +def test_dask_scheduler(local_registry, dask_array): """Test that a pint.Quantity wrapped Dask array has the correct default scheduler.""" - q = ureg.Quantity(dask_array, units_) + q = local_registry.Quantity(dask_array, units_) scheduler = q.__dask_scheduler__ scheduler_name = f"{scheduler.__module__}.{scheduler.__name__}" @@ -70,27 +75,27 @@ def test_dask_scheduler(dask_array): ), ), ) -def test_dask_tokenize(item): +def test_dask_tokenize(local_registry, item): """Test that a pint.Quantity wrapping something has a unique token.""" dask_token = dask.base.tokenize(item) - q = ureg.Quantity(item, units_) + q = local_registry.Quantity(item, units_) assert dask.base.tokenize(item) != dask.base.tokenize(q) assert dask.base.tokenize(item) == dask_token -def test_dask_optimize(dask_array): +def test_dask_optimize(local_registry, dask_array): """Test that a pint.Quantity wrapped Dask array can be optimized.""" - q = ureg.Quantity(dask_array, units_) + q = local_registry.Quantity(dask_array, units_) assert q.__dask_optimize__ == dask.array.Array.__dask_optimize__ -def test_compute(dask_array, numpy_array): +def test_compute(local_registry, dask_array, numpy_array): """Test the compute() method on a pint.Quantity wrapped Dask array.""" - q = ureg.Quantity(dask_array, units_) + q = local_registry.Quantity(dask_array, units_) - comps = add_five(q) + comps = add_five(local_registry, q) res = comps.compute() assert np.all(res.m == numpy_array) @@ -99,11 +104,11 @@ def test_compute(dask_array, numpy_array): assert q.magnitude is dask_array -def test_persist(dask_array, numpy_array): +def test_persist(local_registry, dask_array, numpy_array): """Test the persist() method on a pint.Quantity wrapped Dask array.""" - q = ureg.Quantity(dask_array, units_) + q = local_registry.Quantity(dask_array, units_) - comps = add_five(q) + comps = add_five(local_registry, q) res = comps.persist() assert np.all(res.m == numpy_array) @@ -115,11 +120,11 @@ def test_persist(dask_array, numpy_array): @pytest.mark.skipif( importlib.util.find_spec("graphviz") is None, reason="GraphViz is not available" ) -def test_visualize(dask_array): +def test_visualize(local_registry, dask_array): """Test the visualize() method on a pint.Quantity wrapped Dask array.""" - q = ureg.Quantity(dask_array, units_) + q = local_registry.Quantity(dask_array, units_) - comps = add_five(q) + comps = add_five(local_registry, q) res = comps.visualize() assert res is None @@ -128,11 +133,11 @@ def test_visualize(dask_array): os.remove("mydask.png") -def test_compute_persist_equivalent(dask_array, numpy_array): +def test_compute_persist_equivalent(local_registry, dask_array, numpy_array): """Test that compute() and persist() return the same numeric results.""" - q = ureg.Quantity(dask_array, units_) + q = local_registry.Quantity(dask_array, units_) - comps = add_five(q) + comps = add_five(local_registry, q) res_compute = comps.compute() res_persist = comps.persist() @@ -141,11 +146,11 @@ def test_compute_persist_equivalent(dask_array, numpy_array): @pytest.mark.parametrize("method", ["compute", "persist", "visualize"]) -def test_exception_method_not_implemented(numpy_array, method): +def test_exception_method_not_implemented(local_registry, numpy_array, method): """Test exception handling for convenience methods on a pint.Quantity wrapped object that is not a dask.array.Array object. """ - q = ureg.Quantity(numpy_array, units_) + q = local_registry.Quantity(numpy_array, units_) exctruth = ( f"Method {method} only implemented for objects of" @@ -157,13 +162,13 @@ def test_exception_method_not_implemented(numpy_array, method): obj_method() -def test_distributed_compute(loop, dask_array, numpy_array): +def test_distributed_compute(local_registry, loop, dask_array, numpy_array): """Test compute() for distributed machines.""" - q = ureg.Quantity(dask_array, units_) + q = local_registry.Quantity(dask_array, units_) with cluster() as (s, [a, b]): with Client(s["address"], loop=loop): - comps = add_five(q) + comps = add_five(local_registry, q) res = comps.compute() assert np.all(res.m == numpy_array) @@ -173,13 +178,13 @@ def test_distributed_compute(loop, dask_array, numpy_array): assert q.magnitude is dask_array -def test_distributed_persist(loop, dask_array): +def test_distributed_persist(local_registry, loop, dask_array): """Test persist() for distributed machines.""" - q = ureg.Quantity(dask_array, units_) + q = local_registry.Quantity(dask_array, units_) with cluster() as (s, [a, b]): with Client(s["address"], loop=loop): - comps = add_five(q) + comps = add_five(local_registry, q) persisted_q = comps.persist() comps_truth = dask_array + 5 @@ -195,10 +200,14 @@ def test_distributed_persist(loop, dask_array): @gen_cluster(client=True) async def test_async(c, s, a, b): """Test asynchronous operations.""" + + # TODO: use a fixture for this. + local_registry = UnitRegistry(force_ndarray_like=True) + da = dask.array.arange(0, 25, chunks=5, dtype=float).reshape((5, 5)) - q = ureg.Quantity(da, units_) + q = local_registry.Quantity(da, units_) - x = q + ureg.Quantity(5, units_) + x = q + local_registry.Quantity(5, units_) y = x.persist() assert str(y) |