summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.pre-commit-config.yaml11
-rw-r--r--CHANGES10
-rw-r--r--README.rst4
-rw-r--r--docs/user/numpy.ipynb85
-rw-r--r--docs/user/plotting.rst25
-rw-r--r--pint/facets/dask/__init__.py14
-rw-r--r--pint/facets/numpy/numpy_func.py20
-rw-r--r--pint/facets/plain/registry.py4
-rw-r--r--pint/matplotlib.py3
-rw-r--r--pint/testsuite/baseline/test_basic_plot.pngbin17483 -> 17415 bytes
-rw-r--r--pint/testsuite/baseline/test_plot_with_non_default_format.pngbin0 -> 16617 bytes
-rw-r--r--pint/testsuite/baseline/test_plot_with_set_units.pngbin18145 -> 18176 bytes
-rw-r--r--pint/testsuite/test_dask.py2
-rw-r--r--pint/testsuite/test_matplotlib.py18
-rw-r--r--pint/testsuite/test_numpy.py32
-rw-r--r--pint/testsuite/test_pint_eval.py74
-rw-r--r--pint/util.py2
17 files changed, 227 insertions, 77 deletions
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 83587c6..74e4522 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -7,14 +7,19 @@ repos:
- id: end-of-file-fixer
- id: check-yaml
- repo: https://github.com/psf/black
- rev: 22.10.0
+ rev: 22.12.0
hooks:
- - id: black
+ - id: black-jupyter
- repo: https://github.com/pycqa/isort
- rev: 5.10.1
+ rev: 5.12.0
hooks:
- id: isort
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
hooks:
- id: flake8
+- repo: https://github.com/kynan/nbstripout
+ rev: 0.6.1
+ hooks:
+ - id: nbstripout
+ args: [--extra-keys=metadata.kernelspec metadata.language_info.version]
diff --git a/CHANGES b/CHANGES
index 4f79f8f..e1084b4 100644
--- a/CHANGES
+++ b/CHANGES
@@ -4,21 +4,31 @@ Pint Changelog
0.21 (unreleased)
-----------------
+- Exposed matplotlib unit formatter (PR #1703)
- Fix error when when re-registering a formatter.
(PR #1629)
- Add new SI prefixes: ronna-, ronto-, quetta-, quecto-.
(PR #1652)
+- Fix unit check with `atol` using `np.allclose` & `np.isclose`.
+ (Issue #1658)
- Implementation for numpy.positive added for Quantity.
(PR #1663)
- Changed frequency to angular frequency in the docs.
(PR #1668)
- Avoid addition of spurious trailing zeros when converting units and non-int-type is
Decimal (PR #1625).
+- Implementation for numpy.delete added for Quantity.
+ (PR #1669)
+- Fixed Quantity type returned from `__dask_postcompute__`.
+ (PR #1722)
+
### Breaking Changes
- Support percent and ppm units. Support the `%` symbol.
(Issue #1277)
+- Fix error when parsing subtraction operator followed by white space.
+ (PR #1701)
0.20.1 (2022-10-27)
diff --git a/README.rst b/README.rst
index 86c8f77..32879d9 100644
--- a/README.rst
+++ b/README.rst
@@ -153,7 +153,7 @@ see CHANGES_
.. _`NumPy`: http://www.numpy.org/
.. _`PEP 3101`: https://www.python.org/dev/peps/pep-3101/
.. _`Babel`: http://babel.pocoo.org/
-.. _`Pandas Extension Types`: https://pandas.pydata.org/pandas-docs/stable/extending.html#extension-types
-.. _`pint-pandas Jupyter notebook`: https://github.com/hgrecco/pint-pandas/blob/master/notebooks/pandas_support.ipynb
+.. _`Pandas Extension Types`: https://pandas.pydata.org/pandas-docs/stable/development/extending.html#extension-types
+.. _`pint-pandas Jupyter notebook`: https://github.com/hgrecco/pint-pandas/blob/master/notebooks/pint-pandas.ipynb
.. _`AUTHORS`: https://github.com/hgrecco/pint/blob/master/AUTHORS
.. _`CHANGES`: https://github.com/hgrecco/pint/blob/master/CHANGES
diff --git a/docs/user/numpy.ipynb b/docs/user/numpy.ipynb
index 2586626..5491001 100644
--- a/docs/user/numpy.ipynb
+++ b/docs/user/numpy.ipynb
@@ -37,11 +37,13 @@
"\n",
"# Import Pint\n",
"import pint\n",
+ "\n",
"ureg = pint.UnitRegistry()\n",
"Q_ = ureg.Quantity\n",
"\n",
"# Silence NEP 18 warning\n",
"import warnings\n",
+ "\n",
"with warnings.catch_warnings():\n",
" warnings.simplefilter(\"ignore\")\n",
" Q_([])"
@@ -68,7 +70,7 @@
},
"outputs": [],
"source": [
- "legs1 = Q_(np.asarray([3., 4.]), 'meter')\n",
+ "legs1 = Q_(np.asarray([3.0, 4.0]), \"meter\")\n",
"print(legs1)"
]
},
@@ -82,7 +84,7 @@
},
"outputs": [],
"source": [
- "legs1 = [3., 4.] * ureg.meter\n",
+ "legs1 = [3.0, 4.0] * ureg.meter\n",
"print(legs1)"
]
},
@@ -107,7 +109,7 @@
},
"outputs": [],
"source": [
- "print(legs1.to('kilometer'))"
+ "print(legs1.to(\"kilometer\"))"
]
},
{
@@ -134,7 +136,7 @@
"outputs": [],
"source": [
"try:\n",
- " legs1.to('joule')\n",
+ " legs1.to(\"joule\")\n",
"except pint.DimensionalityError as exc:\n",
" print(exc)"
]
@@ -160,7 +162,7 @@
},
"outputs": [],
"source": [
- "legs2 = [400., 300.] * ureg.centimeter\n",
+ "legs2 = [400.0, 300.0] * ureg.centimeter\n",
"print(legs2)"
]
},
@@ -214,7 +216,7 @@
},
"outputs": [],
"source": [
- "angles = np.arccos(legs2/hyps)\n",
+ "angles = np.arccos(legs2 / hyps)\n",
"print(angles)"
]
},
@@ -239,7 +241,7 @@
},
"outputs": [],
"source": [
- "print(angles.to('degree'))"
+ "print(angles.to(\"degree\"))"
]
},
{
@@ -302,6 +304,7 @@
"outputs": [],
"source": [
"from pint.facets.numpy.numpy_func import HANDLED_FUNCTIONS\n",
+ "\n",
"print(sorted(list(HANDLED_FUNCTIONS)))"
]
},
@@ -374,27 +377,27 @@
"source": [
"from graphviz import Digraph\n",
"\n",
- "g = Digraph(graph_attr={'size': '8,5'}, node_attr={'fontname': 'courier'})\n",
- "g.edge('Dask array', 'NumPy ndarray')\n",
- "g.edge('Dask array', 'CuPy ndarray')\n",
- "g.edge('Dask array', 'Sparse COO')\n",
- "g.edge('Dask array', 'NumPy masked array', style='dashed')\n",
- "g.edge('CuPy ndarray', 'NumPy ndarray')\n",
- "g.edge('Sparse COO', 'NumPy ndarray')\n",
- "g.edge('NumPy masked array', 'NumPy ndarray')\n",
- "g.edge('Jax array', 'NumPy ndarray')\n",
- "g.edge('Pint Quantity', 'Dask array', style='dashed')\n",
- "g.edge('Pint Quantity', 'NumPy ndarray')\n",
- "g.edge('Pint Quantity', 'CuPy ndarray', style='dashed')\n",
- "g.edge('Pint Quantity', 'Sparse COO')\n",
- "g.edge('Pint Quantity', 'NumPy masked array', style='dashed')\n",
- "g.edge('xarray Dataset/DataArray/Variable', 'Dask array')\n",
- "g.edge('xarray Dataset/DataArray/Variable', 'CuPy ndarray', style='dashed')\n",
- "g.edge('xarray Dataset/DataArray/Variable', 'Sparse COO')\n",
- "g.edge('xarray Dataset/DataArray/Variable', 'NumPy ndarray')\n",
- "g.edge('xarray Dataset/DataArray/Variable', 'NumPy masked array', style='dashed')\n",
- "g.edge('xarray Dataset/DataArray/Variable', 'Pint Quantity')\n",
- "g.edge('xarray Dataset/DataArray/Variable', 'Jax array', style='dashed')\n",
+ "g = Digraph(graph_attr={\"size\": \"8,5\"}, node_attr={\"fontname\": \"courier\"})\n",
+ "g.edge(\"Dask array\", \"NumPy ndarray\")\n",
+ "g.edge(\"Dask array\", \"CuPy ndarray\")\n",
+ "g.edge(\"Dask array\", \"Sparse COO\")\n",
+ "g.edge(\"Dask array\", \"NumPy masked array\", style=\"dashed\")\n",
+ "g.edge(\"CuPy ndarray\", \"NumPy ndarray\")\n",
+ "g.edge(\"Sparse COO\", \"NumPy ndarray\")\n",
+ "g.edge(\"NumPy masked array\", \"NumPy ndarray\")\n",
+ "g.edge(\"Jax array\", \"NumPy ndarray\")\n",
+ "g.edge(\"Pint Quantity\", \"Dask array\", style=\"dashed\")\n",
+ "g.edge(\"Pint Quantity\", \"NumPy ndarray\")\n",
+ "g.edge(\"Pint Quantity\", \"CuPy ndarray\", style=\"dashed\")\n",
+ "g.edge(\"Pint Quantity\", \"Sparse COO\")\n",
+ "g.edge(\"Pint Quantity\", \"NumPy masked array\", style=\"dashed\")\n",
+ "g.edge(\"xarray Dataset/DataArray/Variable\", \"Dask array\")\n",
+ "g.edge(\"xarray Dataset/DataArray/Variable\", \"CuPy ndarray\", style=\"dashed\")\n",
+ "g.edge(\"xarray Dataset/DataArray/Variable\", \"Sparse COO\")\n",
+ "g.edge(\"xarray Dataset/DataArray/Variable\", \"NumPy ndarray\")\n",
+ "g.edge(\"xarray Dataset/DataArray/Variable\", \"NumPy masked array\", style=\"dashed\")\n",
+ "g.edge(\"xarray Dataset/DataArray/Variable\", \"Pint Quantity\")\n",
+ "g.edge(\"xarray Dataset/DataArray/Variable\", \"Jax array\", style=\"dashed\")\n",
"g"
]
},
@@ -424,10 +427,10 @@
"import xarray as xr\n",
"\n",
"# Load tutorial data\n",
- "air = xr.tutorial.load_dataset('air_temperature')['air'][0]\n",
+ "air = xr.tutorial.load_dataset(\"air_temperature\")[\"air\"][0]\n",
"\n",
"# Convert to Quantity\n",
- "air.data = Q_(air.data, air.attrs.pop('units', ''))\n",
+ "air.data = Q_(air.data, air.attrs.pop(\"units\", \"\"))\n",
"\n",
"print(air)\n",
"print()\n",
@@ -494,7 +497,7 @@
"m = np.ma.masked_array([2, 3, 5, 7], mask=[False, True, False, True])\n",
"\n",
"# Must create using Quantity class\n",
- "print(repr(ureg.Quantity(m, 'm')))\n",
+ "print(repr(ureg.Quantity(m, \"m\")))\n",
"print()\n",
"\n",
"# DO NOT create using multiplication until\n",
@@ -568,14 +571,14 @@
"x[x < 0.95] = 0\n",
"\n",
"data = xr.DataArray(\n",
- " Q_(x.map_blocks(COO), 'm'),\n",
- " dims=('z', 'y', 'x'),\n",
+ " Q_(x.map_blocks(COO), \"m\"),\n",
+ " dims=(\"z\", \"y\", \"x\"),\n",
" coords={\n",
- " 'z': np.arange(100),\n",
- " 'y': np.arange(100) - 50,\n",
- " 'x': np.arange(100) * 1.5 - 20\n",
+ " \"z\": np.arange(100),\n",
+ " \"y\": np.arange(100) - 50,\n",
+ " \"x\": np.arange(100) * 1.5 - 20,\n",
" },\n",
- " name='test'\n",
+ " name=\"test\",\n",
")\n",
"\n",
"print(data)\n",
@@ -627,11 +630,6 @@
}
],
"metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
"language_info": {
"codemirror_mode": {
"name": "ipython",
@@ -641,8 +639,7 @@
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.2"
+ "pygments_lexer": "ipython3"
}
},
"nbformat": 4,
diff --git a/docs/user/plotting.rst b/docs/user/plotting.rst
index a008d45..3c3fc39 100644
--- a/docs/user/plotting.rst
+++ b/docs/user/plotting.rst
@@ -70,6 +70,31 @@ This also allows controlling the actual plotting units for the x and y axes:
ax.axhline(26400 * ureg.feet, color='tab:red')
ax.axvline(120 * ureg.minutes, color='tab:green')
+Users have the possibility to change the format of the units on the plot:
+
+.. plot::
+ :include-source: true
+
+ import matplotlib.pyplot as plt
+ import numpy as np
+ import pint
+
+ ureg = pint.UnitRegistry()
+ ureg.setup_matplotlib(True)
+
+ ureg.mpl_formatter = "{:~P}"
+
+ y = np.linspace(0, 30) * ureg.miles
+ x = np.linspace(0, 5) * ureg.hours
+
+ fig, ax = plt.subplots()
+ ax.yaxis.set_units(ureg.inches)
+ ax.xaxis.set_units(ureg.seconds)
+
+ ax.plot(x, y, 'tab:blue')
+ ax.axhline(26400 * ureg.feet, color='tab:red')
+ ax.axvline(120 * ureg.minutes, color='tab:green')
+
For more information, visit the Matplotlib_ home page.
.. _Matplotlib: https://matplotlib.org
diff --git a/pint/facets/dask/__init__.py b/pint/facets/dask/__init__.py
index f99e8a2..5276d3c 100644
--- a/pint/facets/dask/__init__.py
+++ b/pint/facets/dask/__init__.py
@@ -46,10 +46,7 @@ class DaskQuantity:
def __dask_tokenize__(self):
from dask.base import tokenize
- from pint import UnitRegistry
-
- # TODO: Check if this is the right class as first argument
- return (UnitRegistry.Quantity, tokenize(self._magnitude), self.units)
+ return (type(self), tokenize(self._magnitude), self.units)
@property
def __dask_optimize__(self):
@@ -67,14 +64,9 @@ class DaskQuantity:
func, args = self._magnitude.__dask_postpersist__()
return self._dask_finalize, (func, args, self.units)
- @staticmethod
- def _dask_finalize(results, func, args, units):
+ def _dask_finalize(self, results, func, args, units):
values = func(results, *args)
-
- from pint import Quantity
-
- # TODO: Check if this is the right class as first argument
- return Quantity(values, units)
+ return type(self)(values, units)
@check_dask_array
def compute(self, **kwargs):
diff --git a/pint/facets/numpy/numpy_func.py b/pint/facets/numpy/numpy_func.py
index 7bce41e..138ed24 100644
--- a/pint/facets/numpy/numpy_func.py
+++ b/pint/facets/numpy/numpy_func.py
@@ -527,22 +527,16 @@ def _meshgrid(*xi, **kwargs):
@implements("full_like", "function")
-def _full_like(a, fill_value, dtype=None, order="K", subok=True, shape=None):
+def _full_like(a, fill_value, **kwargs):
# Make full_like by multiplying with array from ones_like in a
# non-multiplicative-unit-safe way
if hasattr(fill_value, "_REGISTRY"):
return fill_value._REGISTRY.Quantity(
- (
- np.ones_like(a, dtype=dtype, order=order, subok=subok, shape=shape)
- * fill_value.m
- ),
+ np.ones_like(a, **kwargs) * fill_value.m,
fill_value.units,
)
else:
- return (
- np.ones_like(a, dtype=dtype, order=order, subok=subok, shape=shape)
- * fill_value
- )
+ return np.ones_like(a, **kwargs) * fill_value
@implements("interp", "function")
@@ -796,6 +790,7 @@ for func_str, unit_arguments, wrap_output in [
("ptp", "a", True),
("ravel", "a", True),
("round_", "a", True),
+ ("round", "a", True),
("sort", "a", True),
("median", "a", True),
("nanmedian", "a", True),
@@ -816,8 +811,10 @@ for func_str, unit_arguments, wrap_output in [
("broadcast_to", ["array"], True),
("amax", ["a", "initial"], True),
("amin", ["a", "initial"], True),
+ ("max", ["a", "initial"], True),
+ ("min", ["a", "initial"], True),
("searchsorted", ["a", "v"], False),
- ("isclose", ["a", "b"], False),
+ ("isclose", ["a", "b", "atol"], False),
("nan_to_num", ["x", "nan", "posinf", "neginf"], True),
("clip", ["a", "a_min", "a_max"], True),
("append", ["arr", "values"], True),
@@ -827,9 +824,10 @@ for func_str, unit_arguments, wrap_output in [
("lib.stride_tricks.sliding_window_view", "x", True),
("rot90", "m", True),
("insert", ["arr", "values"], True),
+ ("delete", ["arr"], True),
("resize", "a", True),
("reshape", "a", True),
- ("allclose", ["a", "b"], False),
+ ("allclose", ["a", "b", "atol"], False),
("intersect1d", ["ar1", "ar2"], True),
]:
implement_consistent_units_by_argument(func_str, unit_arguments, wrap_output)
diff --git a/pint/facets/plain/registry.py b/pint/facets/plain/registry.py
index ffa6fb4..eed73e1 100644
--- a/pint/facets/plain/registry.py
+++ b/pint/facets/plain/registry.py
@@ -204,6 +204,7 @@ class PlainRegistry(metaclass=RegistryMeta):
case_sensitive: bool = True,
cache_folder: Union[str, pathlib.Path, None] = None,
separate_format_defaults: Optional[bool] = None,
+ mpl_formatter: str = "{:P}",
):
#: Map a definition class to a adder methods.
self._adders = dict()
@@ -244,6 +245,9 @@ class PlainRegistry(metaclass=RegistryMeta):
#: Default locale identifier string, used when calling format_babel without explicit locale.
self.set_fmt_locale(fmt_locale)
+ #: sets the formatter used when plotting with matplotlib
+ self.mpl_formatter = mpl_formatter
+
#: Numerical type used for non integer values.
self._non_int_type = non_int_type
diff --git a/pint/matplotlib.py b/pint/matplotlib.py
index 3785c7d..ea88c70 100644
--- a/pint/matplotlib.py
+++ b/pint/matplotlib.py
@@ -21,7 +21,8 @@ class PintAxisInfo(matplotlib.units.AxisInfo):
def __init__(self, units):
"""Set the default label to the pretty-print of the unit."""
- super().__init__(label="{:P}".format(units))
+ formatter = units._REGISTRY.mpl_formatter
+ super().__init__(label=formatter.format(units))
class PintConverter(matplotlib.units.ConversionInterface):
diff --git a/pint/testsuite/baseline/test_basic_plot.png b/pint/testsuite/baseline/test_basic_plot.png
index 63be609..b0c4d18 100644
--- a/pint/testsuite/baseline/test_basic_plot.png
+++ b/pint/testsuite/baseline/test_basic_plot.png
Binary files differ
diff --git a/pint/testsuite/baseline/test_plot_with_non_default_format.png b/pint/testsuite/baseline/test_plot_with_non_default_format.png
new file mode 100644
index 0000000..1cb5b18
--- /dev/null
+++ b/pint/testsuite/baseline/test_plot_with_non_default_format.png
Binary files differ
diff --git a/pint/testsuite/baseline/test_plot_with_set_units.png b/pint/testsuite/baseline/test_plot_with_set_units.png
index 5fd3ce0..a59924c 100644
--- a/pint/testsuite/baseline/test_plot_with_set_units.png
+++ b/pint/testsuite/baseline/test_plot_with_set_units.png
Binary files differ
diff --git a/pint/testsuite/test_dask.py b/pint/testsuite/test_dask.py
index 69c80fe..f4dee6a 100644
--- a/pint/testsuite/test_dask.py
+++ b/pint/testsuite/test_dask.py
@@ -149,6 +149,8 @@ def test_compute_persist_equivalent(local_registry, dask_array, numpy_array):
assert np.all(res_compute == res_persist)
assert res_compute.units == res_persist.units == units_
+ assert type(res_compute) == local_registry.Quantity
+ assert type(res_persist) == local_registry.Quantity
@pytest.mark.parametrize("method", ["compute", "persist", "visualize"])
diff --git a/pint/testsuite/test_matplotlib.py b/pint/testsuite/test_matplotlib.py
index 25f3172..0735721 100644
--- a/pint/testsuite/test_matplotlib.py
+++ b/pint/testsuite/test_matplotlib.py
@@ -46,3 +46,21 @@ def test_plot_with_set_units(local_registry):
ax.axvline(120 * local_registry.minutes, color="tab:green")
return fig
+
+
+@pytest.mark.mpl_image_compare(tolerance=0, remove_text=True)
+def test_plot_with_non_default_format(local_registry):
+ local_registry.mpl_formatter = "{:~P}"
+
+ y = np.linspace(0, 30) * local_registry.miles
+ x = np.linspace(0, 5) * local_registry.hours
+
+ fig, ax = plt.subplots()
+ ax.yaxis.set_units(local_registry.inches)
+ ax.xaxis.set_units(local_registry.seconds)
+
+ ax.plot(x, y, "tab:blue")
+ ax.axhline(26400 * local_registry.feet, color="tab:red")
+ ax.axvline(120 * local_registry.minutes, color="tab:green")
+
+ return fig
diff --git a/pint/testsuite/test_numpy.py b/pint/testsuite/test_numpy.py
index 83448ce..f0f95bc 100644
--- a/pint/testsuite/test_numpy.py
+++ b/pint/testsuite/test_numpy.py
@@ -806,7 +806,7 @@ class TestNumpyUnclassified(TestNumpyMethods):
np.around(1.0275 * self.ureg.m, decimals=2), 1.03 * self.ureg.m
)
helpers.assert_quantity_equal(
- np.round_(1.0275 * self.ureg.m, decimals=2), 1.03 * self.ureg.m
+ np.round(1.0275 * self.ureg.m, decimals=2), 1.03 * self.ureg.m
)
def test_trace(self):
@@ -1050,7 +1050,7 @@ class TestNumpyUnclassified(TestNumpyMethods):
np.isclose(self.q, q2), np.array([[False, True], [True, False]])
)
self.assertNDArrayEqual(
- np.isclose(self.q, q2, atol=1e-5, rtol=1e-7),
+ np.isclose(self.q, q2, atol=1e-5 * self.ureg.mm, rtol=1e-7),
np.array([[False, True], [True, False]]),
)
@@ -1222,6 +1222,24 @@ class TestNumpyUnclassified(TestNumpyMethods):
np.array([[1, 0, 2], [3, 0, 4]]) * self.ureg.m,
)
+ @helpers.requires_array_function_protocol()
+ def test_delete(self):
+ q = self.Q_(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]), "m")
+ helpers.assert_quantity_equal(
+ np.delete(q, 1, axis=0),
+ np.array([[1, 2, 3, 4], [9, 10, 11, 12]]) * self.ureg.m,
+ )
+
+ helpers.assert_quantity_equal(
+ np.delete(q, np.s_[::2], 1),
+ np.array([[2, 4], [6, 8], [10, 12]]) * self.ureg.m,
+ )
+
+ helpers.assert_quantity_equal(
+ np.delete(q, [1, 3, 5], None),
+ np.array([1, 3, 5, 7, 8, 9, 10, 11, 12]) * self.ureg.m,
+ )
+
def test_ndarray_downcast(self):
with pytest.warns(UnitStrippedWarning):
np.asarray(self.q)
@@ -1355,6 +1373,16 @@ class TestNumpyUnclassified(TestNumpyMethods):
assert not np.allclose(
[1e10, 1e-8] * self.ureg.m, [1.00001e10, 1e-9] * self.ureg.mm
)
+ assert np.allclose(
+ [1e10, 1e-8] * self.ureg.m,
+ [1.00001e10, 1e-9] * self.ureg.m,
+ atol=1e-8 * self.ureg.m,
+ )
+
+ with pytest.raises(DimensionalityError):
+ assert np.allclose(
+ [1e10, 1e-8] * self.ureg.m, [1.00001e10, 1e-9] * self.ureg.m, atol=1e-8
+ )
@helpers.requires_array_function_protocol()
def test_intersect1d(self):
diff --git a/pint/testsuite/test_pint_eval.py b/pint/testsuite/test_pint_eval.py
index bed8105..b5b94f0 100644
--- a/pint/testsuite/test_pint_eval.py
+++ b/pint/testsuite/test_pint_eval.py
@@ -2,10 +2,13 @@ import pytest
from pint.compat import tokenizer
from pint.pint_eval import build_eval_tree
+from pint.util import string_preprocessor
class TestPintEval:
- def _test_one(self, input_text, parsed):
+ def _test_one(self, input_text, parsed, preprocess=False):
+ if preprocess:
+ input_text = string_preprocessor(input_text)
assert build_eval_tree(tokenizer(input_text)).to_string() == parsed
@pytest.mark.parametrize(
@@ -13,6 +16,7 @@ class TestPintEval:
(
("3", "3"),
("1 + 2", "(1 + 2)"),
+ ("1 - 2", "(1 - 2)"),
("2 * 3 + 4", "((2 * 3) + 4)"), # order of operations
("2 * (3 + 4)", "(2 * (3 + 4))"), # parentheses
(
@@ -71,4 +75,70 @@ class TestPintEval:
),
)
def test_build_eval_tree(self, input_text, parsed):
- self._test_one(input_text, parsed)
+ self._test_one(input_text, parsed, preprocess=False)
+
+ @pytest.mark.parametrize(
+ ("input_text", "parsed"),
+ (
+ ("3", "3"),
+ ("1 + 2", "(1 + 2)"),
+ ("1 - 2", "(1 - 2)"),
+ ("2 * 3 + 4", "((2 * 3) + 4)"), # order of operations
+ ("2 * (3 + 4)", "(2 * (3 + 4))"), # parentheses
+ (
+ "1 + 2 * 3 ** (4 + 3 / 5)",
+ "(1 + (2 * (3 ** (4 + (3 / 5)))))",
+ ), # more order of operations
+ (
+ "1 * ((3 + 4) * 5)",
+ "(1 * ((3 + 4) * 5))",
+ ), # nested parentheses at beginning
+ ("1 * (5 * (3 + 4))", "(1 * (5 * (3 + 4)))"), # nested parentheses at end
+ (
+ "1 * (5 * (3 + 4) / 6)",
+ "(1 * ((5 * (3 + 4)) / 6))",
+ ), # nested parentheses in middle
+ ("-1", "(- 1)"), # unary
+ ("3 * -1", "(3 * (- 1))"), # unary
+ ("3 * --1", "(3 * (- (- 1)))"), # double unary
+ ("3 * -(2 + 4)", "(3 * (- (2 + 4)))"), # parenthetical unary
+ ("3 * -((2 + 4))", "(3 * (- (2 + 4)))"), # parenthetical unary
+ # implicit op
+ ("3 4", "(3 * 4)"),
+ # implicit op, then parentheses
+ ("3 (2 + 4)", "(3 * (2 + 4))"),
+ # parentheses, then implicit
+ ("(3 ** 4 ) 5", "((3 ** 4) * 5)"),
+ # implicit op, then exponentiation
+ ("3 4 ** 5", "(3 * (4 ** 5))"),
+ # implicit op, then addition
+ ("3 4 + 5", "((3 * 4) + 5)"),
+ # power followed by implicit
+ ("3 ** 4 5", "((3 ** 4) * 5)"),
+ # implicit with parentheses
+ ("3 (4 ** 5)", "(3 * (4 ** 5))"),
+ # exponent with e
+ ("3e-1", "3e-1"),
+ # multiple units with exponents
+ ("kg ** 1 * s ** 2", "((kg ** 1) * (s ** 2))"),
+ # multiple units with neg exponents
+ ("kg ** -1 * s ** -2", "((kg ** (- 1)) * (s ** (- 2)))"),
+ # multiple units with neg exponents
+ ("kg^-1 * s^-2", "((kg ** (- 1)) * (s ** (- 2)))"),
+ # multiple units with neg exponents, implicit op
+ ("kg^-1 s^-2", "((kg ** (- 1)) * (s ** (- 2)))"),
+ # nested power
+ ("2 ^ 3 ^ 2", "(2 ** (3 ** 2))"),
+ # nested power
+ ("gram * second / meter ** 2", "((gram * second) / (meter ** 2))"),
+ # nested power
+ ("gram / meter ** 2 / second", "((gram / (meter ** 2)) / second)"),
+ # units should behave like numbers, so we don't need a bunch of extra tests for them
+ # implicit op, then addition
+ ("3 kg + 5", "((3 * kg) + 5)"),
+ ("(5 % 2) m", "((5 % 2) * m)"), # mod operator
+ ("(5 // 2) m", "((5 // 2) * m)"), # floordiv operator
+ ),
+ )
+ def test_preprocessed_eval_tree(self, input_text, parsed):
+ self._test_one(input_text, parsed, preprocess=True)
diff --git a/pint/util.py b/pint/util.py
index 78ea45d..420914f 100644
--- a/pint/util.py
+++ b/pint/util.py
@@ -765,7 +765,7 @@ _subs_re_list = [
r"\b([0-9]+\.?[0-9]*)(?=[e|E][a-zA-Z]|[a-df-zA-DF-Z])",
r"\1*",
), # Handle numberLetter for multiplication
- (r"([\w\.\-])\s+(?=\w)", r"\1*"), # Handle space for multiplication
+ (r"([\w\.\)])\s+(?=[\w\(])", r"\1*"), # Handle space for multiplication
]
#: Compiles the regex and replace {} by a regex that matches an identifier.