From b6cebd53fcafd3088fc8361f6d3466166f75410b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sybren=20A=2E=20St=C3=BCvel?= Date: Sun, 4 Aug 2019 16:41:01 +0200 Subject: Added type annotations + some fixes to get them correct One functional change: `CryptoOperation.read_infile()` now reads bytes from `sys.stdin` instead of text. This is necessary to be consistent with the rest of the code, which all deals with bytes. --- rsa/_compat.py | 7 ++--- rsa/cli.py | 44 +++++++++++++++----------- rsa/common.py | 19 ++++++------ rsa/core.py | 6 ++-- rsa/key.py | 95 +++++++++++++++++++++++++++++---------------------------- rsa/parallel.py | 4 +-- rsa/pem.py | 10 ++++-- rsa/pkcs1.py | 36 +++++++++++----------- rsa/pkcs1_v2.py | 2 +- rsa/prime.py | 12 ++++---- rsa/randnum.py | 8 ++--- rsa/util.py | 2 +- 12 files changed, 129 insertions(+), 116 deletions(-) diff --git a/rsa/_compat.py b/rsa/_compat.py index 843583c..b31331e 100644 --- a/rsa/_compat.py +++ b/rsa/_compat.py @@ -21,14 +21,11 @@ import sys from struct import pack -def byte(num):## XXX +def byte(num: int): """ Converts a number between 0 and 255 (both inclusive) to a base-256 (byte) representation. - Use it as a replacement for ``chr`` where you are expecting a byte - because this will work on all current versions of Python:: - :param num: An unsigned integer between 0 and 255 (both inclusive). :returns: @@ -37,7 +34,7 @@ def byte(num):## XXX return pack("B", num) -def xor_bytes(b1, b2): +def xor_bytes(b1: bytes, b2: bytes) -> bytes: """ Returns the bitwise XOR result between two bytes objects, b1 ^ b2. diff --git a/rsa/cli.py b/rsa/cli.py index cbf3f97..60bc07c 100644 --- a/rsa/cli.py +++ b/rsa/cli.py @@ -22,20 +22,21 @@ These scripts are called by the executables defined in setup.py. import abc import sys import typing -from optparse import OptionParser +import optparse import rsa import rsa.key import rsa.pkcs1 HASH_METHODS = sorted(rsa.pkcs1.HASH_METHODS.keys()) +Indexable = typing.Union[typing.Tuple, typing.List[str]] -def keygen(): +def keygen() -> None: """Key generator.""" # Parse the CLI options - parser = OptionParser(usage='usage: %prog [options] keysize', + parser = optparse.OptionParser(usage='usage: %prog [options] keysize', description='Generates a new RSA keypair of "keysize" bits.') parser.add_option('--pubout', type='string', @@ -104,13 +105,14 @@ class CryptoOperation(metaclass=abc.ABCMeta): key_class = rsa.PublicKey # type: typing.Type[rsa.key.AbstractKey] - def __init__(self): + def __init__(self) -> None: self.usage = self.usage % self.__class__.__dict__ self.input_help = self.input_help % self.__class__.__dict__ self.output_help = self.output_help % self.__class__.__dict__ @abc.abstractmethod - def perform_operation(self, indata, key, cli_args): + def perform_operation(self, indata: bytes, key: rsa.key.AbstractKey, + cli_args: Indexable): """Performs the program's operation. Implement in a subclass. @@ -118,7 +120,7 @@ class CryptoOperation(metaclass=abc.ABCMeta): :returns: the data to write to the output. """ - def __call__(self): + def __call__(self) -> None: """Runs the program.""" (cli, cli_args) = self.parse_cli() @@ -133,13 +135,13 @@ class CryptoOperation(metaclass=abc.ABCMeta): if self.has_output: self.write_outfile(outdata, cli.output) - def parse_cli(self): + def parse_cli(self) -> typing.Tuple[optparse.Values, typing.List[str]]: """Parse the CLI options :returns: (cli_opts, cli_args) """ - parser = OptionParser(usage=self.usage, description=self.description) + parser = optparse.OptionParser(usage=self.usage, description=self.description) parser.add_option('-i', '--input', type='string', help=self.input_help) @@ -158,7 +160,7 @@ class CryptoOperation(metaclass=abc.ABCMeta): return cli, cli_args - def read_key(self, filename, keyform): + def read_key(self, filename: str, keyform: str) -> rsa.key.AbstractKey: """Reads a public or private key.""" print('Reading %s key from %s' % (self.keyname, filename), file=sys.stderr) @@ -167,7 +169,7 @@ class CryptoOperation(metaclass=abc.ABCMeta): return self.key_class.load_pkcs1(keydata, keyform) - def read_infile(self, inname): + def read_infile(self, inname: str) -> bytes: """Read the input file""" if inname: @@ -176,9 +178,9 @@ class CryptoOperation(metaclass=abc.ABCMeta): return infile.read() print('Reading input from stdin', file=sys.stderr) - return sys.stdin.read() + return sys.stdin.buffer.read() - def write_outfile(self, outdata, outname): + def write_outfile(self, outdata: bytes, outname: str) -> None: """Write the output file""" if outname: @@ -200,9 +202,10 @@ class EncryptOperation(CryptoOperation): operation_past = 'encrypted' operation_progressive = 'encrypting' - def perform_operation(self, indata, pub_key, cli_args=None): + def perform_operation(self, indata: bytes, pub_key: rsa.key.AbstractKey, + cli_args: Indexable=()): """Encrypts files.""" - + assert isinstance(pub_key, rsa.key.PublicKey) return rsa.encrypt(indata, pub_key) @@ -217,9 +220,10 @@ class DecryptOperation(CryptoOperation): operation_progressive = 'decrypting' key_class = rsa.PrivateKey - def perform_operation(self, indata, priv_key, cli_args=None): + def perform_operation(self, indata: bytes, priv_key: rsa.key.AbstractKey, + cli_args: Indexable=()): """Decrypts files.""" - + assert isinstance(priv_key, rsa.key.PrivateKey) return rsa.decrypt(indata, priv_key) @@ -239,8 +243,10 @@ class SignOperation(CryptoOperation): output_help = ('Name of the file to write the signature to. Written ' 'to stdout if this option is not present.') - def perform_operation(self, indata, priv_key, cli_args): + def perform_operation(self, indata: bytes, priv_key: rsa.key.AbstractKey, + cli_args: Indexable): """Signs files.""" + assert isinstance(priv_key, rsa.key.PrivateKey) hash_method = cli_args[1] if hash_method not in HASH_METHODS: @@ -264,8 +270,10 @@ class VerifyOperation(CryptoOperation): expected_cli_args = 2 has_output = False - def perform_operation(self, indata, pub_key, cli_args): + def perform_operation(self, indata: bytes, pub_key: rsa.key.AbstractKey, + cli_args: Indexable): """Verifies files.""" + assert isinstance(pub_key, rsa.key.PublicKey) signature_file = cli_args[1] diff --git a/rsa/common.py b/rsa/common.py index a4337f6..b983b98 100644 --- a/rsa/common.py +++ b/rsa/common.py @@ -16,17 +16,18 @@ """Common functionality shared by several modules.""" +import typing + class NotRelativePrimeError(ValueError): - def __init__(self, a, b, d, msg=None): - super(NotRelativePrimeError, self).__init__( - msg or "%d and %d are not relatively prime, divider=%i" % (a, b, d)) + def __init__(self, a, b, d, msg=''): + super().__init__(msg or "%d and %d are not relatively prime, divider=%i" % (a, b, d)) self.a = a self.b = b self.d = d -def bit_size(num): +def bit_size(num: int) -> int: """ Number of bits needed to represent a integer excluding any prefix 0 bits. @@ -54,7 +55,7 @@ def bit_size(num): raise TypeError('bit_size(num) only supports integers, not %r' % type(num)) -def byte_size(number): +def byte_size(number: int) -> int: """ Returns the number of bytes required to hold a specific long number. @@ -79,7 +80,7 @@ def byte_size(number): return ceil_div(bit_size(number), 8) -def ceil_div(num, div): +def ceil_div(num: int, div: int) -> int: """ Returns the ceiling function of a division between `num` and `div`. @@ -103,7 +104,7 @@ def ceil_div(num, div): return quanta -def extended_gcd(a, b): +def extended_gcd(a: int, b: int) -> typing.Tuple[int, int, int]: """Returns a tuple (r, i, j) such that r = gcd(a, b) = ia + jb """ # r = gcd(a,b) i = multiplicitive inverse of a mod b @@ -128,7 +129,7 @@ def extended_gcd(a, b): return a, lx, ly # Return only positive values -def inverse(x, n): +def inverse(x: int, n: int) -> int: """Returns the inverse of x % n under multiplication, a.k.a x^-1 (mod n) >>> inverse(7, 4) @@ -145,7 +146,7 @@ def inverse(x, n): return inv -def crt(a_values, modulo_values): +def crt(a_values: typing.Iterable[int], modulo_values: typing.Iterable[int]) -> int: """Chinese Remainder Theorem. Calculates x such that x = a[i] (mod m[i]) for each i. diff --git a/rsa/core.py b/rsa/core.py index 0660881..42f7bac 100644 --- a/rsa/core.py +++ b/rsa/core.py @@ -21,14 +21,14 @@ mathematically on integers. """ -def assert_int(var, name): +def assert_int(var: int, name: str): if isinstance(var, int): return raise TypeError('%s should be an integer, not %s' % (name, var.__class__)) -def encrypt_int(message, ekey, n): +def encrypt_int(message: int, ekey: int, n: int) -> int: """Encrypts a message using encryption key 'ekey', working modulo n""" assert_int(message, 'message') @@ -44,7 +44,7 @@ def encrypt_int(message, ekey, n): return pow(message, ekey, n) -def decrypt_int(cyphertext, dkey, n): +def decrypt_int(cyphertext: int, dkey: int, n: int) -> int: """Decrypts a cypher text using the decryption key 'dkey', working modulo n""" assert_int(cyphertext, 'cyphertext') diff --git a/rsa/key.py b/rsa/key.py index 1565967..05c77ef 100644 --- a/rsa/key.py +++ b/rsa/key.py @@ -34,6 +34,7 @@ of pyasn1. """ import logging +import typing import warnings import rsa.prime @@ -47,17 +48,17 @@ log = logging.getLogger(__name__) DEFAULT_EXPONENT = 65537 -class AbstractKey(object): +class AbstractKey: """Abstract superclass for private and public keys.""" __slots__ = ('n', 'e') - def __init__(self, n, e): + def __init__(self, n: int, e: int) -> None: self.n = n self.e = e @classmethod - def _load_pkcs1_pem(cls, keyfile): + def _load_pkcs1_pem(cls, keyfile: bytes) -> 'AbstractKey': """Loads a key in PKCS#1 PEM format, implement in a subclass. :param keyfile: contents of a PEM-encoded file that contains @@ -69,7 +70,7 @@ class AbstractKey(object): """ @classmethod - def _load_pkcs1_der(cls, keyfile): + def _load_pkcs1_der(cls, keyfile: bytes) -> 'AbstractKey': """Loads a key in PKCS#1 PEM format, implement in a subclass. :param keyfile: contents of a DER-encoded file that contains @@ -80,14 +81,14 @@ class AbstractKey(object): :rtype: AbstractKey """ - def _save_pkcs1_pem(self): + def _save_pkcs1_pem(self) -> bytes: """Saves the key in PKCS#1 PEM format, implement in a subclass. :returns: the PEM-encoded key. :rtype: bytes """ - def _save_pkcs1_der(self): + def _save_pkcs1_der(self) -> bytes: """Saves the key in PKCS#1 DER format, implement in a subclass. :returns: the DER-encoded key. @@ -95,7 +96,7 @@ class AbstractKey(object): """ @classmethod - def load_pkcs1(cls, keyfile, format='PEM'): + def load_pkcs1(cls, keyfile: bytes, format='PEM') -> 'AbstractKey': """Loads a key in PKCS#1 DER or PEM format. :param keyfile: contents of a DER- or PEM-encoded file that contains @@ -117,7 +118,7 @@ class AbstractKey(object): return method(keyfile) @staticmethod - def _assert_format_exists(file_format, methods): + def _assert_format_exists(file_format: str, methods: typing.Mapping[str, typing.Callable]) -> typing.Callable: """Checks whether the given file format exists in 'methods'. """ @@ -128,7 +129,7 @@ class AbstractKey(object): raise ValueError('Unsupported format: %r, try one of %s' % (file_format, formats)) - def save_pkcs1(self, format='PEM'): + def save_pkcs1(self, format='PEM') -> bytes: """Saves the key in PKCS#1 DER or PEM format. :param format: the format to save; 'PEM' or 'DER' @@ -145,7 +146,7 @@ class AbstractKey(object): method = self._assert_format_exists(format, methods) return method() - def blind(self, message, r): + def blind(self, message: int, r: int) -> int: """Performs blinding on the message using random number 'r'. :param message: the message, as integer, to blind. @@ -162,7 +163,7 @@ class AbstractKey(object): return (message * pow(r, self.e, self.n)) % self.n - def unblind(self, blinded, r): + def unblind(self, blinded: int, r: int) -> int: """Performs blinding on the message using random number 'r'. :param blinded: the blinded message, as integer, to unblind. @@ -206,18 +207,18 @@ class PublicKey(AbstractKey): def __getitem__(self, key): return getattr(self, key) - def __repr__(self): + def __repr__(self) -> str: return 'PublicKey(%i, %i)' % (self.n, self.e) - def __getstate__(self): + def __getstate__(self) -> typing.Tuple[int, int]: """Returns the key as tuple for pickling.""" return self.n, self.e - def __setstate__(self, state): + def __setstate__(self, state: typing.Tuple[int, int]) -> None: """Sets the key from tuple.""" self.n, self.e = state - def __eq__(self, other): + def __eq__(self, other: typing.Any) -> bool: if other is None: return False @@ -226,14 +227,14 @@ class PublicKey(AbstractKey): return self.n == other.n and self.e == other.e - def __ne__(self, other): + def __ne__(self, other: typing.Any) -> bool: return not (self == other) - def __hash__(self): + def __hash__(self) -> int: return hash((self.n, self.e)) @classmethod - def _load_pkcs1_der(cls, keyfile): + def _load_pkcs1_der(cls, keyfile: bytes) -> 'PublicKey': """Loads a key in PKCS#1 DER format. :param keyfile: contents of a DER-encoded file that contains the public @@ -259,7 +260,7 @@ class PublicKey(AbstractKey): (priv, _) = decoder.decode(keyfile, asn1Spec=AsnPubKey()) return cls(n=int(priv['modulus']), e=int(priv['publicExponent'])) - def _save_pkcs1_der(self): + def _save_pkcs1_der(self) -> bytes: """Saves the public key in PKCS#1 DER format. :returns: the DER-encoded public key. @@ -277,7 +278,7 @@ class PublicKey(AbstractKey): return encoder.encode(asn_key) @classmethod - def _load_pkcs1_pem(cls, keyfile): + def _load_pkcs1_pem(cls, keyfile: bytes) -> 'PublicKey': """Loads a PKCS#1 PEM-encoded public key file. The contents of the file before the "-----BEGIN RSA PUBLIC KEY-----" and @@ -291,7 +292,7 @@ class PublicKey(AbstractKey): der = rsa.pem.load_pem(keyfile, 'RSA PUBLIC KEY') return cls._load_pkcs1_der(der) - def _save_pkcs1_pem(self): + def _save_pkcs1_pem(self) -> bytes: """Saves a PKCS#1 PEM-encoded public key file. :return: contents of a PEM-encoded file that contains the public key. @@ -302,7 +303,7 @@ class PublicKey(AbstractKey): return rsa.pem.save_pem(der, 'RSA PUBLIC KEY') @classmethod - def load_pkcs1_openssl_pem(cls, keyfile): + def load_pkcs1_openssl_pem(cls, keyfile: bytes) -> 'PublicKey': """Loads a PKCS#1.5 PEM-encoded public key file from OpenSSL. These files can be recognised in that they start with BEGIN PUBLIC KEY @@ -321,14 +322,12 @@ class PublicKey(AbstractKey): return cls.load_pkcs1_openssl_der(der) @classmethod - def load_pkcs1_openssl_der(cls, keyfile): + def load_pkcs1_openssl_der(cls, keyfile: bytes) -> 'PublicKey': """Loads a PKCS#1 DER-encoded public key file from OpenSSL. :param keyfile: contents of a DER-encoded file that contains the public key, from OpenSSL. :return: a PublicKey object - :rtype: bytes - """ from rsa.asn1 import OpenSSLPubKey @@ -369,7 +368,7 @@ class PrivateKey(AbstractKey): __slots__ = ('n', 'e', 'd', 'p', 'q', 'exp1', 'exp2', 'coef') - def __init__(self, n, e, d, p, q): + def __init__(self, n: int, e: int, d: int, p: int, q: int) -> None: AbstractKey.__init__(self, n, e) self.d = d self.p = p @@ -383,18 +382,18 @@ class PrivateKey(AbstractKey): def __getitem__(self, key): return getattr(self, key) - def __repr__(self): - return 'PrivateKey(%(n)i, %(e)i, %(d)i, %(p)i, %(q)i)' % self + def __repr__(self) -> str: + return 'PrivateKey(%i, %i, %i, %i, %i)' % (self.n, self.e, self.d, self.p, self.q) - def __getstate__(self): + def __getstate__(self) -> typing.Tuple[int, int, int, int, int, int, int, int]: """Returns the key as tuple for pickling.""" return self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef - def __setstate__(self, state): + def __setstate__(self, state: typing.Tuple[int, int, int, int, int, int, int, int]): """Sets the key from tuple.""" self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef = state - def __eq__(self, other): + def __eq__(self, other: typing.Any) -> bool: if other is None: return False @@ -410,13 +409,13 @@ class PrivateKey(AbstractKey): self.exp2 == other.exp2 and self.coef == other.coef) - def __ne__(self, other): + def __ne__(self, other: typing.Any) -> bool: return not (self == other) - def __hash__(self): + def __hash__(self) -> int: return hash((self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef)) - def blinded_decrypt(self, encrypted): + def blinded_decrypt(self, encrypted: int) -> int: """Decrypts the message using blinding to prevent side-channel attacks. :param encrypted: the encrypted message @@ -432,7 +431,7 @@ class PrivateKey(AbstractKey): return self.unblind(decrypted, blind_r) - def blinded_encrypt(self, message): + def blinded_encrypt(self, message: int) -> int: """Encrypts the message using blinding to prevent side-channel attacks. :param message: the message to encrypt @@ -448,7 +447,7 @@ class PrivateKey(AbstractKey): return self.unblind(encrypted, blind_r) @classmethod - def _load_pkcs1_der(cls, keyfile): + def _load_pkcs1_der(cls, keyfile: bytes) -> 'PrivateKey': """Loads a key in PKCS#1 DER format. :param keyfile: contents of a DER-encoded file that contains the private @@ -505,7 +504,7 @@ class PrivateKey(AbstractKey): return key - def _save_pkcs1_der(self): + def _save_pkcs1_der(self) -> bytes: """Saves the private key in PKCS#1 DER format. :returns: the DER-encoded private key. @@ -543,7 +542,7 @@ class PrivateKey(AbstractKey): return encoder.encode(asn_key) @classmethod - def _load_pkcs1_pem(cls, keyfile): + def _load_pkcs1_pem(cls, keyfile: bytes) -> 'PrivateKey': """Loads a PKCS#1 PEM-encoded private key file. The contents of the file before the "-----BEGIN RSA PRIVATE KEY-----" and @@ -558,7 +557,7 @@ class PrivateKey(AbstractKey): der = rsa.pem.load_pem(keyfile, b'RSA PRIVATE KEY') return cls._load_pkcs1_der(der) - def _save_pkcs1_pem(self): + def _save_pkcs1_pem(self) -> bytes: """Saves a PKCS#1 PEM-encoded private key file. :return: contents of a PEM-encoded file that contains the private key. @@ -569,7 +568,7 @@ class PrivateKey(AbstractKey): return rsa.pem.save_pem(der, b'RSA PRIVATE KEY') -def find_p_q(nbits, getprime_func=rsa.prime.getprime, accurate=True): +def find_p_q(nbits: int, getprime_func=rsa.prime.getprime, accurate=True) -> typing.Tuple[int, int]: """Returns a tuple of two different primes of nbits bits each. The resulting p * q has exacty 2 * nbits bits, and the returned p and q @@ -647,7 +646,7 @@ def find_p_q(nbits, getprime_func=rsa.prime.getprime, accurate=True): return max(p, q), min(p, q) -def calculate_keys_custom_exponent(p, q, exponent): +def calculate_keys_custom_exponent(p: int, q: int, exponent: int) -> typing.Tuple[int, int]: """Calculates an encryption and a decryption key given p, q and an exponent, and returns them as a tuple (e, d) @@ -677,7 +676,7 @@ def calculate_keys_custom_exponent(p, q, exponent): return exponent, d -def calculate_keys(p, q): +def calculate_keys(p: int, q: int) -> typing.Tuple[int, int]: """Calculates an encryption and a decryption key given p and q, and returns them as a tuple (e, d) @@ -690,7 +689,10 @@ def calculate_keys(p, q): return calculate_keys_custom_exponent(p, q, DEFAULT_EXPONENT) -def gen_keys(nbits, getprime_func, accurate=True, exponent=DEFAULT_EXPONENT): +def gen_keys(nbits: int, + getprime_func: typing.Callable[[int], int], + accurate=True, + exponent=DEFAULT_EXPONENT) -> typing.Tuple[int, int, int, int]: """Generate RSA keys of nbits bits. Returns (p, q, e, d). Note: this can take a long time, depending on the key size. @@ -718,7 +720,8 @@ def gen_keys(nbits, getprime_func, accurate=True, exponent=DEFAULT_EXPONENT): return p, q, e, d -def newkeys(nbits, accurate=True, poolsize=1, exponent=DEFAULT_EXPONENT): +def newkeys(nbits: int, accurate=True, poolsize=1, exponent=DEFAULT_EXPONENT) \ + -> typing.Tuple[PublicKey, PrivateKey]: """Generates public and private keys, and returns them as (pub, priv). The public key is also known as the 'encryption key', and is a @@ -753,9 +756,9 @@ def newkeys(nbits, accurate=True, poolsize=1, exponent=DEFAULT_EXPONENT): # Determine which getprime function to use if poolsize > 1: from rsa import parallel - import functools - getprime_func = functools.partial(parallel.getprime, poolsize=poolsize) + def getprime_func(nbits): + return parallel.getprime(nbits, poolsize=poolsize) else: getprime_func = rsa.prime.getprime diff --git a/rsa/parallel.py b/rsa/parallel.py index ef9f07f..e81bcad 100644 --- a/rsa/parallel.py +++ b/rsa/parallel.py @@ -30,7 +30,7 @@ import rsa.prime import rsa.randnum -def _find_prime(nbits, pipe): +def _find_prime(nbits: int, pipe) -> None: while True: integer = rsa.randnum.read_random_odd_int(nbits) @@ -40,7 +40,7 @@ def _find_prime(nbits, pipe): return -def getprime(nbits, poolsize): +def getprime(nbits: int, poolsize: int) -> int: """Returns a prime number that can be stored in 'nbits' bits. Works in multiple threads at the same time. diff --git a/rsa/pem.py b/rsa/pem.py index 0650e64..02c7691 100644 --- a/rsa/pem.py +++ b/rsa/pem.py @@ -17,9 +17,13 @@ """Functions that load and write PEM-encoded files.""" import base64 +import typing +# Should either be ASCII strings or bytes. +FlexiText = typing.Union[str, bytes] -def _markers(pem_marker): + +def _markers(pem_marker: FlexiText) -> typing.Tuple[bytes, bytes]: """ Returns the start and end PEM markers, as bytes. """ @@ -31,7 +35,7 @@ def _markers(pem_marker): b'-----END ' + pem_marker + b'-----') -def load_pem(contents, pem_marker): +def load_pem(contents: FlexiText, pem_marker: FlexiText) -> bytes: """Loads a PEM file. :param contents: the contents of the file to interpret @@ -97,7 +101,7 @@ def load_pem(contents, pem_marker): return base64.standard_b64decode(pem) -def save_pem(contents, pem_marker): +def save_pem(contents: bytes, pem_marker: FlexiText) -> bytes: """Saves a PEM file. :param contents: the contents to encode in PEM format diff --git a/rsa/pkcs1.py b/rsa/pkcs1.py index 310f22c..39ebc49 100644 --- a/rsa/pkcs1.py +++ b/rsa/pkcs1.py @@ -30,8 +30,9 @@ to your users. import hashlib import os +import typing -from rsa import common, transform, core +from . import common, transform, core, key # ASN.1 codes that describe the hash algorithm used. HASH_ASN1 = { @@ -65,7 +66,7 @@ class VerificationError(CryptoError): """Raised when verification fails.""" -def _pad_for_encryption(message, target_length): +def _pad_for_encryption(message: bytes, target_length: int) -> bytes: r"""Pads the message for encryption, returning the padded message. :return: 00 02 RANDOM_DATA 00 MESSAGE @@ -111,7 +112,7 @@ def _pad_for_encryption(message, target_length): message]) -def _pad_for_signing(message, target_length): +def _pad_for_signing(message: bytes, target_length: int) -> bytes: r"""Pads the message for signing, returning the padded message. The padding is always a repetition of FF bytes. @@ -145,7 +146,7 @@ def _pad_for_signing(message, target_length): message]) -def encrypt(message, pub_key): +def encrypt(message: bytes, pub_key: key.PublicKey): """Encrypts the given message using PKCS#1 v1.5 :param message: the message to encrypt. Must be a byte string no longer than @@ -177,7 +178,7 @@ def encrypt(message, pub_key): return block -def decrypt(crypto, priv_key): +def decrypt(crypto: bytes, priv_key: key.PrivateKey) -> bytes: r"""Decrypts the given message using PKCS#1 v1.5 The decryption is considered 'failed' when the resulting cleartext doesn't @@ -246,14 +247,13 @@ def decrypt(crypto, priv_key): return cleartext[sep_idx + 1:] -def sign_hash(hash_value, priv_key, hash_method): +def sign_hash(hash_value: bytes, priv_key: key.PrivateKey, hash_method: str) -> bytes: """Signs a precomputed hash with the private key. Hashes the message, then signs the hash with the given key. This is known as a "detached signature", because the message itself isn't altered. - :param hash_value: A precomputed hash to sign (ignores message). Should be set to - None if needing to hash and sign message. + :param hash_value: A precomputed hash to sign (ignores message). :param priv_key: the :py:class:`rsa.PrivateKey` to sign with :param hash_method: the hash method used on the message. Use 'MD5', 'SHA-1', 'SHA-224', SHA-256', 'SHA-384' or 'SHA-512'. @@ -280,7 +280,7 @@ def sign_hash(hash_value, priv_key, hash_method): return block -def sign(message, priv_key, hash_method): +def sign(message: bytes, priv_key: key.PrivateKey, hash_method: str) -> bytes: """Signs the message with the private key. Hashes the message, then signs the hash with the given key. This is known @@ -302,7 +302,7 @@ def sign(message, priv_key, hash_method): return sign_hash(msg_hash, priv_key, hash_method) -def verify(message, signature, pub_key): +def verify(message: bytes, signature: bytes, pub_key: key.PublicKey) -> str: """Verifies that the signature matches the message. The hash method is detected automatically from the signature. @@ -337,7 +337,7 @@ def verify(message, signature, pub_key): return method_name -def find_signature_hash(signature, pub_key): +def find_signature_hash(signature: bytes, pub_key: key.PublicKey) -> str: """Returns the hash name detected from the signature. If you also want to verify the message, use :py:func:`rsa.verify()` instead. @@ -356,7 +356,7 @@ def find_signature_hash(signature, pub_key): return _find_method_hash(clearsig) -def yield_fixedblocks(infile, blocksize): +def yield_fixedblocks(infile: typing.BinaryIO, blocksize: int) -> typing.Iterator[bytes]: """Generator, yields each block of ``blocksize`` bytes in the input file. :param infile: file to read and separate in blocks. @@ -377,7 +377,7 @@ def yield_fixedblocks(infile, blocksize): break -def compute_hash(message, method_name): +def compute_hash(message: typing.Union[bytes, typing.BinaryIO], method_name: str) -> bytes: """Returns the message digest. :param message: the signed message. Can be an 8-bit string or a file-like @@ -394,18 +394,18 @@ def compute_hash(message, method_name): method = HASH_METHODS[method_name] hasher = method() - if hasattr(message, 'read') and hasattr(message.read, '__call__'): + if isinstance(message, bytes): + hasher.update(message) + else: + assert hasattr(message, 'read') and hasattr(message.read, '__call__') # read as 1K blocks for block in yield_fixedblocks(message, 1024): hasher.update(block) - else: - # hash the message object itself. - hasher.update(message) return hasher.digest() -def _find_method_hash(clearsig): +def _find_method_hash(clearsig: bytes) -> str: """Finds the hash method. :param clearsig: full padded ASN1 and hash. diff --git a/rsa/pkcs1_v2.py b/rsa/pkcs1_v2.py index 6242a71..b751399 100644 --- a/rsa/pkcs1_v2.py +++ b/rsa/pkcs1_v2.py @@ -27,7 +27,7 @@ from rsa import ( ) -def mgf1(seed, length, hasher='SHA-1'): +def mgf1(seed: bytes, length: int, hasher='SHA-1') -> bytes: """ MGF1 is a Mask Generation Function based on a hash function. diff --git a/rsa/prime.py b/rsa/prime.py index a45f659..dcd60dd 100644 --- a/rsa/prime.py +++ b/rsa/prime.py @@ -26,7 +26,7 @@ import rsa.randnum __all__ = ['getprime', 'are_relatively_prime'] -def gcd(p, q): +def gcd(p: int, q: int) -> int: """Returns the greatest common divisor of p and q >>> gcd(48, 180) @@ -38,7 +38,7 @@ def gcd(p, q): return p -def get_primality_testing_rounds(number): +def get_primality_testing_rounds(number: int) -> int: """Returns minimum number of rounds for Miller-Rabing primality testing, based on number bitsize. @@ -64,7 +64,7 @@ def get_primality_testing_rounds(number): return 10 -def miller_rabin_primality_testing(n, k): +def miller_rabin_primality_testing(n: int, k: int) -> bool: """Calculates whether n is composite (which is always correct) or prime (which theoretically is incorrect with error probability 4**-k), by applying Miller-Rabin primality testing. @@ -117,7 +117,7 @@ def miller_rabin_primality_testing(n, k): return True -def is_prime(number): +def is_prime(number: int) -> bool: """Returns True if the number is prime, and False otherwise. >>> is_prime(2) @@ -143,7 +143,7 @@ def is_prime(number): return miller_rabin_primality_testing(number, k + 1) -def getprime(nbits): +def getprime(nbits: int) -> int: """Returns a prime number that can be stored in 'nbits' bits. >>> p = getprime(128) @@ -171,7 +171,7 @@ def getprime(nbits): # Retry if not prime -def are_relatively_prime(a, b): +def are_relatively_prime(a: int, b: int) -> bool: """Returns True if a and b are relatively prime, and False if they are not. diff --git a/rsa/randnum.py b/rsa/randnum.py index 1f0a4e5..e9bfc87 100644 --- a/rsa/randnum.py +++ b/rsa/randnum.py @@ -24,7 +24,7 @@ import struct from rsa import common, transform -def read_random_bits(nbits): +def read_random_bits(nbits: int) -> bytes: """Reads 'nbits' random bits. If nbits isn't a whole number of bytes, an extra byte will be appended with @@ -45,7 +45,7 @@ def read_random_bits(nbits): return randomdata -def read_random_int(nbits): +def read_random_int(nbits: int) -> int: """Reads a random integer of approximately nbits bits. """ @@ -59,7 +59,7 @@ def read_random_int(nbits): return value -def read_random_odd_int(nbits): +def read_random_odd_int(nbits: int) -> int: """Reads a random odd integer of approximately nbits bits. >>> read_random_odd_int(512) & 1 @@ -72,7 +72,7 @@ def read_random_odd_int(nbits): return value | 1 -def randint(maxvalue): +def randint(maxvalue: int) -> int: """Returns a random integer x with 1 <= x <= maxvalue May take a very long time in specific situations. If maxvalue needs N bits diff --git a/rsa/util.py b/rsa/util.py index c44d04c..e0c7134 100644 --- a/rsa/util.py +++ b/rsa/util.py @@ -22,7 +22,7 @@ from optparse import OptionParser import rsa.key -def private_to_public(): +def private_to_public() -> None: """Reads a private key and outputs the corresponding public key.""" # Parse the CLI options -- cgit v1.2.1