From badbf70324274bdb4299d8c64d3d83a26be2d4c0 Mon Sep 17 00:00:00 2001 From: Ohad Ravid Date: Wed, 24 Mar 2021 15:45:16 +0200 Subject: ENH: Improve performance of `np.save` for small arrays (gh-18657) * ENH: Remove call to `_filter_header` from `_write_array_header` Improve performance of `np.save` by removing the call when writing the header, as it is known to be done in Python 3. * ENH: Only call `_filter_header` from `_read_array_header` for old vers Improve performance of `np.load` for arrays with version >= (3,0) by removing the call, as it is known to be done in Python 3. * ENH: Use a set of keys when checking `read_array` Improve performance of `np.load`. * DOC: Improve performance of `np.{save,load}` for small arrays --- numpy/lib/format.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) (limited to 'numpy/lib/format.py') diff --git a/numpy/lib/format.py b/numpy/lib/format.py index 904c32cc7..ead6a0420 100644 --- a/numpy/lib/format.py +++ b/numpy/lib/format.py @@ -173,6 +173,7 @@ from numpy.compat import ( __all__ = [] +EXPECTED_KEYS = {'descr', 'fortran_order', 'shape'} MAGIC_PREFIX = b'\x93NUMPY' MAGIC_LEN = len(MAGIC_PREFIX) + 2 ARRAY_ALIGN = 64 # plausible values are powers of 2 between 16 and 4096 @@ -432,7 +433,6 @@ def _write_array_header(fp, d, version=None): header.append("'%s': %s, " % (key, repr(value))) header.append("}") header = "".join(header) - header = _filter_header(header) if version is None: header = _wrap_header_guess_version(header) else: @@ -590,7 +590,10 @@ def _read_array_header(fp, version): # "shape" : tuple of int # "fortran_order" : bool # "descr" : dtype.descr - header = _filter_header(header) + # Versions (2, 0) and (1, 0) could have been created by a Python 2 + # implementation before header filtering was implemented. + if version <= (2, 0): + header = _filter_header(header) try: d = safe_eval(header) except SyntaxError as e: @@ -599,14 +602,15 @@ def _read_array_header(fp, version): if not isinstance(d, dict): msg = "Header is not a dictionary: {!r}" raise ValueError(msg.format(d)) - keys = sorted(d.keys()) - if keys != ['descr', 'fortran_order', 'shape']: + + if EXPECTED_KEYS != d.keys(): + keys = sorted(d.keys()) msg = "Header does not contain the correct keys: {!r}" - raise ValueError(msg.format(keys)) + raise ValueError(msg.format(d.keys())) # Sanity-check the values. if (not isinstance(d['shape'], tuple) or - not numpy.all([isinstance(x, int) for x in d['shape']])): + not all(isinstance(x, int) for x in d['shape'])): msg = "shape is not valid: {!r}" raise ValueError(msg.format(d['shape'])) if not isinstance(d['fortran_order'], bool): -- cgit v1.2.1