summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2020-09-17 01:43:49 +0200
committerBas van Beek <b.f.van.beek@vu.nl>2020-10-05 16:08:53 +0200
commit5dbcbb79478a8eddf05a72f7030bdf29a93ff46c (patch)
tree7bf95281bda48c6448e645f471de92ced5a29f2b
parent9ebd2bfa64eab788d529af7404f85dc3a58ff411 (diff)
downloadnumpy-5dbcbb79478a8eddf05a72f7030bdf29a93ff46c.tar.gz
ENH: Add annotations for `generic` and `ndarray` bitwise operations
-rw-r--r--numpy/__init__.pyi79
-rw-r--r--numpy/typing/_callable.py29
2 files changed, 91 insertions, 17 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index 9966ef199..1dd5d5552 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -7,12 +7,15 @@ from numpy.core._internal import _ctypes
from numpy.typing import ArrayLike, DtypeLike, _Shape, _ShapeLike
from numpy.typing._callable import (
_BoolOp,
+ _BoolBitOp,
_BoolSub,
_BoolTrueDiv,
_TD64Div,
_IntTrueDiv,
_UnsignedIntOp,
+ _UnsignedIntBitOp,
_SignedIntOp,
+ _SignedIntBitOp,
_FloatOp,
_ComplexOp,
_NumberOp,
@@ -677,20 +680,9 @@ class _ArrayOrScalarCommon(
def __rmod__(self, other): ...
def __divmod__(self, other): ...
def __rdivmod__(self, other): ...
- def __lshift__(self, other): ...
- def __rlshift__(self, other): ...
- def __rshift__(self, other): ...
- def __rrshift__(self, other): ...
- def __and__(self, other): ...
- def __rand__(self, other): ...
- def __xor__(self, other): ...
- def __rxor__(self, other): ...
- def __or__(self, other): ...
- def __ror__(self, other): ...
def __neg__(self: _ArraySelf) -> _ArraySelf: ...
def __pos__(self: _ArraySelf) -> _ArraySelf: ...
def __abs__(self: _ArraySelf) -> _ArraySelf: ...
- def __invert__(self: _ArraySelf) -> _ArraySelf: ...
def astype(
self: _ArraySelf,
dtype: DtypeLike,
@@ -1257,6 +1249,17 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
def __rpow__(self, other: ArrayLike) -> Union[ndarray, generic]: ...
def __truediv__(self, other: ArrayLike) -> Union[ndarray, generic]: ...
def __rtruediv__(self, other: ArrayLike) -> Union[ndarray, generic]: ...
+ def __invert__(self: _ArraySelf) -> Union[_ArraySelf, integer, bool_]: ...
+ def __lshift__(self, other: ArrayLike) -> Union[ndarray, integer]: ...
+ def __rlshift__(self, other: ArrayLike) -> Union[ndarray, integer]: ...
+ def __rshift__(self, other: ArrayLike) -> Union[ndarray, integer]: ...
+ def __rrshift__(self, other: ArrayLike) -> Union[ndarray, integer]: ...
+ def __and__(self, other: ArrayLike) -> Union[ndarray, integer, bool_]: ...
+ def __rand__(self, other: ArrayLike) -> Union[ndarray, integer, bool_]: ...
+ def __xor__(self, other: ArrayLike) -> Union[ndarray, integer, bool_]: ...
+ def __rxor__(self, other: ArrayLike) -> Union[ndarray, integer, bool_]: ...
+ def __or__(self, other: ArrayLike) -> Union[ndarray, integer, bool_]: ...
+ def __ror__(self, other: ArrayLike) -> Union[ndarray, integer, bool_]: ...
# `np.generic` does not support inplace operations
def __iadd__(self: _ArraySelf, other: ArrayLike) -> _ArraySelf: ...
def __isub__(self: _ArraySelf, other: ArrayLike) -> _ArraySelf: ...
@@ -1265,11 +1268,11 @@ class ndarray(_ArrayOrScalarCommon, Iterable, Sized, Container):
def __ifloordiv__(self: _ArraySelf, other: ArrayLike) -> _ArraySelf: ...
def __ipow__(self: _ArraySelf, other: ArrayLike) -> _ArraySelf: ...
def __imod__(self, other): ...
- def __ilshift__(self, other): ...
- def __irshift__(self, other): ...
- def __iand__(self, other): ...
- def __ixor__(self, other): ...
- def __ior__(self, other): ...
+ def __ilshift__(self: _ArraySelf, other: ArrayLike) -> _ArraySelf: ...
+ def __irshift__(self: _ArraySelf, other: ArrayLike) -> _ArraySelf: ...
+ def __iand__(self: _ArraySelf, other: ArrayLike) -> _ArraySelf: ...
+ def __ixor__(self: _ArraySelf, other: ArrayLike) -> _ArraySelf: ...
+ def __ior__(self: _ArraySelf, other: ArrayLike) -> _ArraySelf: ...
# NOTE: while `np.generic` is not technically an instance of `ABCMeta`,
# the `@abstractmethod` decorator is herein used to (forcefully) deny
@@ -1329,6 +1332,17 @@ class bool_(generic):
__rpow__: _BoolOp[int8]
__truediv__: _BoolTrueDiv
__rtruediv__: _BoolTrueDiv
+ def __invert__(self) -> bool_: ...
+ __lshift__: _BoolBitOp[int8]
+ __rlshift__: _BoolBitOp[int8]
+ __rshift__: _BoolBitOp[int8]
+ __rrshift__: _BoolBitOp[int8]
+ __and__: _BoolBitOp[bool_]
+ __rand__: _BoolBitOp[bool_]
+ __xor__: _BoolBitOp[bool_]
+ __rxor__: _BoolBitOp[bool_]
+ __or__: _BoolBitOp[bool_]
+ __ror__: _BoolBitOp[bool_]
class object_(generic):
def __init__(self, __value: object = ...) -> None: ...
@@ -1374,6 +1388,18 @@ class integer(number): # type: ignore
def __index__(self) -> int: ...
__truediv__: _IntTrueDiv
__rtruediv__: _IntTrueDiv
+ def __invert__(self: _IntType) -> _IntType: ...
+ # Ensure that objects annotated as `integer` support bit-wise operations
+ def __lshift__(self, other: Union[_IntLike, _BoolLike]) -> integer: ...
+ def __rlshift__(self, other: Union[_IntLike, _BoolLike]) -> integer: ...
+ def __rshift__(self, other: Union[_IntLike, _BoolLike]) -> integer: ...
+ def __rrshift__(self, other: Union[_IntLike, _BoolLike]) -> integer: ...
+ def __and__(self, other: Union[_IntLike, _BoolLike]) -> integer: ...
+ def __rand__(self, other: Union[_IntLike, _BoolLike]) -> integer: ...
+ def __or__(self, other: Union[_IntLike, _BoolLike]) -> integer: ...
+ def __ror__(self, other: Union[_IntLike, _BoolLike]) -> integer: ...
+ def __xor__(self, other: Union[_IntLike, _BoolLike]) -> integer: ...
+ def __rxor__(self, other: Union[_IntLike, _BoolLike]) -> integer: ...
class signedinteger(integer): # type: ignore
__add__: _SignedIntOp
@@ -1386,6 +1412,16 @@ class signedinteger(integer): # type: ignore
__rfloordiv__: _SignedIntOp
__pow__: _SignedIntOp
__rpow__: _SignedIntOp
+ __lshift__: _SignedIntBitOp
+ __rlshift__: _SignedIntBitOp
+ __rshift__: _SignedIntBitOp
+ __rrshift__: _SignedIntBitOp
+ __and__: _SignedIntBitOp
+ __rand__: _SignedIntBitOp
+ __xor__: _SignedIntBitOp
+ __rxor__: _SignedIntBitOp
+ __or__: _SignedIntBitOp
+ __ror__: _SignedIntBitOp
class int8(signedinteger):
def __init__(self, __value: _IntValue = ...) -> None: ...
@@ -1429,6 +1465,16 @@ class unsignedinteger(integer): # type: ignore
__rfloordiv__: _UnsignedIntOp
__pow__: _UnsignedIntOp
__rpow__: _UnsignedIntOp
+ __lshift__: _UnsignedIntBitOp
+ __rlshift__: _UnsignedIntBitOp
+ __rshift__: _UnsignedIntBitOp
+ __rrshift__: _UnsignedIntBitOp
+ __and__: _UnsignedIntBitOp
+ __rand__: _UnsignedIntBitOp
+ __xor__: _UnsignedIntBitOp
+ __rxor__: _UnsignedIntBitOp
+ __or__: _UnsignedIntBitOp
+ __ror__: _UnsignedIntBitOp
class uint8(unsignedinteger):
def __init__(self, __value: _IntValue = ...) -> None: ...
@@ -1458,6 +1504,7 @@ class floating(inexact): # type: ignore
__pow__: _FloatOp
__rpow__: _FloatOp
+_IntType = TypeVar("_IntType", bound=integer)
_FloatType = TypeVar('_FloatType', bound=floating)
class float16(floating):
diff --git a/numpy/typing/_callable.py b/numpy/typing/_callable.py
index 5e14b708f..943441cf4 100644
--- a/numpy/typing/_callable.py
+++ b/numpy/typing/_callable.py
@@ -9,7 +9,7 @@ See the `Mypy documentation`_ on protocols for more details.
"""
import sys
-from typing import Union, TypeVar, overload, Any
+from typing import Union, TypeVar, overload, Any, NoReturn
from numpy import (
_BoolLike,
@@ -26,6 +26,7 @@ from numpy import (
signedinteger,
int32,
int64,
+ uint64,
floating,
float32,
float64,
@@ -45,6 +46,7 @@ else:
HAVE_PROTOCOL = True
if HAVE_PROTOCOL:
+ _IntType = TypeVar("_IntType", bound=integer)
_NumberType = TypeVar("_NumberType", bound=number)
_NumberType_co = TypeVar("_NumberType_co", covariant=True, bound=number)
_GenericType_co = TypeVar("_GenericType_co", covariant=True, bound=generic)
@@ -61,6 +63,14 @@ if HAVE_PROTOCOL:
@overload
def __call__(self, __other: _NumberType) -> _NumberType: ...
+ class _BoolBitOp(Protocol[_GenericType_co]):
+ @overload
+ def __call__(self, __other: _BoolLike) -> _GenericType_co: ...
+ @overload # platform dependent
+ def __call__(self, __other: int) -> Union[int32, int64]: ...
+ @overload
+ def __call__(self, __other: _IntType) -> _IntType: ...
+
class _BoolSub(Protocol):
# Note that `__other: bool_` is absent here
@overload # platform dependent
@@ -103,6 +113,17 @@ if HAVE_PROTOCOL:
@overload
def __call__(self, __other: complex) -> complexfloating[floating]: ...
+ class _UnsignedIntBitOp(Protocol):
+ # The likes of `uint64 | np.signedinteger` will fail as there
+ # is no signed integer type large enough to hold a `uint64`
+ # See https://github.com/numpy/numpy/issues/2524
+ @overload
+ def __call__(self, __other: Union[bool, unsignedinteger]) -> unsignedinteger: ...
+ @overload
+ def __call__(self: uint64, __other: Union[int, signedinteger]) -> NoReturn: ...
+ @overload
+ def __call__(self, __other: Union[int, signedinteger]) -> signedinteger: ...
+
class _SignedIntOp(Protocol):
@overload
def __call__(self, __other: Union[int, signedinteger]) -> signedinteger: ...
@@ -111,6 +132,9 @@ if HAVE_PROTOCOL:
@overload
def __call__(self, __other: complex) -> complexfloating[floating]: ...
+ class _SignedIntBitOp(Protocol):
+ def __call__(self, __other: Union[int, signedinteger]) -> signedinteger: ...
+
class _FloatOp(Protocol):
@overload
def __call__(self, __other: _FloatLike) -> floating: ...
@@ -125,12 +149,15 @@ if HAVE_PROTOCOL:
else:
_BoolOp = Any
+ _BoolBitOp = Any
_BoolSub = Any
_BoolTrueDiv = Any
_TD64Div = Any
_IntTrueDiv = Any
_UnsignedIntOp = Any
+ _UnsignedIntBitOp = Any
_SignedIntOp = Any
+ _SignedIntBitOp = Any
_FloatOp = Any
_ComplexOp = Any
_NumberOp = Any