summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSybren A. Stüvel <sybren@stuvel.eu>2019-08-04 16:41:01 +0200
committerSybren A. Stüvel <sybren@stuvel.eu>2019-08-04 17:05:58 +0200
commitb6cebd53fcafd3088fc8361f6d3466166f75410b (patch)
treea1a3912fb9e91e249e433df0a9b79572f46340f3
parent6760eb76e665dc81863a82110164c4b3b38e7ee9 (diff)
downloadrsa-git-b6cebd53fcafd3088fc8361f6d3466166f75410b.tar.gz
Added type annotations + some fixes to get them correct
One functional change: `CryptoOperation.read_infile()` now reads bytes from `sys.stdin` instead of text. This is necessary to be consistent with the rest of the code, which all deals with bytes.
-rw-r--r--rsa/_compat.py7
-rw-r--r--rsa/cli.py44
-rw-r--r--rsa/common.py19
-rw-r--r--rsa/core.py6
-rw-r--r--rsa/key.py95
-rw-r--r--rsa/parallel.py4
-rw-r--r--rsa/pem.py10
-rw-r--r--rsa/pkcs1.py36
-rw-r--r--rsa/pkcs1_v2.py2
-rw-r--r--rsa/prime.py12
-rw-r--r--rsa/randnum.py8
-rw-r--r--rsa/util.py2
12 files changed, 129 insertions, 116 deletions
diff --git a/rsa/_compat.py b/rsa/_compat.py
index 843583c..b31331e 100644
--- a/rsa/_compat.py
+++ b/rsa/_compat.py
@@ -21,14 +21,11 @@ import sys
from struct import pack
-def byte(num):## XXX
+def byte(num: int):
"""
Converts a number between 0 and 255 (both inclusive) to a base-256 (byte)
representation.
- Use it as a replacement for ``chr`` where you are expecting a byte
- because this will work on all current versions of Python::
-
:param num:
An unsigned integer between 0 and 255 (both inclusive).
:returns:
@@ -37,7 +34,7 @@ def byte(num):## XXX
return pack("B", num)
-def xor_bytes(b1, b2):
+def xor_bytes(b1: bytes, b2: bytes) -> bytes:
"""
Returns the bitwise XOR result between two bytes objects, b1 ^ b2.
diff --git a/rsa/cli.py b/rsa/cli.py
index cbf3f97..60bc07c 100644
--- a/rsa/cli.py
+++ b/rsa/cli.py
@@ -22,20 +22,21 @@ These scripts are called by the executables defined in setup.py.
import abc
import sys
import typing
-from optparse import OptionParser
+import optparse
import rsa
import rsa.key
import rsa.pkcs1
HASH_METHODS = sorted(rsa.pkcs1.HASH_METHODS.keys())
+Indexable = typing.Union[typing.Tuple, typing.List[str]]
-def keygen():
+def keygen() -> None:
"""Key generator."""
# Parse the CLI options
- parser = OptionParser(usage='usage: %prog [options] keysize',
+ parser = optparse.OptionParser(usage='usage: %prog [options] keysize',
description='Generates a new RSA keypair of "keysize" bits.')
parser.add_option('--pubout', type='string',
@@ -104,13 +105,14 @@ class CryptoOperation(metaclass=abc.ABCMeta):
key_class = rsa.PublicKey # type: typing.Type[rsa.key.AbstractKey]
- def __init__(self):
+ def __init__(self) -> None:
self.usage = self.usage % self.__class__.__dict__
self.input_help = self.input_help % self.__class__.__dict__
self.output_help = self.output_help % self.__class__.__dict__
@abc.abstractmethod
- def perform_operation(self, indata, key, cli_args):
+ def perform_operation(self, indata: bytes, key: rsa.key.AbstractKey,
+ cli_args: Indexable):
"""Performs the program's operation.
Implement in a subclass.
@@ -118,7 +120,7 @@ class CryptoOperation(metaclass=abc.ABCMeta):
:returns: the data to write to the output.
"""
- def __call__(self):
+ def __call__(self) -> None:
"""Runs the program."""
(cli, cli_args) = self.parse_cli()
@@ -133,13 +135,13 @@ class CryptoOperation(metaclass=abc.ABCMeta):
if self.has_output:
self.write_outfile(outdata, cli.output)
- def parse_cli(self):
+ def parse_cli(self) -> typing.Tuple[optparse.Values, typing.List[str]]:
"""Parse the CLI options
:returns: (cli_opts, cli_args)
"""
- parser = OptionParser(usage=self.usage, description=self.description)
+ parser = optparse.OptionParser(usage=self.usage, description=self.description)
parser.add_option('-i', '--input', type='string', help=self.input_help)
@@ -158,7 +160,7 @@ class CryptoOperation(metaclass=abc.ABCMeta):
return cli, cli_args
- def read_key(self, filename, keyform):
+ def read_key(self, filename: str, keyform: str) -> rsa.key.AbstractKey:
"""Reads a public or private key."""
print('Reading %s key from %s' % (self.keyname, filename), file=sys.stderr)
@@ -167,7 +169,7 @@ class CryptoOperation(metaclass=abc.ABCMeta):
return self.key_class.load_pkcs1(keydata, keyform)
- def read_infile(self, inname):
+ def read_infile(self, inname: str) -> bytes:
"""Read the input file"""
if inname:
@@ -176,9 +178,9 @@ class CryptoOperation(metaclass=abc.ABCMeta):
return infile.read()
print('Reading input from stdin', file=sys.stderr)
- return sys.stdin.read()
+ return sys.stdin.buffer.read()
- def write_outfile(self, outdata, outname):
+ def write_outfile(self, outdata: bytes, outname: str) -> None:
"""Write the output file"""
if outname:
@@ -200,9 +202,10 @@ class EncryptOperation(CryptoOperation):
operation_past = 'encrypted'
operation_progressive = 'encrypting'
- def perform_operation(self, indata, pub_key, cli_args=None):
+ def perform_operation(self, indata: bytes, pub_key: rsa.key.AbstractKey,
+ cli_args: Indexable=()):
"""Encrypts files."""
-
+ assert isinstance(pub_key, rsa.key.PublicKey)
return rsa.encrypt(indata, pub_key)
@@ -217,9 +220,10 @@ class DecryptOperation(CryptoOperation):
operation_progressive = 'decrypting'
key_class = rsa.PrivateKey
- def perform_operation(self, indata, priv_key, cli_args=None):
+ def perform_operation(self, indata: bytes, priv_key: rsa.key.AbstractKey,
+ cli_args: Indexable=()):
"""Decrypts files."""
-
+ assert isinstance(priv_key, rsa.key.PrivateKey)
return rsa.decrypt(indata, priv_key)
@@ -239,8 +243,10 @@ class SignOperation(CryptoOperation):
output_help = ('Name of the file to write the signature to. Written '
'to stdout if this option is not present.')
- def perform_operation(self, indata, priv_key, cli_args):
+ def perform_operation(self, indata: bytes, priv_key: rsa.key.AbstractKey,
+ cli_args: Indexable):
"""Signs files."""
+ assert isinstance(priv_key, rsa.key.PrivateKey)
hash_method = cli_args[1]
if hash_method not in HASH_METHODS:
@@ -264,8 +270,10 @@ class VerifyOperation(CryptoOperation):
expected_cli_args = 2
has_output = False
- def perform_operation(self, indata, pub_key, cli_args):
+ def perform_operation(self, indata: bytes, pub_key: rsa.key.AbstractKey,
+ cli_args: Indexable):
"""Verifies files."""
+ assert isinstance(pub_key, rsa.key.PublicKey)
signature_file = cli_args[1]
diff --git a/rsa/common.py b/rsa/common.py
index a4337f6..b983b98 100644
--- a/rsa/common.py
+++ b/rsa/common.py
@@ -16,17 +16,18 @@
"""Common functionality shared by several modules."""
+import typing
+
class NotRelativePrimeError(ValueError):
- def __init__(self, a, b, d, msg=None):
- super(NotRelativePrimeError, self).__init__(
- msg or "%d and %d are not relatively prime, divider=%i" % (a, b, d))
+ def __init__(self, a, b, d, msg=''):
+ super().__init__(msg or "%d and %d are not relatively prime, divider=%i" % (a, b, d))
self.a = a
self.b = b
self.d = d
-def bit_size(num):
+def bit_size(num: int) -> int:
"""
Number of bits needed to represent a integer excluding any prefix
0 bits.
@@ -54,7 +55,7 @@ def bit_size(num):
raise TypeError('bit_size(num) only supports integers, not %r' % type(num))
-def byte_size(number):
+def byte_size(number: int) -> int:
"""
Returns the number of bytes required to hold a specific long number.
@@ -79,7 +80,7 @@ def byte_size(number):
return ceil_div(bit_size(number), 8)
-def ceil_div(num, div):
+def ceil_div(num: int, div: int) -> int:
"""
Returns the ceiling function of a division between `num` and `div`.
@@ -103,7 +104,7 @@ def ceil_div(num, div):
return quanta
-def extended_gcd(a, b):
+def extended_gcd(a: int, b: int) -> typing.Tuple[int, int, int]:
"""Returns a tuple (r, i, j) such that r = gcd(a, b) = ia + jb
"""
# r = gcd(a,b) i = multiplicitive inverse of a mod b
@@ -128,7 +129,7 @@ def extended_gcd(a, b):
return a, lx, ly # Return only positive values
-def inverse(x, n):
+def inverse(x: int, n: int) -> int:
"""Returns the inverse of x % n under multiplication, a.k.a x^-1 (mod n)
>>> inverse(7, 4)
@@ -145,7 +146,7 @@ def inverse(x, n):
return inv
-def crt(a_values, modulo_values):
+def crt(a_values: typing.Iterable[int], modulo_values: typing.Iterable[int]) -> int:
"""Chinese Remainder Theorem.
Calculates x such that x = a[i] (mod m[i]) for each i.
diff --git a/rsa/core.py b/rsa/core.py
index 0660881..42f7bac 100644
--- a/rsa/core.py
+++ b/rsa/core.py
@@ -21,14 +21,14 @@ mathematically on integers.
"""
-def assert_int(var, name):
+def assert_int(var: int, name: str):
if isinstance(var, int):
return
raise TypeError('%s should be an integer, not %s' % (name, var.__class__))
-def encrypt_int(message, ekey, n):
+def encrypt_int(message: int, ekey: int, n: int) -> int:
"""Encrypts a message using encryption key 'ekey', working modulo n"""
assert_int(message, 'message')
@@ -44,7 +44,7 @@ def encrypt_int(message, ekey, n):
return pow(message, ekey, n)
-def decrypt_int(cyphertext, dkey, n):
+def decrypt_int(cyphertext: int, dkey: int, n: int) -> int:
"""Decrypts a cypher text using the decryption key 'dkey', working modulo n"""
assert_int(cyphertext, 'cyphertext')
diff --git a/rsa/key.py b/rsa/key.py
index 1565967..05c77ef 100644
--- a/rsa/key.py
+++ b/rsa/key.py
@@ -34,6 +34,7 @@ of pyasn1.
"""
import logging
+import typing
import warnings
import rsa.prime
@@ -47,17 +48,17 @@ log = logging.getLogger(__name__)
DEFAULT_EXPONENT = 65537
-class AbstractKey(object):
+class AbstractKey:
"""Abstract superclass for private and public keys."""
__slots__ = ('n', 'e')
- def __init__(self, n, e):
+ def __init__(self, n: int, e: int) -> None:
self.n = n
self.e = e
@classmethod
- def _load_pkcs1_pem(cls, keyfile):
+ def _load_pkcs1_pem(cls, keyfile: bytes) -> 'AbstractKey':
"""Loads a key in PKCS#1 PEM format, implement in a subclass.
:param keyfile: contents of a PEM-encoded file that contains
@@ -69,7 +70,7 @@ class AbstractKey(object):
"""
@classmethod
- def _load_pkcs1_der(cls, keyfile):
+ def _load_pkcs1_der(cls, keyfile: bytes) -> 'AbstractKey':
"""Loads a key in PKCS#1 PEM format, implement in a subclass.
:param keyfile: contents of a DER-encoded file that contains
@@ -80,14 +81,14 @@ class AbstractKey(object):
:rtype: AbstractKey
"""
- def _save_pkcs1_pem(self):
+ def _save_pkcs1_pem(self) -> bytes:
"""Saves the key in PKCS#1 PEM format, implement in a subclass.
:returns: the PEM-encoded key.
:rtype: bytes
"""
- def _save_pkcs1_der(self):
+ def _save_pkcs1_der(self) -> bytes:
"""Saves the key in PKCS#1 DER format, implement in a subclass.
:returns: the DER-encoded key.
@@ -95,7 +96,7 @@ class AbstractKey(object):
"""
@classmethod
- def load_pkcs1(cls, keyfile, format='PEM'):
+ def load_pkcs1(cls, keyfile: bytes, format='PEM') -> 'AbstractKey':
"""Loads a key in PKCS#1 DER or PEM format.
:param keyfile: contents of a DER- or PEM-encoded file that contains
@@ -117,7 +118,7 @@ class AbstractKey(object):
return method(keyfile)
@staticmethod
- def _assert_format_exists(file_format, methods):
+ def _assert_format_exists(file_format: str, methods: typing.Mapping[str, typing.Callable]) -> typing.Callable:
"""Checks whether the given file format exists in 'methods'.
"""
@@ -128,7 +129,7 @@ class AbstractKey(object):
raise ValueError('Unsupported format: %r, try one of %s' % (file_format,
formats))
- def save_pkcs1(self, format='PEM'):
+ def save_pkcs1(self, format='PEM') -> bytes:
"""Saves the key in PKCS#1 DER or PEM format.
:param format: the format to save; 'PEM' or 'DER'
@@ -145,7 +146,7 @@ class AbstractKey(object):
method = self._assert_format_exists(format, methods)
return method()
- def blind(self, message, r):
+ def blind(self, message: int, r: int) -> int:
"""Performs blinding on the message using random number 'r'.
:param message: the message, as integer, to blind.
@@ -162,7 +163,7 @@ class AbstractKey(object):
return (message * pow(r, self.e, self.n)) % self.n
- def unblind(self, blinded, r):
+ def unblind(self, blinded: int, r: int) -> int:
"""Performs blinding on the message using random number 'r'.
:param blinded: the blinded message, as integer, to unblind.
@@ -206,18 +207,18 @@ class PublicKey(AbstractKey):
def __getitem__(self, key):
return getattr(self, key)
- def __repr__(self):
+ def __repr__(self) -> str:
return 'PublicKey(%i, %i)' % (self.n, self.e)
- def __getstate__(self):
+ def __getstate__(self) -> typing.Tuple[int, int]:
"""Returns the key as tuple for pickling."""
return self.n, self.e
- def __setstate__(self, state):
+ def __setstate__(self, state: typing.Tuple[int, int]) -> None:
"""Sets the key from tuple."""
self.n, self.e = state
- def __eq__(self, other):
+ def __eq__(self, other: typing.Any) -> bool:
if other is None:
return False
@@ -226,14 +227,14 @@ class PublicKey(AbstractKey):
return self.n == other.n and self.e == other.e
- def __ne__(self, other):
+ def __ne__(self, other: typing.Any) -> bool:
return not (self == other)
- def __hash__(self):
+ def __hash__(self) -> int:
return hash((self.n, self.e))
@classmethod
- def _load_pkcs1_der(cls, keyfile):
+ def _load_pkcs1_der(cls, keyfile: bytes) -> 'PublicKey':
"""Loads a key in PKCS#1 DER format.
:param keyfile: contents of a DER-encoded file that contains the public
@@ -259,7 +260,7 @@ class PublicKey(AbstractKey):
(priv, _) = decoder.decode(keyfile, asn1Spec=AsnPubKey())
return cls(n=int(priv['modulus']), e=int(priv['publicExponent']))
- def _save_pkcs1_der(self):
+ def _save_pkcs1_der(self) -> bytes:
"""Saves the public key in PKCS#1 DER format.
:returns: the DER-encoded public key.
@@ -277,7 +278,7 @@ class PublicKey(AbstractKey):
return encoder.encode(asn_key)
@classmethod
- def _load_pkcs1_pem(cls, keyfile):
+ def _load_pkcs1_pem(cls, keyfile: bytes) -> 'PublicKey':
"""Loads a PKCS#1 PEM-encoded public key file.
The contents of the file before the "-----BEGIN RSA PUBLIC KEY-----" and
@@ -291,7 +292,7 @@ class PublicKey(AbstractKey):
der = rsa.pem.load_pem(keyfile, 'RSA PUBLIC KEY')
return cls._load_pkcs1_der(der)
- def _save_pkcs1_pem(self):
+ def _save_pkcs1_pem(self) -> bytes:
"""Saves a PKCS#1 PEM-encoded public key file.
:return: contents of a PEM-encoded file that contains the public key.
@@ -302,7 +303,7 @@ class PublicKey(AbstractKey):
return rsa.pem.save_pem(der, 'RSA PUBLIC KEY')
@classmethod
- def load_pkcs1_openssl_pem(cls, keyfile):
+ def load_pkcs1_openssl_pem(cls, keyfile: bytes) -> 'PublicKey':
"""Loads a PKCS#1.5 PEM-encoded public key file from OpenSSL.
These files can be recognised in that they start with BEGIN PUBLIC KEY
@@ -321,14 +322,12 @@ class PublicKey(AbstractKey):
return cls.load_pkcs1_openssl_der(der)
@classmethod
- def load_pkcs1_openssl_der(cls, keyfile):
+ def load_pkcs1_openssl_der(cls, keyfile: bytes) -> 'PublicKey':
"""Loads a PKCS#1 DER-encoded public key file from OpenSSL.
:param keyfile: contents of a DER-encoded file that contains the public
key, from OpenSSL.
:return: a PublicKey object
- :rtype: bytes
-
"""
from rsa.asn1 import OpenSSLPubKey
@@ -369,7 +368,7 @@ class PrivateKey(AbstractKey):
__slots__ = ('n', 'e', 'd', 'p', 'q', 'exp1', 'exp2', 'coef')
- def __init__(self, n, e, d, p, q):
+ def __init__(self, n: int, e: int, d: int, p: int, q: int) -> None:
AbstractKey.__init__(self, n, e)
self.d = d
self.p = p
@@ -383,18 +382,18 @@ class PrivateKey(AbstractKey):
def __getitem__(self, key):
return getattr(self, key)
- def __repr__(self):
- return 'PrivateKey(%(n)i, %(e)i, %(d)i, %(p)i, %(q)i)' % self
+ def __repr__(self) -> str:
+ return 'PrivateKey(%i, %i, %i, %i, %i)' % (self.n, self.e, self.d, self.p, self.q)
- def __getstate__(self):
+ def __getstate__(self) -> typing.Tuple[int, int, int, int, int, int, int, int]:
"""Returns the key as tuple for pickling."""
return self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef
- def __setstate__(self, state):
+ def __setstate__(self, state: typing.Tuple[int, int, int, int, int, int, int, int]):
"""Sets the key from tuple."""
self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef = state
- def __eq__(self, other):
+ def __eq__(self, other: typing.Any) -> bool:
if other is None:
return False
@@ -410,13 +409,13 @@ class PrivateKey(AbstractKey):
self.exp2 == other.exp2 and
self.coef == other.coef)
- def __ne__(self, other):
+ def __ne__(self, other: typing.Any) -> bool:
return not (self == other)
- def __hash__(self):
+ def __hash__(self) -> int:
return hash((self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef))
- def blinded_decrypt(self, encrypted):
+ def blinded_decrypt(self, encrypted: int) -> int:
"""Decrypts the message using blinding to prevent side-channel attacks.
:param encrypted: the encrypted message
@@ -432,7 +431,7 @@ class PrivateKey(AbstractKey):
return self.unblind(decrypted, blind_r)
- def blinded_encrypt(self, message):
+ def blinded_encrypt(self, message: int) -> int:
"""Encrypts the message using blinding to prevent side-channel attacks.
:param message: the message to encrypt
@@ -448,7 +447,7 @@ class PrivateKey(AbstractKey):
return self.unblind(encrypted, blind_r)
@classmethod
- def _load_pkcs1_der(cls, keyfile):
+ def _load_pkcs1_der(cls, keyfile: bytes) -> 'PrivateKey':
"""Loads a key in PKCS#1 DER format.
:param keyfile: contents of a DER-encoded file that contains the private
@@ -505,7 +504,7 @@ class PrivateKey(AbstractKey):
return key
- def _save_pkcs1_der(self):
+ def _save_pkcs1_der(self) -> bytes:
"""Saves the private key in PKCS#1 DER format.
:returns: the DER-encoded private key.
@@ -543,7 +542,7 @@ class PrivateKey(AbstractKey):
return encoder.encode(asn_key)
@classmethod
- def _load_pkcs1_pem(cls, keyfile):
+ def _load_pkcs1_pem(cls, keyfile: bytes) -> 'PrivateKey':
"""Loads a PKCS#1 PEM-encoded private key file.
The contents of the file before the "-----BEGIN RSA PRIVATE KEY-----" and
@@ -558,7 +557,7 @@ class PrivateKey(AbstractKey):
der = rsa.pem.load_pem(keyfile, b'RSA PRIVATE KEY')
return cls._load_pkcs1_der(der)
- def _save_pkcs1_pem(self):
+ def _save_pkcs1_pem(self) -> bytes:
"""Saves a PKCS#1 PEM-encoded private key file.
:return: contents of a PEM-encoded file that contains the private key.
@@ -569,7 +568,7 @@ class PrivateKey(AbstractKey):
return rsa.pem.save_pem(der, b'RSA PRIVATE KEY')
-def find_p_q(nbits, getprime_func=rsa.prime.getprime, accurate=True):
+def find_p_q(nbits: int, getprime_func=rsa.prime.getprime, accurate=True) -> typing.Tuple[int, int]:
"""Returns a tuple of two different primes of nbits bits each.
The resulting p * q has exacty 2 * nbits bits, and the returned p and q
@@ -647,7 +646,7 @@ def find_p_q(nbits, getprime_func=rsa.prime.getprime, accurate=True):
return max(p, q), min(p, q)
-def calculate_keys_custom_exponent(p, q, exponent):
+def calculate_keys_custom_exponent(p: int, q: int, exponent: int) -> typing.Tuple[int, int]:
"""Calculates an encryption and a decryption key given p, q and an exponent,
and returns them as a tuple (e, d)
@@ -677,7 +676,7 @@ def calculate_keys_custom_exponent(p, q, exponent):
return exponent, d
-def calculate_keys(p, q):
+def calculate_keys(p: int, q: int) -> typing.Tuple[int, int]:
"""Calculates an encryption and a decryption key given p and q, and
returns them as a tuple (e, d)
@@ -690,7 +689,10 @@ def calculate_keys(p, q):
return calculate_keys_custom_exponent(p, q, DEFAULT_EXPONENT)
-def gen_keys(nbits, getprime_func, accurate=True, exponent=DEFAULT_EXPONENT):
+def gen_keys(nbits: int,
+ getprime_func: typing.Callable[[int], int],
+ accurate=True,
+ exponent=DEFAULT_EXPONENT) -> typing.Tuple[int, int, int, int]:
"""Generate RSA keys of nbits bits. Returns (p, q, e, d).
Note: this can take a long time, depending on the key size.
@@ -718,7 +720,8 @@ def gen_keys(nbits, getprime_func, accurate=True, exponent=DEFAULT_EXPONENT):
return p, q, e, d
-def newkeys(nbits, accurate=True, poolsize=1, exponent=DEFAULT_EXPONENT):
+def newkeys(nbits: int, accurate=True, poolsize=1, exponent=DEFAULT_EXPONENT) \
+ -> typing.Tuple[PublicKey, PrivateKey]:
"""Generates public and private keys, and returns them as (pub, priv).
The public key is also known as the 'encryption key', and is a
@@ -753,9 +756,9 @@ def newkeys(nbits, accurate=True, poolsize=1, exponent=DEFAULT_EXPONENT):
# Determine which getprime function to use
if poolsize > 1:
from rsa import parallel
- import functools
- getprime_func = functools.partial(parallel.getprime, poolsize=poolsize)
+ def getprime_func(nbits):
+ return parallel.getprime(nbits, poolsize=poolsize)
else:
getprime_func = rsa.prime.getprime
diff --git a/rsa/parallel.py b/rsa/parallel.py
index ef9f07f..e81bcad 100644
--- a/rsa/parallel.py
+++ b/rsa/parallel.py
@@ -30,7 +30,7 @@ import rsa.prime
import rsa.randnum
-def _find_prime(nbits, pipe):
+def _find_prime(nbits: int, pipe) -> None:
while True:
integer = rsa.randnum.read_random_odd_int(nbits)
@@ -40,7 +40,7 @@ def _find_prime(nbits, pipe):
return
-def getprime(nbits, poolsize):
+def getprime(nbits: int, poolsize: int) -> int:
"""Returns a prime number that can be stored in 'nbits' bits.
Works in multiple threads at the same time.
diff --git a/rsa/pem.py b/rsa/pem.py
index 0650e64..02c7691 100644
--- a/rsa/pem.py
+++ b/rsa/pem.py
@@ -17,9 +17,13 @@
"""Functions that load and write PEM-encoded files."""
import base64
+import typing
+# Should either be ASCII strings or bytes.
+FlexiText = typing.Union[str, bytes]
-def _markers(pem_marker):
+
+def _markers(pem_marker: FlexiText) -> typing.Tuple[bytes, bytes]:
"""
Returns the start and end PEM markers, as bytes.
"""
@@ -31,7 +35,7 @@ def _markers(pem_marker):
b'-----END ' + pem_marker + b'-----')
-def load_pem(contents, pem_marker):
+def load_pem(contents: FlexiText, pem_marker: FlexiText) -> bytes:
"""Loads a PEM file.
:param contents: the contents of the file to interpret
@@ -97,7 +101,7 @@ def load_pem(contents, pem_marker):
return base64.standard_b64decode(pem)
-def save_pem(contents, pem_marker):
+def save_pem(contents: bytes, pem_marker: FlexiText) -> bytes:
"""Saves a PEM file.
:param contents: the contents to encode in PEM format
diff --git a/rsa/pkcs1.py b/rsa/pkcs1.py
index 310f22c..39ebc49 100644
--- a/rsa/pkcs1.py
+++ b/rsa/pkcs1.py
@@ -30,8 +30,9 @@ to your users.
import hashlib
import os
+import typing
-from rsa import common, transform, core
+from . import common, transform, core, key
# ASN.1 codes that describe the hash algorithm used.
HASH_ASN1 = {
@@ -65,7 +66,7 @@ class VerificationError(CryptoError):
"""Raised when verification fails."""
-def _pad_for_encryption(message, target_length):
+def _pad_for_encryption(message: bytes, target_length: int) -> bytes:
r"""Pads the message for encryption, returning the padded message.
:return: 00 02 RANDOM_DATA 00 MESSAGE
@@ -111,7 +112,7 @@ def _pad_for_encryption(message, target_length):
message])
-def _pad_for_signing(message, target_length):
+def _pad_for_signing(message: bytes, target_length: int) -> bytes:
r"""Pads the message for signing, returning the padded message.
The padding is always a repetition of FF bytes.
@@ -145,7 +146,7 @@ def _pad_for_signing(message, target_length):
message])
-def encrypt(message, pub_key):
+def encrypt(message: bytes, pub_key: key.PublicKey):
"""Encrypts the given message using PKCS#1 v1.5
:param message: the message to encrypt. Must be a byte string no longer than
@@ -177,7 +178,7 @@ def encrypt(message, pub_key):
return block
-def decrypt(crypto, priv_key):
+def decrypt(crypto: bytes, priv_key: key.PrivateKey) -> bytes:
r"""Decrypts the given message using PKCS#1 v1.5
The decryption is considered 'failed' when the resulting cleartext doesn't
@@ -246,14 +247,13 @@ def decrypt(crypto, priv_key):
return cleartext[sep_idx + 1:]
-def sign_hash(hash_value, priv_key, hash_method):
+def sign_hash(hash_value: bytes, priv_key: key.PrivateKey, hash_method: str) -> bytes:
"""Signs a precomputed hash with the private key.
Hashes the message, then signs the hash with the given key. This is known
as a "detached signature", because the message itself isn't altered.
- :param hash_value: A precomputed hash to sign (ignores message). Should be set to
- None if needing to hash and sign message.
+ :param hash_value: A precomputed hash to sign (ignores message).
:param priv_key: the :py:class:`rsa.PrivateKey` to sign with
:param hash_method: the hash method used on the message. Use 'MD5', 'SHA-1',
'SHA-224', SHA-256', 'SHA-384' or 'SHA-512'.
@@ -280,7 +280,7 @@ def sign_hash(hash_value, priv_key, hash_method):
return block
-def sign(message, priv_key, hash_method):
+def sign(message: bytes, priv_key: key.PrivateKey, hash_method: str) -> bytes:
"""Signs the message with the private key.
Hashes the message, then signs the hash with the given key. This is known
@@ -302,7 +302,7 @@ def sign(message, priv_key, hash_method):
return sign_hash(msg_hash, priv_key, hash_method)
-def verify(message, signature, pub_key):
+def verify(message: bytes, signature: bytes, pub_key: key.PublicKey) -> str:
"""Verifies that the signature matches the message.
The hash method is detected automatically from the signature.
@@ -337,7 +337,7 @@ def verify(message, signature, pub_key):
return method_name
-def find_signature_hash(signature, pub_key):
+def find_signature_hash(signature: bytes, pub_key: key.PublicKey) -> str:
"""Returns the hash name detected from the signature.
If you also want to verify the message, use :py:func:`rsa.verify()` instead.
@@ -356,7 +356,7 @@ def find_signature_hash(signature, pub_key):
return _find_method_hash(clearsig)
-def yield_fixedblocks(infile, blocksize):
+def yield_fixedblocks(infile: typing.BinaryIO, blocksize: int) -> typing.Iterator[bytes]:
"""Generator, yields each block of ``blocksize`` bytes in the input file.
:param infile: file to read and separate in blocks.
@@ -377,7 +377,7 @@ def yield_fixedblocks(infile, blocksize):
break
-def compute_hash(message, method_name):
+def compute_hash(message: typing.Union[bytes, typing.BinaryIO], method_name: str) -> bytes:
"""Returns the message digest.
:param message: the signed message. Can be an 8-bit string or a file-like
@@ -394,18 +394,18 @@ def compute_hash(message, method_name):
method = HASH_METHODS[method_name]
hasher = method()
- if hasattr(message, 'read') and hasattr(message.read, '__call__'):
+ if isinstance(message, bytes):
+ hasher.update(message)
+ else:
+ assert hasattr(message, 'read') and hasattr(message.read, '__call__')
# read as 1K blocks
for block in yield_fixedblocks(message, 1024):
hasher.update(block)
- else:
- # hash the message object itself.
- hasher.update(message)
return hasher.digest()
-def _find_method_hash(clearsig):
+def _find_method_hash(clearsig: bytes) -> str:
"""Finds the hash method.
:param clearsig: full padded ASN1 and hash.
diff --git a/rsa/pkcs1_v2.py b/rsa/pkcs1_v2.py
index 6242a71..b751399 100644
--- a/rsa/pkcs1_v2.py
+++ b/rsa/pkcs1_v2.py
@@ -27,7 +27,7 @@ from rsa import (
)
-def mgf1(seed, length, hasher='SHA-1'):
+def mgf1(seed: bytes, length: int, hasher='SHA-1') -> bytes:
"""
MGF1 is a Mask Generation Function based on a hash function.
diff --git a/rsa/prime.py b/rsa/prime.py
index a45f659..dcd60dd 100644
--- a/rsa/prime.py
+++ b/rsa/prime.py
@@ -26,7 +26,7 @@ import rsa.randnum
__all__ = ['getprime', 'are_relatively_prime']
-def gcd(p, q):
+def gcd(p: int, q: int) -> int:
"""Returns the greatest common divisor of p and q
>>> gcd(48, 180)
@@ -38,7 +38,7 @@ def gcd(p, q):
return p
-def get_primality_testing_rounds(number):
+def get_primality_testing_rounds(number: int) -> int:
"""Returns minimum number of rounds for Miller-Rabing primality testing,
based on number bitsize.
@@ -64,7 +64,7 @@ def get_primality_testing_rounds(number):
return 10
-def miller_rabin_primality_testing(n, k):
+def miller_rabin_primality_testing(n: int, k: int) -> bool:
"""Calculates whether n is composite (which is always correct) or prime
(which theoretically is incorrect with error probability 4**-k), by
applying Miller-Rabin primality testing.
@@ -117,7 +117,7 @@ def miller_rabin_primality_testing(n, k):
return True
-def is_prime(number):
+def is_prime(number: int) -> bool:
"""Returns True if the number is prime, and False otherwise.
>>> is_prime(2)
@@ -143,7 +143,7 @@ def is_prime(number):
return miller_rabin_primality_testing(number, k + 1)
-def getprime(nbits):
+def getprime(nbits: int) -> int:
"""Returns a prime number that can be stored in 'nbits' bits.
>>> p = getprime(128)
@@ -171,7 +171,7 @@ def getprime(nbits):
# Retry if not prime
-def are_relatively_prime(a, b):
+def are_relatively_prime(a: int, b: int) -> bool:
"""Returns True if a and b are relatively prime, and False if they
are not.
diff --git a/rsa/randnum.py b/rsa/randnum.py
index 1f0a4e5..e9bfc87 100644
--- a/rsa/randnum.py
+++ b/rsa/randnum.py
@@ -24,7 +24,7 @@ import struct
from rsa import common, transform
-def read_random_bits(nbits):
+def read_random_bits(nbits: int) -> bytes:
"""Reads 'nbits' random bits.
If nbits isn't a whole number of bytes, an extra byte will be appended with
@@ -45,7 +45,7 @@ def read_random_bits(nbits):
return randomdata
-def read_random_int(nbits):
+def read_random_int(nbits: int) -> int:
"""Reads a random integer of approximately nbits bits.
"""
@@ -59,7 +59,7 @@ def read_random_int(nbits):
return value
-def read_random_odd_int(nbits):
+def read_random_odd_int(nbits: int) -> int:
"""Reads a random odd integer of approximately nbits bits.
>>> read_random_odd_int(512) & 1
@@ -72,7 +72,7 @@ def read_random_odd_int(nbits):
return value | 1
-def randint(maxvalue):
+def randint(maxvalue: int) -> int:
"""Returns a random integer x with 1 <= x <= maxvalue
May take a very long time in specific situations. If maxvalue needs N bits
diff --git a/rsa/util.py b/rsa/util.py
index c44d04c..e0c7134 100644
--- a/rsa/util.py
+++ b/rsa/util.py
@@ -22,7 +22,7 @@ from optparse import OptionParser
import rsa.key
-def private_to_public():
+def private_to_public() -> None:
"""Reads a private key and outputs the corresponding public key."""
# Parse the CLI options