summaryrefslogtreecommitdiff
path: root/pint
diff options
context:
space:
mode:
authorHernan Grecco <hernan.grecco@gmail.com>2023-04-24 11:21:55 -0300
committerGitHub <noreply@github.com>2023-04-24 11:21:55 -0300
commit4357104a645e90b937a54318f4b5574f7282129f (patch)
tree9217e58390c4e5a88f3beb96995771fc6c8c48b6 /pint
parent62bb350c75d4d8fa61a490eb82c8770c2a09e37b (diff)
parent5cbebb618ea8364902512d0bb352d98d149da161 (diff)
downloadpint-4357104a645e90b937a54318f4b5574f7282129f.tar.gz
Merge branch 'master' into improve-latex-escaping
Diffstat (limited to 'pint')
-rw-r--r--pint/facets/dask/__init__.py14
-rw-r--r--pint/facets/numpy/numpy_func.py20
-rw-r--r--pint/facets/plain/registry.py4
-rw-r--r--pint/matplotlib.py3
-rw-r--r--pint/testsuite/baseline/test_basic_plot.pngbin17483 -> 17415 bytes
-rw-r--r--pint/testsuite/baseline/test_plot_with_non_default_format.pngbin0 -> 16617 bytes
-rw-r--r--pint/testsuite/baseline/test_plot_with_set_units.pngbin18145 -> 18176 bytes
-rw-r--r--pint/testsuite/test_dask.py2
-rw-r--r--pint/testsuite/test_issues.py16
-rw-r--r--pint/testsuite/test_matplotlib.py18
-rw-r--r--pint/testsuite/test_numpy.py32
-rw-r--r--pint/testsuite/test_pint_eval.py74
-rw-r--r--pint/util.py4
13 files changed, 158 insertions, 29 deletions
diff --git a/pint/facets/dask/__init__.py b/pint/facets/dask/__init__.py
index f99e8a2..5276d3c 100644
--- a/pint/facets/dask/__init__.py
+++ b/pint/facets/dask/__init__.py
@@ -46,10 +46,7 @@ class DaskQuantity:
def __dask_tokenize__(self):
from dask.base import tokenize
- from pint import UnitRegistry
-
- # TODO: Check if this is the right class as first argument
- return (UnitRegistry.Quantity, tokenize(self._magnitude), self.units)
+ return (type(self), tokenize(self._magnitude), self.units)
@property
def __dask_optimize__(self):
@@ -67,14 +64,9 @@ class DaskQuantity:
func, args = self._magnitude.__dask_postpersist__()
return self._dask_finalize, (func, args, self.units)
- @staticmethod
- def _dask_finalize(results, func, args, units):
+ def _dask_finalize(self, results, func, args, units):
values = func(results, *args)
-
- from pint import Quantity
-
- # TODO: Check if this is the right class as first argument
- return Quantity(values, units)
+ return type(self)(values, units)
@check_dask_array
def compute(self, **kwargs):
diff --git a/pint/facets/numpy/numpy_func.py b/pint/facets/numpy/numpy_func.py
index 7bce41e..138ed24 100644
--- a/pint/facets/numpy/numpy_func.py
+++ b/pint/facets/numpy/numpy_func.py
@@ -527,22 +527,16 @@ def _meshgrid(*xi, **kwargs):
@implements("full_like", "function")
-def _full_like(a, fill_value, dtype=None, order="K", subok=True, shape=None):
+def _full_like(a, fill_value, **kwargs):
# Make full_like by multiplying with array from ones_like in a
# non-multiplicative-unit-safe way
if hasattr(fill_value, "_REGISTRY"):
return fill_value._REGISTRY.Quantity(
- (
- np.ones_like(a, dtype=dtype, order=order, subok=subok, shape=shape)
- * fill_value.m
- ),
+ np.ones_like(a, **kwargs) * fill_value.m,
fill_value.units,
)
else:
- return (
- np.ones_like(a, dtype=dtype, order=order, subok=subok, shape=shape)
- * fill_value
- )
+ return np.ones_like(a, **kwargs) * fill_value
@implements("interp", "function")
@@ -796,6 +790,7 @@ for func_str, unit_arguments, wrap_output in [
("ptp", "a", True),
("ravel", "a", True),
("round_", "a", True),
+ ("round", "a", True),
("sort", "a", True),
("median", "a", True),
("nanmedian", "a", True),
@@ -816,8 +811,10 @@ for func_str, unit_arguments, wrap_output in [
("broadcast_to", ["array"], True),
("amax", ["a", "initial"], True),
("amin", ["a", "initial"], True),
+ ("max", ["a", "initial"], True),
+ ("min", ["a", "initial"], True),
("searchsorted", ["a", "v"], False),
- ("isclose", ["a", "b"], False),
+ ("isclose", ["a", "b", "atol"], False),
("nan_to_num", ["x", "nan", "posinf", "neginf"], True),
("clip", ["a", "a_min", "a_max"], True),
("append", ["arr", "values"], True),
@@ -827,9 +824,10 @@ for func_str, unit_arguments, wrap_output in [
("lib.stride_tricks.sliding_window_view", "x", True),
("rot90", "m", True),
("insert", ["arr", "values"], True),
+ ("delete", ["arr"], True),
("resize", "a", True),
("reshape", "a", True),
- ("allclose", ["a", "b"], False),
+ ("allclose", ["a", "b", "atol"], False),
("intersect1d", ["ar1", "ar2"], True),
]:
implement_consistent_units_by_argument(func_str, unit_arguments, wrap_output)
diff --git a/pint/facets/plain/registry.py b/pint/facets/plain/registry.py
index ffa6fb4..eed73e1 100644
--- a/pint/facets/plain/registry.py
+++ b/pint/facets/plain/registry.py
@@ -204,6 +204,7 @@ class PlainRegistry(metaclass=RegistryMeta):
case_sensitive: bool = True,
cache_folder: Union[str, pathlib.Path, None] = None,
separate_format_defaults: Optional[bool] = None,
+ mpl_formatter: str = "{:P}",
):
#: Map a definition class to a adder methods.
self._adders = dict()
@@ -244,6 +245,9 @@ class PlainRegistry(metaclass=RegistryMeta):
#: Default locale identifier string, used when calling format_babel without explicit locale.
self.set_fmt_locale(fmt_locale)
+ #: sets the formatter used when plotting with matplotlib
+ self.mpl_formatter = mpl_formatter
+
#: Numerical type used for non integer values.
self._non_int_type = non_int_type
diff --git a/pint/matplotlib.py b/pint/matplotlib.py
index 3785c7d..ea88c70 100644
--- a/pint/matplotlib.py
+++ b/pint/matplotlib.py
@@ -21,7 +21,8 @@ class PintAxisInfo(matplotlib.units.AxisInfo):
def __init__(self, units):
"""Set the default label to the pretty-print of the unit."""
- super().__init__(label="{:P}".format(units))
+ formatter = units._REGISTRY.mpl_formatter
+ super().__init__(label=formatter.format(units))
class PintConverter(matplotlib.units.ConversionInterface):
diff --git a/pint/testsuite/baseline/test_basic_plot.png b/pint/testsuite/baseline/test_basic_plot.png
index 63be609..b0c4d18 100644
--- a/pint/testsuite/baseline/test_basic_plot.png
+++ b/pint/testsuite/baseline/test_basic_plot.png
Binary files differ
diff --git a/pint/testsuite/baseline/test_plot_with_non_default_format.png b/pint/testsuite/baseline/test_plot_with_non_default_format.png
new file mode 100644
index 0000000..1cb5b18
--- /dev/null
+++ b/pint/testsuite/baseline/test_plot_with_non_default_format.png
Binary files differ
diff --git a/pint/testsuite/baseline/test_plot_with_set_units.png b/pint/testsuite/baseline/test_plot_with_set_units.png
index 5fd3ce0..a59924c 100644
--- a/pint/testsuite/baseline/test_plot_with_set_units.png
+++ b/pint/testsuite/baseline/test_plot_with_set_units.png
Binary files differ
diff --git a/pint/testsuite/test_dask.py b/pint/testsuite/test_dask.py
index 69c80fe..f4dee6a 100644
--- a/pint/testsuite/test_dask.py
+++ b/pint/testsuite/test_dask.py
@@ -149,6 +149,8 @@ def test_compute_persist_equivalent(local_registry, dask_array, numpy_array):
assert np.all(res_compute == res_persist)
assert res_compute.units == res_persist.units == units_
+ assert type(res_compute) == local_registry.Quantity
+ assert type(res_persist) == local_registry.Quantity
@pytest.mark.parametrize("method", ["compute", "persist", "visualize"])
diff --git a/pint/testsuite/test_issues.py b/pint/testsuite/test_issues.py
index cf7e39c..0c1155c 100644
--- a/pint/testsuite/test_issues.py
+++ b/pint/testsuite/test_issues.py
@@ -1,4 +1,5 @@
import copy
+import decimal
import math
import pprint
@@ -1040,6 +1041,21 @@ def test_backcompat_speed_velocity(func_registry):
assert get("[speed]") == UnitsContainer({"[length]": 1, "[time]": -1})
+def test_issue1527():
+ ureg = UnitRegistry(non_int_type=decimal.Decimal)
+ x = ureg.parse_expression("2 microliter milligram/liter")
+ assert x.magnitude.as_tuple()[1] == (2,)
+ assert x.to_compact().as_tuple()[1] == (2,)
+ assert x.to_base_units().as_tuple()[1] == (2,)
+ assert x.to("ng").as_tuple()[1] == (2,)
+
+
+def test_issue1621():
+ ureg = UnitRegistry(non_int_type=decimal.Decimal)
+ digits = ureg.Quantity("5.0 mV/m").to_base_units().magnitude.as_tuple()[1]
+ assert digits == (5, 0)
+
+
def test_issue1631():
import pint
diff --git a/pint/testsuite/test_matplotlib.py b/pint/testsuite/test_matplotlib.py
index 25f3172..0735721 100644
--- a/pint/testsuite/test_matplotlib.py
+++ b/pint/testsuite/test_matplotlib.py
@@ -46,3 +46,21 @@ def test_plot_with_set_units(local_registry):
ax.axvline(120 * local_registry.minutes, color="tab:green")
return fig
+
+
+@pytest.mark.mpl_image_compare(tolerance=0, remove_text=True)
+def test_plot_with_non_default_format(local_registry):
+ local_registry.mpl_formatter = "{:~P}"
+
+ y = np.linspace(0, 30) * local_registry.miles
+ x = np.linspace(0, 5) * local_registry.hours
+
+ fig, ax = plt.subplots()
+ ax.yaxis.set_units(local_registry.inches)
+ ax.xaxis.set_units(local_registry.seconds)
+
+ ax.plot(x, y, "tab:blue")
+ ax.axhline(26400 * local_registry.feet, color="tab:red")
+ ax.axvline(120 * local_registry.minutes, color="tab:green")
+
+ return fig
diff --git a/pint/testsuite/test_numpy.py b/pint/testsuite/test_numpy.py
index 83448ce..f0f95bc 100644
--- a/pint/testsuite/test_numpy.py
+++ b/pint/testsuite/test_numpy.py
@@ -806,7 +806,7 @@ class TestNumpyUnclassified(TestNumpyMethods):
np.around(1.0275 * self.ureg.m, decimals=2), 1.03 * self.ureg.m
)
helpers.assert_quantity_equal(
- np.round_(1.0275 * self.ureg.m, decimals=2), 1.03 * self.ureg.m
+ np.round(1.0275 * self.ureg.m, decimals=2), 1.03 * self.ureg.m
)
def test_trace(self):
@@ -1050,7 +1050,7 @@ class TestNumpyUnclassified(TestNumpyMethods):
np.isclose(self.q, q2), np.array([[False, True], [True, False]])
)
self.assertNDArrayEqual(
- np.isclose(self.q, q2, atol=1e-5, rtol=1e-7),
+ np.isclose(self.q, q2, atol=1e-5 * self.ureg.mm, rtol=1e-7),
np.array([[False, True], [True, False]]),
)
@@ -1222,6 +1222,24 @@ class TestNumpyUnclassified(TestNumpyMethods):
np.array([[1, 0, 2], [3, 0, 4]]) * self.ureg.m,
)
+ @helpers.requires_array_function_protocol()
+ def test_delete(self):
+ q = self.Q_(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]), "m")
+ helpers.assert_quantity_equal(
+ np.delete(q, 1, axis=0),
+ np.array([[1, 2, 3, 4], [9, 10, 11, 12]]) * self.ureg.m,
+ )
+
+ helpers.assert_quantity_equal(
+ np.delete(q, np.s_[::2], 1),
+ np.array([[2, 4], [6, 8], [10, 12]]) * self.ureg.m,
+ )
+
+ helpers.assert_quantity_equal(
+ np.delete(q, [1, 3, 5], None),
+ np.array([1, 3, 5, 7, 8, 9, 10, 11, 12]) * self.ureg.m,
+ )
+
def test_ndarray_downcast(self):
with pytest.warns(UnitStrippedWarning):
np.asarray(self.q)
@@ -1355,6 +1373,16 @@ class TestNumpyUnclassified(TestNumpyMethods):
assert not np.allclose(
[1e10, 1e-8] * self.ureg.m, [1.00001e10, 1e-9] * self.ureg.mm
)
+ assert np.allclose(
+ [1e10, 1e-8] * self.ureg.m,
+ [1.00001e10, 1e-9] * self.ureg.m,
+ atol=1e-8 * self.ureg.m,
+ )
+
+ with pytest.raises(DimensionalityError):
+ assert np.allclose(
+ [1e10, 1e-8] * self.ureg.m, [1.00001e10, 1e-9] * self.ureg.m, atol=1e-8
+ )
@helpers.requires_array_function_protocol()
def test_intersect1d(self):
diff --git a/pint/testsuite/test_pint_eval.py b/pint/testsuite/test_pint_eval.py
index bed8105..b5b94f0 100644
--- a/pint/testsuite/test_pint_eval.py
+++ b/pint/testsuite/test_pint_eval.py
@@ -2,10 +2,13 @@ import pytest
from pint.compat import tokenizer
from pint.pint_eval import build_eval_tree
+from pint.util import string_preprocessor
class TestPintEval:
- def _test_one(self, input_text, parsed):
+ def _test_one(self, input_text, parsed, preprocess=False):
+ if preprocess:
+ input_text = string_preprocessor(input_text)
assert build_eval_tree(tokenizer(input_text)).to_string() == parsed
@pytest.mark.parametrize(
@@ -13,6 +16,7 @@ class TestPintEval:
(
("3", "3"),
("1 + 2", "(1 + 2)"),
+ ("1 - 2", "(1 - 2)"),
("2 * 3 + 4", "((2 * 3) + 4)"), # order of operations
("2 * (3 + 4)", "(2 * (3 + 4))"), # parentheses
(
@@ -71,4 +75,70 @@ class TestPintEval:
),
)
def test_build_eval_tree(self, input_text, parsed):
- self._test_one(input_text, parsed)
+ self._test_one(input_text, parsed, preprocess=False)
+
+ @pytest.mark.parametrize(
+ ("input_text", "parsed"),
+ (
+ ("3", "3"),
+ ("1 + 2", "(1 + 2)"),
+ ("1 - 2", "(1 - 2)"),
+ ("2 * 3 + 4", "((2 * 3) + 4)"), # order of operations
+ ("2 * (3 + 4)", "(2 * (3 + 4))"), # parentheses
+ (
+ "1 + 2 * 3 ** (4 + 3 / 5)",
+ "(1 + (2 * (3 ** (4 + (3 / 5)))))",
+ ), # more order of operations
+ (
+ "1 * ((3 + 4) * 5)",
+ "(1 * ((3 + 4) * 5))",
+ ), # nested parentheses at beginning
+ ("1 * (5 * (3 + 4))", "(1 * (5 * (3 + 4)))"), # nested parentheses at end
+ (
+ "1 * (5 * (3 + 4) / 6)",
+ "(1 * ((5 * (3 + 4)) / 6))",
+ ), # nested parentheses in middle
+ ("-1", "(- 1)"), # unary
+ ("3 * -1", "(3 * (- 1))"), # unary
+ ("3 * --1", "(3 * (- (- 1)))"), # double unary
+ ("3 * -(2 + 4)", "(3 * (- (2 + 4)))"), # parenthetical unary
+ ("3 * -((2 + 4))", "(3 * (- (2 + 4)))"), # parenthetical unary
+ # implicit op
+ ("3 4", "(3 * 4)"),
+ # implicit op, then parentheses
+ ("3 (2 + 4)", "(3 * (2 + 4))"),
+ # parentheses, then implicit
+ ("(3 ** 4 ) 5", "((3 ** 4) * 5)"),
+ # implicit op, then exponentiation
+ ("3 4 ** 5", "(3 * (4 ** 5))"),
+ # implicit op, then addition
+ ("3 4 + 5", "((3 * 4) + 5)"),
+ # power followed by implicit
+ ("3 ** 4 5", "((3 ** 4) * 5)"),
+ # implicit with parentheses
+ ("3 (4 ** 5)", "(3 * (4 ** 5))"),
+ # exponent with e
+ ("3e-1", "3e-1"),
+ # multiple units with exponents
+ ("kg ** 1 * s ** 2", "((kg ** 1) * (s ** 2))"),
+ # multiple units with neg exponents
+ ("kg ** -1 * s ** -2", "((kg ** (- 1)) * (s ** (- 2)))"),
+ # multiple units with neg exponents
+ ("kg^-1 * s^-2", "((kg ** (- 1)) * (s ** (- 2)))"),
+ # multiple units with neg exponents, implicit op
+ ("kg^-1 s^-2", "((kg ** (- 1)) * (s ** (- 2)))"),
+ # nested power
+ ("2 ^ 3 ^ 2", "(2 ** (3 ** 2))"),
+ # nested power
+ ("gram * second / meter ** 2", "((gram * second) / (meter ** 2))"),
+ # nested power
+ ("gram / meter ** 2 / second", "((gram / (meter ** 2)) / second)"),
+ # units should behave like numbers, so we don't need a bunch of extra tests for them
+ # implicit op, then addition
+ ("3 kg + 5", "((3 * kg) + 5)"),
+ ("(5 % 2) m", "((5 % 2) * m)"), # mod operator
+ ("(5 // 2) m", "((5 // 2) * m)"), # floordiv operator
+ ),
+ )
+ def test_preprocessed_eval_tree(self, input_text, parsed):
+ self._test_one(input_text, parsed, preprocess=True)
diff --git a/pint/util.py b/pint/util.py
index 3d00175..420914f 100644
--- a/pint/util.py
+++ b/pint/util.py
@@ -566,7 +566,7 @@ class ParserHelper(UnitsContainer):
if non_int_type is float:
return cls(1, [(input_word, 1)], non_int_type=non_int_type)
else:
- ONE = non_int_type("1.0")
+ ONE = non_int_type("1")
return cls(ONE, [(input_word, ONE)], non_int_type=non_int_type)
@classmethod
@@ -765,7 +765,7 @@ _subs_re_list = [
r"\b([0-9]+\.?[0-9]*)(?=[e|E][a-zA-Z]|[a-df-zA-DF-Z])",
r"\1*",
), # Handle numberLetter for multiplication
- (r"([\w\.\-])\s+(?=\w)", r"\1*"), # Handle space for multiplication
+ (r"([\w\.\)])\s+(?=[\w\(])", r"\1*"), # Handle space for multiplication
]
#: Compiles the regex and replace {} by a regex that matches an identifier.