summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2021-08-16 08:21:30 -0600
committerGitHub <noreply@github.com>2021-08-16 08:21:30 -0600
commiteb31586efe52c5cdc8ab1b37eeb739c9e2a1ccc4 (patch)
tree5786994f172faa68e0997312b0c290dca9557bed
parent6cd4b015dfd512fb4bc9d2e06a0f0cef4503582b (diff)
parentd01a312de72bb4eef007bf661c22ea8e21e71378 (diff)
downloadnumpy-eb31586efe52c5cdc8ab1b37eeb739c9e2a1ccc4.tar.gz
Merge pull request #19667 from BvB93/npyio
ENH: Add annotations for `np.lib.npyio`
-rw-r--r--numpy/lib/npyio.pyi313
-rw-r--r--numpy/typing/tests/data/fail/modules.py1
-rw-r--r--numpy/typing/tests/data/fail/npyio.py31
-rw-r--r--numpy/typing/tests/data/reveal/npyio.py71
-rwxr-xr-xruntests.py8
5 files changed, 340 insertions, 84 deletions
diff --git a/numpy/lib/npyio.pyi b/numpy/lib/npyio.pyi
index f69edd564..264ceef14 100644
--- a/numpy/lib/npyio.pyi
+++ b/numpy/lib/npyio.pyi
@@ -1,98 +1,249 @@
-from typing import Mapping, List, Any
+import os
+import sys
+import zipfile
+import types
+from typing import (
+ Any,
+ Mapping,
+ TypeVar,
+ Generic,
+ List,
+ Type,
+ Iterator,
+ Union,
+ IO,
+ overload,
+ Sequence,
+ Callable,
+ Pattern,
+)
from numpy import (
DataSource as DataSource,
+ ndarray,
+ recarray,
+ dtype,
+ generic,
+ float64,
+ void,
)
+from numpy.ma.mrecords import MaskedRecords
+from numpy.typing import ArrayLike, DTypeLike, NDArray, _SupportsDType
+
from numpy.core.multiarray import (
packbits as packbits,
unpackbits as unpackbits,
)
+from typing_extensions import Protocol, Literal as L
+
+_T = TypeVar("_T")
+_T_contra = TypeVar("_T_contra", contravariant=True)
+_T_co = TypeVar("_T_co", covariant=True)
+_SCT = TypeVar("_SCT", bound=generic)
+
+_DTypeLike = Union[
+ Type[_SCT],
+ dtype[_SCT],
+ _SupportsDType[dtype[_SCT]],
+]
+
+class _SupportsGetItem(Protocol[_T_contra, _T_co]):
+ def __getitem__(self, key: _T_contra) -> _T_co: ...
+
__all__: List[str]
-class BagObj:
- def __init__(self, obj): ...
- def __getattribute__(self, key): ...
- def __dir__(self): ...
-
-def zipfile_factory(file, *args, **kwargs): ...
-
-class NpzFile(Mapping[Any, Any]):
- zip: Any
- fid: Any
- files: Any
- allow_pickle: Any
- pickle_kwargs: Any
- f: Any
- def __init__(self, fid, own_fid=..., allow_pickle=..., pickle_kwargs=...): ...
- def __enter__(self): ...
- def __exit__(self, exc_type, exc_value, traceback): ...
- def close(self): ...
- def __del__(self): ...
- def __iter__(self): ...
- def __len__(self): ...
- def __getitem__(self, key): ...
- def iteritems(self): ...
- def iterkeys(self): ...
-
-def load(file, mmap_mode=..., allow_pickle=..., fix_imports=..., encoding=...): ...
-def save(file, arr, allow_pickle=..., fix_imports=...): ...
-def savez(file, *args, **kwds): ...
-def savez_compressed(file, *args, **kwds): ...
+class BagObj(Generic[_T_co]):
+ def __init__(self, obj: _SupportsGetItem[str, _T_co]) -> None: ...
+ def __getattribute__(self, key: str) -> _T_co: ...
+ def __dir__(self) -> List[str]: ...
+
+class NpzFile(Mapping[str, NDArray[Any]]):
+ zip: zipfile.ZipFile
+ fid: None | IO[str]
+ files: List[str]
+ allow_pickle: bool
+ pickle_kwargs: None | Mapping[str, Any]
+ # Represent `f` as a mutable property so we can access the type of `self`
+ @property
+ def f(self: _T) -> BagObj[_T]: ...
+ @f.setter
+ def f(self: _T, value: BagObj[_T]) -> None: ...
+ def __init__(
+ self,
+ fid: IO[str],
+ own_fid: bool = ...,
+ allow_pickle: bool = ...,
+ pickle_kwargs: None | Mapping[str, Any] = ...,
+ ) -> None: ...
+ def __enter__(self: _T) -> _T: ...
+ def __exit__(
+ self,
+ __exc_type: None | Type[BaseException],
+ __exc_value: None | BaseException,
+ __traceback: None | types.TracebackType,
+ ) -> None: ...
+ def close(self) -> None: ...
+ def __del__(self) -> None: ...
+ def __iter__(self) -> Iterator[str]: ...
+ def __len__(self) -> int: ...
+ def __getitem__(self, key: str) -> NDArray[Any]: ...
+
+# NOTE: Returns a `NpzFile` if file is a zip file;
+# returns an `ndarray`/`memmap` otherwise
+def load(
+ file: str | bytes | os.PathLike[Any] | IO[bytes],
+ mmap_mode: L[None, "r+", "r", "w+", "c"] = ...,
+ allow_pickle: bool = ...,
+ fix_imports: bool = ...,
+ encoding: L["ASCII", "latin1", "bytes"] = ...,
+) -> Any: ...
+
+def save(
+ file: str | os.PathLike[str] | IO[bytes],
+ arr: ArrayLike,
+ allow_pickle: bool = ...,
+ fix_imports: bool = ...,
+) -> None: ...
+
+def savez(
+ file: str | os.PathLike[str] | IO[bytes],
+ *args: ArrayLike,
+ **kwds: ArrayLike,
+) -> None: ...
+
+def savez_compressed(
+ file: str | os.PathLike[str] | IO[bytes],
+ *args: ArrayLike,
+ **kwds: ArrayLike,
+) -> None: ...
+
+@overload
+def loadtxt(
+ fname: str | os.PathLike[str] | IO[Any],
+ dtype: None = ...,
+ comments: str | Sequence[str] = ...,
+ delimiter: None | str = ...,
+ converters: None | Mapping[int | str, Callable[[str], Any]] = ...,
+ skiprows: int = ...,
+ usecols: int | Sequence[int] = ...,
+ unpack: bool = ...,
+ ndmin: L[0, 1, 2] = ...,
+ encoding: None | str = ...,
+ max_rows: None | int = ...,
+ *,
+ like: None | ArrayLike = ...
+) -> NDArray[float64]: ...
+@overload
+def loadtxt(
+ fname: str | os.PathLike[str] | IO[Any],
+ dtype: _DTypeLike[_SCT],
+ comments: str | Sequence[str] = ...,
+ delimiter: None | str = ...,
+ converters: None | Mapping[int | str, Callable[[str], Any]] = ...,
+ skiprows: int = ...,
+ usecols: int | Sequence[int] = ...,
+ unpack: bool = ...,
+ ndmin: L[0, 1, 2] = ...,
+ encoding: None | str = ...,
+ max_rows: None | int = ...,
+ *,
+ like: None | ArrayLike = ...
+) -> NDArray[_SCT]: ...
+@overload
def loadtxt(
- fname,
- dtype=...,
- comments=...,
- delimiter=...,
- converters=...,
- skiprows=...,
- usecols=...,
- unpack=...,
- ndmin=...,
- encoding=...,
- max_rows=...,
+ fname: str | os.PathLike[str] | IO[Any],
+ dtype: DTypeLike,
+ comments: str | Sequence[str] = ...,
+ delimiter: None | str = ...,
+ converters: None | Mapping[int | str, Callable[[str], Any]] = ...,
+ skiprows: int = ...,
+ usecols: int | Sequence[int] = ...,
+ unpack: bool = ...,
+ ndmin: L[0, 1, 2] = ...,
+ encoding: None | str = ...,
+ max_rows: None | int = ...,
*,
- like=...,
-): ...
+ like: None | ArrayLike = ...
+) -> NDArray[Any]: ...
+
def savetxt(
- fname,
- X,
- fmt=...,
- delimiter=...,
- newline=...,
- header=...,
- footer=...,
- comments=...,
- encoding=...,
-): ...
-def fromregex(file, regexp, dtype, encoding=...): ...
+ fname: str | os.PathLike[str] | IO[Any],
+ X: ArrayLike,
+ fmt: str | Sequence[str] = ...,
+ delimiter: str = ...,
+ newline: str = ...,
+ header: str = ...,
+ footer: str = ...,
+ comments: str = ...,
+ encoding: None | str = ...,
+) -> None: ...
+
+@overload
+def fromregex(
+ file: str | IO[Any],
+ regexp: str | bytes | Pattern[Any],
+ dtype: _DTypeLike[_SCT],
+ encoding: None | str = ...
+) -> NDArray[_SCT]: ...
+@overload
+def fromregex(
+ file: str | IO[Any],
+ regexp: str | bytes | Pattern[Any],
+ dtype: DTypeLike,
+ encoding: None | str = ...
+) -> NDArray[Any]: ...
+
+# TODO: Sort out arguments
+@overload
+def genfromtxt(
+ fname: str | os.PathLike[str] | IO[Any],
+ dtype: None = ...,
+ *args: Any,
+ **kwargs: Any,
+) -> NDArray[float64]: ...
+@overload
def genfromtxt(
- fname,
- dtype=...,
- comments=...,
- delimiter=...,
- skip_header=...,
- skip_footer=...,
- converters=...,
- missing_values=...,
- filling_values=...,
- usecols=...,
- names=...,
- excludelist=...,
- deletechars=...,
- replace_space=...,
- autostrip=...,
- case_sensitive=...,
- defaultfmt=...,
- unpack=...,
- usemask=...,
- loose=...,
- invalid_raise=...,
- max_rows=...,
- encoding=...,
+ fname: str | os.PathLike[str] | IO[Any],
+ dtype: _DTypeLike[_SCT],
+ *args: Any,
+ **kwargs: Any,
+) -> NDArray[_SCT]: ...
+@overload
+def genfromtxt(
+ fname: str | os.PathLike[str] | IO[Any],
+ dtype: DTypeLike,
+ *args: Any,
+ **kwargs: Any,
+) -> NDArray[Any]: ...
+
+@overload
+def recfromtxt(
+ fname: str | os.PathLike[str] | IO[Any],
+ *,
+ usemask: L[False] = ...,
+ **kwargs: Any,
+) -> recarray[Any, dtype[void]]: ...
+@overload
+def recfromtxt(
+ fname: str | os.PathLike[str] | IO[Any],
+ *,
+ usemask: L[True],
+ **kwargs: Any,
+) -> MaskedRecords[Any, dtype[void]]: ...
+
+@overload
+def recfromcsv(
+ fname: str | os.PathLike[str] | IO[Any],
+ *,
+ usemask: L[False] = ...,
+ **kwargs: Any,
+) -> recarray[Any, dtype[void]]: ...
+@overload
+def recfromcsv(
+ fname: str | os.PathLike[str] | IO[Any],
*,
- like=...,
-): ...
-def recfromtxt(fname, **kwargs): ...
-def recfromcsv(fname, **kwargs): ...
+ usemask: L[True],
+ **kwargs: Any,
+) -> MaskedRecords[Any, dtype[void]]: ...
diff --git a/numpy/typing/tests/data/fail/modules.py b/numpy/typing/tests/data/fail/modules.py
index 7b9309329..59e724f22 100644
--- a/numpy/typing/tests/data/fail/modules.py
+++ b/numpy/typing/tests/data/fail/modules.py
@@ -12,7 +12,6 @@ np.math # E: Module has no attribute
# Public sub-modules that are not imported to their parent module by default;
# e.g. one must first execute `import numpy.lib.recfunctions`
np.lib.recfunctions # E: Module has no attribute
-np.ma.mrecords # E: Module has no attribute
np.__NUMPY_SETUP__ # E: Module has no attribute
np.__deprecated_attrs__ # E: Module has no attribute
diff --git a/numpy/typing/tests/data/fail/npyio.py b/numpy/typing/tests/data/fail/npyio.py
new file mode 100644
index 000000000..89c511c1c
--- /dev/null
+++ b/numpy/typing/tests/data/fail/npyio.py
@@ -0,0 +1,31 @@
+import pathlib
+from typing import IO
+
+import numpy.typing as npt
+import numpy as np
+
+str_path: str
+bytes_path: bytes
+pathlib_path: pathlib.Path
+str_file: IO[str]
+AR_i8: npt.NDArray[np.int64]
+
+np.load(str_file) # E: incompatible type
+
+np.save(bytes_path, AR_i8) # E: incompatible type
+np.save(str_file, AR_i8) # E: incompatible type
+
+np.savez(bytes_path, AR_i8) # E: incompatible type
+np.savez(str_file, AR_i8) # E: incompatible type
+
+np.savez_compressed(bytes_path, AR_i8) # E: incompatible type
+np.savez_compressed(str_file, AR_i8) # E: incompatible type
+
+np.loadtxt(bytes_path) # E: No overload variant
+
+np.fromregex(bytes_path, ".", np.int64) # E: No overload variant
+np.fromregex(pathlib_path, ".", np.int64) # E: No overload variant
+
+np.recfromtxt(bytes_path) # E: No overload variant
+
+np.recfromcsv(bytes_path) # E: No overload variant
diff --git a/numpy/typing/tests/data/reveal/npyio.py b/numpy/typing/tests/data/reveal/npyio.py
new file mode 100644
index 000000000..36c0c540b
--- /dev/null
+++ b/numpy/typing/tests/data/reveal/npyio.py
@@ -0,0 +1,71 @@
+import re
+import pathlib
+from typing import IO, List
+
+import numpy.typing as npt
+import numpy as np
+
+str_path: str
+pathlib_path: pathlib.Path
+str_file: IO[str]
+bytes_file: IO[bytes]
+
+bag_obj: np.lib.npyio.BagObj[int]
+npz_file: np.lib.npyio.NpzFile
+
+AR_i8: npt.NDArray[np.int64]
+AR_LIKE_f8: List[float]
+
+reveal_type(bag_obj.a) # E: int
+reveal_type(bag_obj.b) # E: int
+
+reveal_type(npz_file.zip) # E: zipfile.ZipFile
+reveal_type(npz_file.fid) # E: Union[None, typing.IO[builtins.str]]
+reveal_type(npz_file.files) # E: list[builtins.str]
+reveal_type(npz_file.allow_pickle) # E: bool
+reveal_type(npz_file.pickle_kwargs) # E: Union[None, typing.Mapping[builtins.str, Any]]
+reveal_type(npz_file.f) # E: numpy.lib.npyio.BagObj[numpy.lib.npyio.NpzFile]
+reveal_type(npz_file["test"]) # E: numpy.ndarray[Any, numpy.dtype[Any]]
+reveal_type(len(npz_file)) # E: int
+with npz_file as f:
+ reveal_type(f) # E: numpy.lib.npyio.NpzFile
+
+reveal_type(np.load(bytes_file)) # E: Any
+reveal_type(np.load(pathlib_path, allow_pickle=True)) # E: Any
+reveal_type(np.load(str_path, encoding="bytes")) # E: Any
+
+reveal_type(np.save(bytes_file, AR_LIKE_f8)) # E: None
+reveal_type(np.save(pathlib_path, AR_i8, allow_pickle=True)) # E: None
+reveal_type(np.save(str_path, AR_LIKE_f8)) # E: None
+
+reveal_type(np.savez(bytes_file, AR_LIKE_f8)) # E: None
+reveal_type(np.savez(pathlib_path, ar1=AR_i8, ar2=AR_i8)) # E: None
+reveal_type(np.savez(str_path, AR_LIKE_f8, ar1=AR_i8)) # E: None
+
+reveal_type(np.savez_compressed(bytes_file, AR_LIKE_f8)) # E: None
+reveal_type(np.savez_compressed(pathlib_path, ar1=AR_i8, ar2=AR_i8)) # E: None
+reveal_type(np.savez_compressed(str_path, AR_LIKE_f8, ar1=AR_i8)) # E: None
+
+reveal_type(np.loadtxt(bytes_file)) # E: numpy.ndarray[Any, numpy.dtype[{float64}]]
+reveal_type(np.loadtxt(pathlib_path, dtype=np.str_)) # E: numpy.ndarray[Any, numpy.dtype[numpy.str_]]
+reveal_type(np.loadtxt(str_path, dtype=str, skiprows=2)) # E: numpy.ndarray[Any, numpy.dtype[Any]]
+reveal_type(np.loadtxt(str_file, comments="test")) # E: numpy.ndarray[Any, numpy.dtype[{float64}]]
+reveal_type(np.loadtxt(str_path, delimiter="\n")) # E: numpy.ndarray[Any, numpy.dtype[{float64}]]
+reveal_type(np.loadtxt(str_path, ndmin=2)) # E: numpy.ndarray[Any, numpy.dtype[{float64}]]
+
+reveal_type(np.fromregex(bytes_file, "test", np.float64)) # E: numpy.ndarray[Any, numpy.dtype[{float64}]]
+reveal_type(np.fromregex(str_file, b"test", dtype=float)) # E: numpy.ndarray[Any, numpy.dtype[Any]]
+reveal_type(np.fromregex(str_path, re.compile("test"), dtype=np.str_, encoding="utf8")) # E: numpy.ndarray[Any, numpy.dtype[numpy.str_]]
+
+reveal_type(np.genfromtxt(bytes_file)) # E: numpy.ndarray[Any, numpy.dtype[{float64}]]
+reveal_type(np.genfromtxt(pathlib_path, dtype=np.str_)) # E: numpy.ndarray[Any, numpy.dtype[numpy.str_]]
+reveal_type(np.genfromtxt(str_path, dtype=str, skiprows=2)) # E: numpy.ndarray[Any, numpy.dtype[Any]]
+reveal_type(np.genfromtxt(str_file, comments="test")) # E: numpy.ndarray[Any, numpy.dtype[{float64}]]
+reveal_type(np.genfromtxt(str_path, delimiter="\n")) # E: numpy.ndarray[Any, numpy.dtype[{float64}]]
+reveal_type(np.genfromtxt(str_path, ndmin=2)) # E: numpy.ndarray[Any, numpy.dtype[{float64}]]
+
+reveal_type(np.recfromtxt(bytes_file)) # E: numpy.recarray[Any, numpy.dtype[numpy.void]]
+reveal_type(np.recfromtxt(pathlib_path, usemask=True)) # E: numpy.ma.mrecords.MaskedRecords[Any, numpy.dtype[numpy.void]]
+
+reveal_type(np.recfromcsv(bytes_file)) # E: numpy.recarray[Any, numpy.dtype[numpy.void]]
+reveal_type(np.recfromcsv(pathlib_path, usemask=True)) # E: numpy.ma.mrecords.MaskedRecords[Any, numpy.dtype[numpy.void]]
diff --git a/runtests.py b/runtests.py
index fcfa4c567..855fd7157 100755
--- a/runtests.py
+++ b/runtests.py
@@ -476,13 +476,17 @@ def build_project(args):
py_v_s = sysconfig.get_config_var('py_version_short')
platlibdir = getattr(sys, 'platlibdir', '') # Python3.9+
- site_dir_template = sysconfig.get_path('platlib', expand=False)
+ site_dir_template = os.path.normpath(sysconfig.get_path(
+ 'platlib', expand=False
+ ))
site_dir = site_dir_template.format(platbase=dst_dir,
py_version_short=py_v_s,
platlibdir=platlibdir,
base=dst_dir,
)
- noarch_template = sysconfig.get_path('purelib', expand=False)
+ noarch_template = os.path.normpath(sysconfig.get_path(
+ 'purelib', expand=False
+ ))
site_dir_noarch = noarch_template.format(base=dst_dir,
py_version_short=py_v_s,
platlibdir=platlibdir,