diff options
Diffstat (limited to 'src/cryptography/x509')
-rw-r--r-- | src/cryptography/x509/extensions.py | 60 | ||||
-rw-r--r-- | src/cryptography/x509/general_name.py | 21 |
2 files changed, 45 insertions, 36 deletions
diff --git a/src/cryptography/x509/extensions.py b/src/cryptography/x509/extensions.py index 9f3d8f62d..2f8612277 100644 --- a/src/cryptography/x509/extensions.py +++ b/src/cryptography/x509/extensions.py @@ -990,15 +990,15 @@ class KeyUsage(ExtensionType): def __init__( self, - digital_signature, - content_commitment, - key_encipherment, - data_encipherment, - key_agreement, - key_cert_sign, - crl_sign, - encipher_only, - decipher_only, + digital_signature: bool, + content_commitment: bool, + key_encipherment: bool, + data_encipherment: bool, + key_agreement: bool, + key_cert_sign: bool, + crl_sign: bool, + encipher_only: bool, + decipher_only: bool, ): if not key_agreement and (encipher_only or decipher_only): raise ValueError( @@ -1101,7 +1101,11 @@ class KeyUsage(ExtensionType): class NameConstraints(ExtensionType): oid = ExtensionOID.NAME_CONSTRAINTS - def __init__(self, permitted_subtrees, excluded_subtrees): + def __init__( + self, + permitted_subtrees: typing.Optional[typing.Iterable[GeneralName]], + excluded_subtrees: typing.Optional[typing.Iterable[GeneralName]], + ): if permitted_subtrees is not None: permitted_subtrees = list(permitted_subtrees) if not all(isinstance(x, GeneralName) for x in permitted_subtrees): @@ -1180,7 +1184,9 @@ class NameConstraints(ExtensionType): class Extension(object): - def __init__(self, oid, critical, value): + def __init__( + self, oid: ObjectIdentifier, critical: bool, value: ExtensionType + ): if not isinstance(oid, ObjectIdentifier): raise TypeError( "oid argument must be an ObjectIdentifier instance." @@ -1221,7 +1227,7 @@ class Extension(object): class GeneralNames(object): - def __init__(self, general_names): + def __init__(self, general_names: typing.Iterable[GeneralName]): general_names = list(general_names) if not all(isinstance(x, GeneralName) for x in general_names): raise TypeError( @@ -1233,7 +1239,7 @@ class GeneralNames(object): __len__, __iter__, __getitem__ = _make_sequence_methods("_general_names") - def get_values_for_type(self, type): + def get_values_for_type(self, type: typing.Type[GeneralName]): # Return the value of each GeneralName, except for OtherName instances # which we return directly because it has two important properties not # just one value. @@ -1261,7 +1267,7 @@ class GeneralNames(object): class SubjectAlternativeName(ExtensionType): oid = ExtensionOID.SUBJECT_ALTERNATIVE_NAME - def __init__(self, general_names): + def __init__(self, general_names: typing.Iterable[GeneralName]): self._general_names = GeneralNames(general_names) __len__, __iter__, __getitem__ = _make_sequence_methods("_general_names") @@ -1288,7 +1294,7 @@ class SubjectAlternativeName(ExtensionType): class IssuerAlternativeName(ExtensionType): oid = ExtensionOID.ISSUER_ALTERNATIVE_NAME - def __init__(self, general_names): + def __init__(self, general_names: typing.Iterable[GeneralName]): self._general_names = GeneralNames(general_names) __len__, __iter__, __getitem__ = _make_sequence_methods("_general_names") @@ -1315,7 +1321,7 @@ class IssuerAlternativeName(ExtensionType): class CertificateIssuer(ExtensionType): oid = CRLEntryExtensionOID.CERTIFICATE_ISSUER - def __init__(self, general_names): + def __init__(self, general_names: typing.Iterable[GeneralName]): self._general_names = GeneralNames(general_names) __len__, __iter__, __getitem__ = _make_sequence_methods("_general_names") @@ -1342,7 +1348,7 @@ class CertificateIssuer(ExtensionType): class CRLReason(ExtensionType): oid = CRLEntryExtensionOID.CRL_REASON - def __init__(self, reason): + def __init__(self, reason: ReasonFlags): if not isinstance(reason, ReasonFlags): raise TypeError("reason must be an element from ReasonFlags") @@ -1369,7 +1375,7 @@ class CRLReason(ExtensionType): class InvalidityDate(ExtensionType): oid = CRLEntryExtensionOID.INVALIDITY_DATE - def __init__(self, invalidity_date): + def __init__(self, invalidity_date: datetime.datetime): if not isinstance(invalidity_date, datetime.datetime): raise TypeError("invalidity_date must be a datetime.datetime") @@ -1398,7 +1404,12 @@ class InvalidityDate(ExtensionType): class PrecertificateSignedCertificateTimestamps(ExtensionType): oid = ExtensionOID.PRECERT_SIGNED_CERTIFICATE_TIMESTAMPS - def __init__(self, signed_certificate_timestamps): + def __init__( + self, + signed_certificate_timestamps: typing.Iterable[ + SignedCertificateTimestamp + ], + ): signed_certificate_timestamps = list(signed_certificate_timestamps) if not all( isinstance(sct, SignedCertificateTimestamp) @@ -1438,7 +1449,12 @@ class PrecertificateSignedCertificateTimestamps(ExtensionType): class SignedCertificateTimestamps(ExtensionType): oid = ExtensionOID.SIGNED_CERTIFICATE_TIMESTAMPS - def __init__(self, signed_certificate_timestamps): + def __init__( + self, + signed_certificate_timestamps: typing.Iterable[ + SignedCertificateTimestamp + ], + ): signed_certificate_timestamps = list(signed_certificate_timestamps) if not all( isinstance(sct, SignedCertificateTimestamp) @@ -1476,7 +1492,7 @@ class SignedCertificateTimestamps(ExtensionType): class OCSPNonce(ExtensionType): oid = OCSPExtensionOID.NONCE - def __init__(self, nonce): + def __init__(self, nonce: bytes): if not isinstance(nonce, bytes): raise TypeError("nonce must be bytes") @@ -1642,7 +1658,7 @@ class IssuingDistributionPoint(ExtensionType): class UnrecognizedExtension(ExtensionType): - def __init__(self, oid, value): + def __init__(self, oid: ObjectIdentifier, value: bytes): if not isinstance(oid, ObjectIdentifier): raise TypeError("oid must be an ObjectIdentifier") self._oid = oid diff --git a/src/cryptography/x509/general_name.py b/src/cryptography/x509/general_name.py index 6683e9313..a83471e93 100644 --- a/src/cryptography/x509/general_name.py +++ b/src/cryptography/x509/general_name.py @@ -40,8 +40,7 @@ class GeneralName(metaclass=abc.ABCMeta): """ -@utils.register_interface(GeneralName) -class RFC822Name(object): +class RFC822Name(GeneralName): def __init__(self, value: str): if isinstance(value, str): try: @@ -87,8 +86,7 @@ class RFC822Name(object): return hash(self.value) -@utils.register_interface(GeneralName) -class DNSName(object): +class DNSName(GeneralName): def __init__(self, value: str): if isinstance(value, str): try: @@ -128,8 +126,7 @@ class DNSName(object): return hash(self.value) -@utils.register_interface(GeneralName) -class UniformResourceIdentifier(object): +class UniformResourceIdentifier(GeneralName): def __init__(self, value: str): if isinstance(value, str): try: @@ -169,8 +166,7 @@ class UniformResourceIdentifier(object): return hash(self.value) -@utils.register_interface(GeneralName) -class DirectoryName(object): +class DirectoryName(GeneralName): def __init__(self, value: Name): if not isinstance(value, Name): raise TypeError("value must be a Name") @@ -195,8 +191,7 @@ class DirectoryName(object): return hash(self.value) -@utils.register_interface(GeneralName) -class RegisteredID(object): +class RegisteredID(GeneralName): def __init__(self, value: ObjectIdentifier): if not isinstance(value, ObjectIdentifier): raise TypeError("value must be an ObjectIdentifier") @@ -221,8 +216,7 @@ class RegisteredID(object): return hash(self.value) -@utils.register_interface(GeneralName) -class IPAddress(object): +class IPAddress(GeneralName): def __init__( self, value: typing.Union[ @@ -267,8 +261,7 @@ class IPAddress(object): return hash(self.value) -@utils.register_interface(GeneralName) -class OtherName(object): +class OtherName(GeneralName): def __init__(self, type_id: ObjectIdentifier, value: bytes): if not isinstance(type_id, ObjectIdentifier): raise TypeError("type_id must be an ObjectIdentifier") |