diff options
author | Alex Gaynor <alex.gaynor@gmail.com> | 2021-02-11 13:56:46 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-02-11 12:56:46 -0600 |
commit | 5511445e95a16fa12b464d57ace4bb17855fe844 (patch) | |
tree | 4f9acb5371a53670664d38f64ef6cf59cbf5bbd4 | |
parent | 9efc6d46fb98a52a9bb956096a9389f21bd8de92 (diff) | |
download | cryptography-master.tar.gz |
Start typing a bunch of stuff from x509 extensions (#5812)master
-rw-r--r-- | src/cryptography/x509/extensions.py | 60 | ||||
-rw-r--r-- | src/cryptography/x509/general_name.py | 21 | ||||
-rw-r--r-- | tests/x509/test_ocsp.py | 4 | ||||
-rw-r--r-- | tests/x509/test_x509.py | 4 | ||||
-rw-r--r-- | tests/x509/test_x509_ext.py | 77 |
5 files changed, 110 insertions, 56 deletions
diff --git a/src/cryptography/x509/extensions.py b/src/cryptography/x509/extensions.py index 9f3d8f62d..2f8612277 100644 --- a/src/cryptography/x509/extensions.py +++ b/src/cryptography/x509/extensions.py @@ -990,15 +990,15 @@ class KeyUsage(ExtensionType): def __init__( self, - digital_signature, - content_commitment, - key_encipherment, - data_encipherment, - key_agreement, - key_cert_sign, - crl_sign, - encipher_only, - decipher_only, + digital_signature: bool, + content_commitment: bool, + key_encipherment: bool, + data_encipherment: bool, + key_agreement: bool, + key_cert_sign: bool, + crl_sign: bool, + encipher_only: bool, + decipher_only: bool, ): if not key_agreement and (encipher_only or decipher_only): raise ValueError( @@ -1101,7 +1101,11 @@ class KeyUsage(ExtensionType): class NameConstraints(ExtensionType): oid = ExtensionOID.NAME_CONSTRAINTS - def __init__(self, permitted_subtrees, excluded_subtrees): + def __init__( + self, + permitted_subtrees: typing.Optional[typing.Iterable[GeneralName]], + excluded_subtrees: typing.Optional[typing.Iterable[GeneralName]], + ): if permitted_subtrees is not None: permitted_subtrees = list(permitted_subtrees) if not all(isinstance(x, GeneralName) for x in permitted_subtrees): @@ -1180,7 +1184,9 @@ class NameConstraints(ExtensionType): class Extension(object): - def __init__(self, oid, critical, value): + def __init__( + self, oid: ObjectIdentifier, critical: bool, value: ExtensionType + ): if not isinstance(oid, ObjectIdentifier): raise TypeError( "oid argument must be an ObjectIdentifier instance." @@ -1221,7 +1227,7 @@ class Extension(object): class GeneralNames(object): - def __init__(self, general_names): + def __init__(self, general_names: typing.Iterable[GeneralName]): general_names = list(general_names) if not all(isinstance(x, GeneralName) for x in general_names): raise TypeError( @@ -1233,7 +1239,7 @@ class GeneralNames(object): __len__, __iter__, __getitem__ = _make_sequence_methods("_general_names") - def get_values_for_type(self, type): + def get_values_for_type(self, type: typing.Type[GeneralName]): # Return the value of each GeneralName, except for OtherName instances # which we return directly because it has two important properties not # just one value. @@ -1261,7 +1267,7 @@ class GeneralNames(object): class SubjectAlternativeName(ExtensionType): oid = ExtensionOID.SUBJECT_ALTERNATIVE_NAME - def __init__(self, general_names): + def __init__(self, general_names: typing.Iterable[GeneralName]): self._general_names = GeneralNames(general_names) __len__, __iter__, __getitem__ = _make_sequence_methods("_general_names") @@ -1288,7 +1294,7 @@ class SubjectAlternativeName(ExtensionType): class IssuerAlternativeName(ExtensionType): oid = ExtensionOID.ISSUER_ALTERNATIVE_NAME - def __init__(self, general_names): + def __init__(self, general_names: typing.Iterable[GeneralName]): self._general_names = GeneralNames(general_names) __len__, __iter__, __getitem__ = _make_sequence_methods("_general_names") @@ -1315,7 +1321,7 @@ class IssuerAlternativeName(ExtensionType): class CertificateIssuer(ExtensionType): oid = CRLEntryExtensionOID.CERTIFICATE_ISSUER - def __init__(self, general_names): + def __init__(self, general_names: typing.Iterable[GeneralName]): self._general_names = GeneralNames(general_names) __len__, __iter__, __getitem__ = _make_sequence_methods("_general_names") @@ -1342,7 +1348,7 @@ class CertificateIssuer(ExtensionType): class CRLReason(ExtensionType): oid = CRLEntryExtensionOID.CRL_REASON - def __init__(self, reason): + def __init__(self, reason: ReasonFlags): if not isinstance(reason, ReasonFlags): raise TypeError("reason must be an element from ReasonFlags") @@ -1369,7 +1375,7 @@ class CRLReason(ExtensionType): class InvalidityDate(ExtensionType): oid = CRLEntryExtensionOID.INVALIDITY_DATE - def __init__(self, invalidity_date): + def __init__(self, invalidity_date: datetime.datetime): if not isinstance(invalidity_date, datetime.datetime): raise TypeError("invalidity_date must be a datetime.datetime") @@ -1398,7 +1404,12 @@ class InvalidityDate(ExtensionType): class PrecertificateSignedCertificateTimestamps(ExtensionType): oid = ExtensionOID.PRECERT_SIGNED_CERTIFICATE_TIMESTAMPS - def __init__(self, signed_certificate_timestamps): + def __init__( + self, + signed_certificate_timestamps: typing.Iterable[ + SignedCertificateTimestamp + ], + ): signed_certificate_timestamps = list(signed_certificate_timestamps) if not all( isinstance(sct, SignedCertificateTimestamp) @@ -1438,7 +1449,12 @@ class PrecertificateSignedCertificateTimestamps(ExtensionType): class SignedCertificateTimestamps(ExtensionType): oid = ExtensionOID.SIGNED_CERTIFICATE_TIMESTAMPS - def __init__(self, signed_certificate_timestamps): + def __init__( + self, + signed_certificate_timestamps: typing.Iterable[ + SignedCertificateTimestamp + ], + ): signed_certificate_timestamps = list(signed_certificate_timestamps) if not all( isinstance(sct, SignedCertificateTimestamp) @@ -1476,7 +1492,7 @@ class SignedCertificateTimestamps(ExtensionType): class OCSPNonce(ExtensionType): oid = OCSPExtensionOID.NONCE - def __init__(self, nonce): + def __init__(self, nonce: bytes): if not isinstance(nonce, bytes): raise TypeError("nonce must be bytes") @@ -1642,7 +1658,7 @@ class IssuingDistributionPoint(ExtensionType): class UnrecognizedExtension(ExtensionType): - def __init__(self, oid, value): + def __init__(self, oid: ObjectIdentifier, value: bytes): if not isinstance(oid, ObjectIdentifier): raise TypeError("oid must be an ObjectIdentifier") self._oid = oid diff --git a/src/cryptography/x509/general_name.py b/src/cryptography/x509/general_name.py index 6683e9313..a83471e93 100644 --- a/src/cryptography/x509/general_name.py +++ b/src/cryptography/x509/general_name.py @@ -40,8 +40,7 @@ class GeneralName(metaclass=abc.ABCMeta): """ -@utils.register_interface(GeneralName) -class RFC822Name(object): +class RFC822Name(GeneralName): def __init__(self, value: str): if isinstance(value, str): try: @@ -87,8 +86,7 @@ class RFC822Name(object): return hash(self.value) -@utils.register_interface(GeneralName) -class DNSName(object): +class DNSName(GeneralName): def __init__(self, value: str): if isinstance(value, str): try: @@ -128,8 +126,7 @@ class DNSName(object): return hash(self.value) -@utils.register_interface(GeneralName) -class UniformResourceIdentifier(object): +class UniformResourceIdentifier(GeneralName): def __init__(self, value: str): if isinstance(value, str): try: @@ -169,8 +166,7 @@ class UniformResourceIdentifier(object): return hash(self.value) -@utils.register_interface(GeneralName) -class DirectoryName(object): +class DirectoryName(GeneralName): def __init__(self, value: Name): if not isinstance(value, Name): raise TypeError("value must be a Name") @@ -195,8 +191,7 @@ class DirectoryName(object): return hash(self.value) -@utils.register_interface(GeneralName) -class RegisteredID(object): +class RegisteredID(GeneralName): def __init__(self, value: ObjectIdentifier): if not isinstance(value, ObjectIdentifier): raise TypeError("value must be an ObjectIdentifier") @@ -221,8 +216,7 @@ class RegisteredID(object): return hash(self.value) -@utils.register_interface(GeneralName) -class IPAddress(object): +class IPAddress(GeneralName): def __init__( self, value: typing.Union[ @@ -267,8 +261,7 @@ class IPAddress(object): return hash(self.value) -@utils.register_interface(GeneralName) -class OtherName(object): +class OtherName(GeneralName): def __init__(self, type_id: ObjectIdentifier, value: bytes): if not isinstance(type_id, ObjectIdentifier): raise TypeError("type_id must be an ObjectIdentifier") diff --git a/tests/x509/test_ocsp.py b/tests/x509/test_ocsp.py index 5793f6d62..5d9da790a 100644 --- a/tests/x509/test_ocsp.py +++ b/tests/x509/test_ocsp.py @@ -726,7 +726,9 @@ class TestOCSPResponseBuilder(object): class TestSignedCertificateTimestampsExtension(object): def test_init(self): with pytest.raises(TypeError): - x509.SignedCertificateTimestamps([object()]) + x509.SignedCertificateTimestamps( + [object()] # type: ignore[list-item] + ) def test_repr(self): assert repr(x509.SignedCertificateTimestamps([])) == ( diff --git a/tests/x509/test_x509.py b/tests/x509/test_x509.py index 39f7bb951..b1e86f436 100644 --- a/tests/x509/test_x509.py +++ b/tests/x509/test_x509.py @@ -4070,7 +4070,9 @@ class TestCertificateSigningRequestBuilder(object): x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "SAN")]) ) .add_extension( - x509.SubjectAlternativeName([FakeGeneralName("")]), + x509.SubjectAlternativeName( + [FakeGeneralName("")] # type:ignore[list-item] + ), critical=False, ) ) diff --git a/tests/x509/test_x509_ext.py b/tests/x509/test_x509_ext.py index 011649f4e..938357f2d 100644 --- a/tests/x509/test_x509_ext.py +++ b/tests/x509/test_x509_ext.py @@ -56,12 +56,16 @@ class TestExtension(object): def test_not_an_oid(self): bc = x509.BasicConstraints(ca=False, path_length=None) with pytest.raises(TypeError): - x509.Extension("notanoid", True, bc) + x509.Extension("notanoid", True, bc) # type:ignore[arg-type] def test_critical_not_a_bool(self): bc = x509.BasicConstraints(ca=False, path_length=None) with pytest.raises(TypeError): - x509.Extension(ExtensionOID.BASIC_CONSTRAINTS, "notabool", bc) + x509.Extension( + ExtensionOID.BASIC_CONSTRAINTS, + "notabool", # type:ignore[arg-type] + bc, + ) def test_repr(self): bc = x509.BasicConstraints(ca=False, path_length=None) @@ -73,16 +77,38 @@ class TestExtension(object): ) def test_eq(self): - ext1 = x509.Extension(x509.ObjectIdentifier("1.2.3.4"), False, "value") - ext2 = x509.Extension(x509.ObjectIdentifier("1.2.3.4"), False, "value") + ext1 = x509.Extension( + x509.ObjectIdentifier("1.2.3.4"), + False, + x509.BasicConstraints(ca=False, path_length=None), + ) + ext2 = x509.Extension( + x509.ObjectIdentifier("1.2.3.4"), + False, + x509.BasicConstraints(ca=False, path_length=None), + ) assert ext1 == ext2 def test_ne(self): - ext1 = x509.Extension(x509.ObjectIdentifier("1.2.3.4"), False, "value") - ext2 = x509.Extension(x509.ObjectIdentifier("1.2.3.5"), False, "value") - ext3 = x509.Extension(x509.ObjectIdentifier("1.2.3.4"), True, "value") + ext1 = x509.Extension( + x509.ObjectIdentifier("1.2.3.4"), + False, + x509.BasicConstraints(ca=False, path_length=None), + ) + ext2 = x509.Extension( + x509.ObjectIdentifier("1.2.3.5"), + False, + x509.BasicConstraints(ca=False, path_length=None), + ) + ext3 = x509.Extension( + x509.ObjectIdentifier("1.2.3.4"), + True, + x509.BasicConstraints(ca=False, path_length=None), + ) ext4 = x509.Extension( - x509.ObjectIdentifier("1.2.3.4"), False, "value4" + x509.ObjectIdentifier("1.2.3.4"), + False, + x509.BasicConstraints(ca=True, path_length=None), ) assert ext1 != ext2 assert ext1 != ext3 @@ -181,7 +207,9 @@ class TestTLSFeature(object): class TestUnrecognizedExtension(object): def test_invalid_oid(self): with pytest.raises(TypeError): - x509.UnrecognizedExtension("notanoid", b"somedata") + x509.UnrecognizedExtension( + "notanoid", b"somedata" # type:ignore[arg-type] + ) def test_eq(self): ext1 = x509.UnrecognizedExtension( @@ -289,7 +317,7 @@ class TestCertificateIssuer(object): class TestCRLReason(object): def test_invalid_reason_flags(self): with pytest.raises(TypeError): - x509.CRLReason("notareason") + x509.CRLReason("notareason") # type:ignore[arg-type] def test_eq(self): reason1 = x509.CRLReason(x509.ReasonFlags.unspecified) @@ -346,7 +374,7 @@ class TestDeltaCRLIndicator(object): class TestInvalidityDate(object): def test_invalid_invalidity_date(self): with pytest.raises(TypeError): - x509.InvalidityDate("notadate") + x509.InvalidityDate("notadate") # type:ignore[arg-type] def test_eq(self): invalid1 = x509.InvalidityDate(datetime.datetime(2015, 1, 1, 1, 1)) @@ -1990,7 +2018,12 @@ class TestGeneralNames(object): def test_invalid_general_names(self): with pytest.raises(TypeError): - x509.GeneralNames([x509.DNSName("cryptography.io"), "invalid"]) + x509.GeneralNames( + [ + x509.DNSName("cryptography.io"), + "invalid", # type:ignore[list-item] + ] + ) def test_repr(self): gns = x509.GeneralNames([x509.DNSName("cryptography.io")]) @@ -2049,7 +2082,10 @@ class TestIssuerAlternativeName(object): def test_invalid_general_names(self): with pytest.raises(TypeError): x509.IssuerAlternativeName( - [x509.DNSName("cryptography.io"), "invalid"] + [ + x509.DNSName("cryptography.io"), + "invalid", # type:ignore[list-item] + ] ) def test_repr(self): @@ -2157,7 +2193,10 @@ class TestSubjectAlternativeName(object): def test_invalid_general_names(self): with pytest.raises(TypeError): x509.SubjectAlternativeName( - [x509.DNSName("cryptography.io"), "invalid"] + [ + x509.DNSName("cryptography.io"), + "invalid", # type:ignore[list-item] + ] ) def test_repr(self): @@ -3335,11 +3374,11 @@ class TestNameConstraints(object): def test_invalid_permitted_subtrees(self): with pytest.raises(TypeError): - x509.NameConstraints("badpermitted", None) + x509.NameConstraints("badpermitted", None) # type:ignore[arg-type] def test_invalid_excluded_subtrees(self): with pytest.raises(TypeError): - x509.NameConstraints(None, "badexcluded") + x509.NameConstraints(None, "badexcluded") # type:ignore[arg-type] def test_no_subtrees(self): with pytest.raises(ValueError): @@ -5365,7 +5404,9 @@ class TestSignedCertificateTimestamps(object): class TestPrecertificateSignedCertificateTimestampsExtension(object): def test_init(self): with pytest.raises(TypeError): - x509.PrecertificateSignedCertificateTimestamps([object()]) + x509.PrecertificateSignedCertificateTimestamps( + [object()] # type:ignore[list-item] + ) def test_repr(self): assert repr(x509.PrecertificateSignedCertificateTimestamps([])) == ( @@ -5566,7 +5607,7 @@ class TestInvalidExtension(object): class TestOCSPNonce(object): def test_non_bytes(self): with pytest.raises(TypeError): - x509.OCSPNonce(38) + x509.OCSPNonce(38) # type:ignore[arg-type] def test_eq(self): nonce1 = x509.OCSPNonce(b"0" * 5) |