summaryrefslogtreecommitdiff
path: root/pint/testsuite/test_dask.py
diff options
context:
space:
mode:
authorHernan <hernan.grecco@gmail.com>2022-02-14 00:41:09 -0300
committerHernan <hernan.grecco@gmail.com>2022-02-14 00:49:41 -0300
commit2f90d96694dbb2cc7eee81e36639f19741920328 (patch)
treec71054a2c0ec4dc1a8ad1287809de72b30bb7564 /pint/testsuite/test_dask.py
parent5403f46ecf636d0749cf54cddf725d177d60af61 (diff)
downloadpint-2f90d96694dbb2cc7eee81e36639f19741920328.tar.gz
Update testsuite to avoid a complete fail when the UnitRegistry is faulty
Under no circunstances a registry should be instantiated in a module outside a fixture to avoid error during collection. This precludes running simple tests that do not depend on the registry.
Diffstat (limited to 'pint/testsuite/test_dask.py')
-rw-r--r--pint/testsuite/test_dask.py79
1 files changed, 44 insertions, 35 deletions
diff --git a/pint/testsuite/test_dask.py b/pint/testsuite/test_dask.py
index cb4d2c0..ef6044c 100644
--- a/pint/testsuite/test_dask.py
+++ b/pint/testsuite/test_dask.py
@@ -16,12 +16,17 @@ from distributed.utils_test import cluster, gen_cluster, loop # isort:skip
loop = loop # flake8
-ureg = UnitRegistry(force_ndarray_like=True)
units_ = "kilogram"
-def add_five(q):
- return q + 5 * ureg(units_)
+@pytest.fixture(scope="module")
+def local_registry():
+ # Set up unit registry and sample
+ return UnitRegistry(force_ndarray_like=True)
+
+
+def add_five(local_registry, q):
+ return q + 5 * local_registry(units_)
@pytest.fixture
@@ -34,21 +39,21 @@ def numpy_array():
return np.arange(0, 25, dtype=float).reshape((5, 5)) + 5
-def test_is_dask_collection(dask_array):
+def test_is_dask_collection(local_registry, dask_array):
"""Test that a pint.Quantity wrapped Dask array is a Dask collection."""
- q = ureg.Quantity(dask_array, units_)
+ q = local_registry.Quantity(dask_array, units_)
assert dask.is_dask_collection(q)
-def test_is_not_dask_collection(numpy_array):
+def test_is_not_dask_collection(local_registry, numpy_array):
"""Test that other pint.Quantity wrapped objects are not Dask collections."""
- q = ureg.Quantity(numpy_array, units_)
+ q = local_registry.Quantity(numpy_array, units_)
assert not dask.is_dask_collection(q)
-def test_dask_scheduler(dask_array):
+def test_dask_scheduler(local_registry, dask_array):
"""Test that a pint.Quantity wrapped Dask array has the correct default scheduler."""
- q = ureg.Quantity(dask_array, units_)
+ q = local_registry.Quantity(dask_array, units_)
scheduler = q.__dask_scheduler__
scheduler_name = f"{scheduler.__module__}.{scheduler.__name__}"
@@ -70,27 +75,27 @@ def test_dask_scheduler(dask_array):
),
),
)
-def test_dask_tokenize(item):
+def test_dask_tokenize(local_registry, item):
"""Test that a pint.Quantity wrapping something has a unique token."""
dask_token = dask.base.tokenize(item)
- q = ureg.Quantity(item, units_)
+ q = local_registry.Quantity(item, units_)
assert dask.base.tokenize(item) != dask.base.tokenize(q)
assert dask.base.tokenize(item) == dask_token
-def test_dask_optimize(dask_array):
+def test_dask_optimize(local_registry, dask_array):
"""Test that a pint.Quantity wrapped Dask array can be optimized."""
- q = ureg.Quantity(dask_array, units_)
+ q = local_registry.Quantity(dask_array, units_)
assert q.__dask_optimize__ == dask.array.Array.__dask_optimize__
-def test_compute(dask_array, numpy_array):
+def test_compute(local_registry, dask_array, numpy_array):
"""Test the compute() method on a pint.Quantity wrapped Dask array."""
- q = ureg.Quantity(dask_array, units_)
+ q = local_registry.Quantity(dask_array, units_)
- comps = add_five(q)
+ comps = add_five(local_registry, q)
res = comps.compute()
assert np.all(res.m == numpy_array)
@@ -99,11 +104,11 @@ def test_compute(dask_array, numpy_array):
assert q.magnitude is dask_array
-def test_persist(dask_array, numpy_array):
+def test_persist(local_registry, dask_array, numpy_array):
"""Test the persist() method on a pint.Quantity wrapped Dask array."""
- q = ureg.Quantity(dask_array, units_)
+ q = local_registry.Quantity(dask_array, units_)
- comps = add_five(q)
+ comps = add_five(local_registry, q)
res = comps.persist()
assert np.all(res.m == numpy_array)
@@ -115,11 +120,11 @@ def test_persist(dask_array, numpy_array):
@pytest.mark.skipif(
importlib.util.find_spec("graphviz") is None, reason="GraphViz is not available"
)
-def test_visualize(dask_array):
+def test_visualize(local_registry, dask_array):
"""Test the visualize() method on a pint.Quantity wrapped Dask array."""
- q = ureg.Quantity(dask_array, units_)
+ q = local_registry.Quantity(dask_array, units_)
- comps = add_five(q)
+ comps = add_five(local_registry, q)
res = comps.visualize()
assert res is None
@@ -128,11 +133,11 @@ def test_visualize(dask_array):
os.remove("mydask.png")
-def test_compute_persist_equivalent(dask_array, numpy_array):
+def test_compute_persist_equivalent(local_registry, dask_array, numpy_array):
"""Test that compute() and persist() return the same numeric results."""
- q = ureg.Quantity(dask_array, units_)
+ q = local_registry.Quantity(dask_array, units_)
- comps = add_five(q)
+ comps = add_five(local_registry, q)
res_compute = comps.compute()
res_persist = comps.persist()
@@ -141,11 +146,11 @@ def test_compute_persist_equivalent(dask_array, numpy_array):
@pytest.mark.parametrize("method", ["compute", "persist", "visualize"])
-def test_exception_method_not_implemented(numpy_array, method):
+def test_exception_method_not_implemented(local_registry, numpy_array, method):
"""Test exception handling for convenience methods on a pint.Quantity wrapped
object that is not a dask.array.Array object.
"""
- q = ureg.Quantity(numpy_array, units_)
+ q = local_registry.Quantity(numpy_array, units_)
exctruth = (
f"Method {method} only implemented for objects of"
@@ -157,13 +162,13 @@ def test_exception_method_not_implemented(numpy_array, method):
obj_method()
-def test_distributed_compute(loop, dask_array, numpy_array):
+def test_distributed_compute(local_registry, loop, dask_array, numpy_array):
"""Test compute() for distributed machines."""
- q = ureg.Quantity(dask_array, units_)
+ q = local_registry.Quantity(dask_array, units_)
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop):
- comps = add_five(q)
+ comps = add_five(local_registry, q)
res = comps.compute()
assert np.all(res.m == numpy_array)
@@ -173,13 +178,13 @@ def test_distributed_compute(loop, dask_array, numpy_array):
assert q.magnitude is dask_array
-def test_distributed_persist(loop, dask_array):
+def test_distributed_persist(local_registry, loop, dask_array):
"""Test persist() for distributed machines."""
- q = ureg.Quantity(dask_array, units_)
+ q = local_registry.Quantity(dask_array, units_)
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop):
- comps = add_five(q)
+ comps = add_five(local_registry, q)
persisted_q = comps.persist()
comps_truth = dask_array + 5
@@ -195,10 +200,14 @@ def test_distributed_persist(loop, dask_array):
@gen_cluster(client=True)
async def test_async(c, s, a, b):
"""Test asynchronous operations."""
+
+ # TODO: use a fixture for this.
+ local_registry = UnitRegistry(force_ndarray_like=True)
+
da = dask.array.arange(0, 25, chunks=5, dtype=float).reshape((5, 5))
- q = ureg.Quantity(da, units_)
+ q = local_registry.Quantity(da, units_)
- x = q + ureg.Quantity(5, units_)
+ x = q + local_registry.Quantity(5, units_)
y = x.persist()
assert str(y)