diff options
author | David Lord <davidism@gmail.com> | 2018-09-28 12:45:55 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-09-28 12:45:55 -0700 |
commit | 9a1937a45a841c2d903f866dc845f7c6b36a2931 (patch) | |
tree | fa8670f5831bdddfb143104c33b72101e589be68 | |
parent | a04a4bc3ab9d3ff9c92e620899e223377c13b96c (diff) | |
parent | 93124678292c819f0a07fa40fe4b1e0a75f7bcf4 (diff) | |
download | itsdangerous-9a1937a45a841c2d903f866dc845f7c6b36a2931.tar.gz |
Merge pull request #105 from pallets/style-checks
add style checks
-rw-r--r-- | .pre-commit-config.yaml | 14 | ||||
-rw-r--r-- | .travis.yml | 2 | ||||
-rw-r--r-- | docs/conf.py | 7 | ||||
-rw-r--r-- | setup.cfg | 14 | ||||
-rwxr-xr-x | setup.py | 45 | ||||
-rw-r--r-- | src/itsdangerous/__init__.py | 300 | ||||
-rwxr-xr-x | tests/test_itsdangerous.py | 227 | ||||
-rw-r--r-- | tox.ini | 6 |
8 files changed, 358 insertions, 257 deletions
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..49eb05c --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,14 @@ +repos: + - repo: https://github.com/asottile/reorder_python_imports + rev: v1.2.0 + hooks: + - id: reorder-python-imports + args: ["--application-directories", "src"] + - repo: https://github.com/ambv/black + rev: 18.9b0 + hooks: + - id: black + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v1.4.0 + hooks: + - id: flake8 diff --git a/.travis.yml b/.travis.yml index c2ee1ff..92acace 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,6 +11,8 @@ python: env: TOXENV=py,codecov matrix: + include: + - env: TOXENV=stylecheck,docs-html allow_failures: - python: nightly - python: pypy3 diff --git a/docs/conf.py b/docs/conf.py index 6beee90..ae04141 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,4 +1,5 @@ -from pallets_sphinx_themes import ProjectLink, get_version +from pallets_sphinx_themes import get_version +from pallets_sphinx_themes import ProjectLink # Project -------------------------------------------------------------- @@ -26,9 +27,7 @@ html_context = { ProjectLink("Issue Tracker", "https://github.com/pallets/itsdangerous/issues/"), ] } -html_sidebars = { - "index": ["project.html", "localtoc.html", "versions.html"], -} +html_sidebars = {"index": ["project.html", "localtoc.html", "versions.html"]} html_static_path = ["_static"] html_title = "It's Dangerous Documentation ({})".format(version) html_show_sourcelink = False @@ -18,3 +18,17 @@ source = src/itsdangerous .tox/*/lib/python*/site-packages/itsdangerous .tox/*/site-packages/itsdangerous + +[flake8] +# B = bugbear +# E = pycodestyle errors +# F = flake8 pyflakes +# W = pycodestyle warnings +# B9 = bugbear opinions +select = B, E, F, W, B9 +# E203 = slice notation whitespace, invalid +# E501 = line length, handled by bugbear B950 +# W503 = bin op line break, invalid +ignore = E203, E501, W503 +# up to 88 allowed by bugbear B950 +max-line-length = 80 @@ -1,45 +1,46 @@ import io import re -from setuptools import setup, find_packages + +from setuptools import find_packages +from setuptools import setup with io.open("README.rst", "rt", encoding="utf8") as f: readme = f.read() with io.open("src/itsdangerous/__init__.py", "rt", encoding="utf8") as f: - version = re.search(r"__version__ = \'(.*?)\'", f.read()).group(1) + version = re.search(r"__version__ = \"(.*?)\"", f.read()).group(1) setup( - name='ItsDangerous', + name="ItsDangerous", version=version, - url='https://palletsprojects.com/p/itsdangerous/', + url="https://palletsprojects.com/p/itsdangerous/", project_urls={ "Documentation": "https://itsdangerous.palletsprojects.com/", "Code": "https://github.com/pallets/itsdangerous", "Issue tracker": "https://github.com/pallets/itsdangerous/issues", }, - license='BSD', - author='Armin Ronacher', - author_email='armin.ronacher@active-4.com', - maintainer='Pallets Team', - maintainer_email='contact@palletsprojects.com', - description='Various helpers to pass data to untrusted environments and back.', + license="BSD", + author="Armin Ronacher", + author_email="armin.ronacher@active-4.com", + maintainer="Pallets Team", + maintainer_email="contact@palletsprojects.com", + description="Various helpers to pass data to untrusted environments and back.", long_description=readme, packages=find_packages("src"), package_dir={"": "src"}, include_package_data=True, python_requires=">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*", classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: BSD License', - 'Operating System :: OS Independent' - 'Programming Language :: Python', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent" "Programming Language :: Python", + "Programming Language :: Python :: 2", + "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.4", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", ], ) diff --git a/src/itsdangerous/__init__.py b/src/itsdangerous/__init__.py index 63d25d2..27e0596 100644 --- a/src/itsdangerous/__init__.py +++ b/src/itsdangerous/__init__.py @@ -9,23 +9,23 @@ sources. Mainly useful for web applications. :copyright: © 2011 by the Pallets team :license: BSD, see LICENSE.rst for more details. """ - +import base64 +import hashlib +import hmac import string import struct import sys -import hmac -import zlib import time -import base64 -import hashlib +import zlib from datetime import datetime PY2 = sys.version_info[0] == 2 if PY2: from itertools import izip - text_type = unicode - number_types = (int, long, float) + + text_type = unicode # noqa: 821 + number_types = (int, long, float) # noqa: 821 else: izip = zip text_type = str @@ -37,7 +37,7 @@ try: except ImportError: import json -__version__ = '1.0.dev' +__version__ = "1.0.dev" class _CompactJSON(object): @@ -49,12 +49,12 @@ class _CompactJSON(object): @staticmethod def dumps(obj, **kwargs): - kwargs.setdefault('ensure_ascii', False) - kwargs.setdefault('separators', (',', ':')) + kwargs.setdefault("ensure_ascii", False) + kwargs.setdefault("separators", (",", ":")) return json.dumps(obj, **kwargs) -def want_bytes(s, encoding='utf-8', errors='strict'): +def want_bytes(s, encoding="utf-8", errors="strict"): if isinstance(s, text_type): s = s.encode(encoding, errors) return s @@ -90,7 +90,7 @@ def constant_time_compare(val1, val2): # Starting with 2.7/3.3 the standard library has a c-implementation for # constant time string compares. -constant_time_compare = getattr(hmac, 'compare_digest', constant_time_compare) +constant_time_compare = getattr(hmac, "compare_digest", constant_time_compare) class BadData(Exception): @@ -99,6 +99,7 @@ class BadData(Exception): .. versionadded:: 0.15 """ + message = None def __init__(self, message): @@ -110,8 +111,9 @@ class BadData(Exception): if PY2: __unicode__ = __str__ + def __str__(self): - return self.__unicode__().encode('utf-8') + return self.__unicode__().encode("utf-8") class BadPayload(BadData): @@ -174,8 +176,7 @@ class BadHeader(BadSignature): .. versionadded:: 0.24 """ - def __init__(self, message, payload=None, header=None, - original_error=None): + def __init__(self, message, payload=None, header=None, original_error=None): BadSignature.__init__(self, message, payload) #: If the header is actually available but just malformed it @@ -200,7 +201,7 @@ def base64_encode(string): The resulting bytestring is safe for putting into URLs. """ string = want_bytes(string) - return base64.urlsafe_b64encode(string).rstrip(b'=') + return base64.urlsafe_b64encode(string).rstrip(b"=") def base64_decode(string): @@ -208,31 +209,29 @@ def base64_decode(string): called with a unicode string). The result is also a bytestring. """ - string = want_bytes(string, encoding='ascii', errors='ignore') - string += b'=' * (-len(string) % 4) + string = want_bytes(string, encoding="ascii", errors="ignore") + string += b"=" * (-len(string) % 4) try: return base64.urlsafe_b64decode(string) except (TypeError, ValueError): - raise BadData('Invalid base64-encoded data') + raise BadData("Invalid base64-encoded data") # The alphabet used by base64.urlsafe_* -_base64_alphabet = ( - string.ascii_letters + string.digits + '-_=' -).encode('ascii') +_base64_alphabet = (string.ascii_letters + string.digits + "-_=").encode("ascii") -_int64_struct = struct.Struct('>Q') +_int64_struct = struct.Struct(">Q") _int_to_bytes = _int64_struct.pack _bytes_to_int = _int64_struct.unpack def int_to_bytes(num): - return _int_to_bytes(num).lstrip(b'\x00') + return _int_to_bytes(num).lstrip(b"\x00") def bytes_to_int(bytestr): - return _bytes_to_int(bytestr.rjust(8, b'\x00'))[0] + return _bytes_to_int(bytestr.rjust(8, b"\x00"))[0] class SigningAlgorithm(object): @@ -255,7 +254,7 @@ class NoneAlgorithm(SigningAlgorithm): """ def get_signature(self, key, value): - return b'' + return b"" class HMACAlgorithm(SigningAlgorithm): @@ -314,18 +313,27 @@ class Signer(object): #: with an added salt. #: #: .. versionadded:: 0.14 - default_key_derivation = 'django-concat' + default_key_derivation = "django-concat" - def __init__(self, secret_key, salt=None, sep='.', key_derivation=None, - digest_method=None, algorithm=None): + def __init__( + self, + secret_key, + salt=None, + sep=".", + key_derivation=None, + digest_method=None, + algorithm=None, + ): self.secret_key = want_bytes(secret_key) self.sep = want_bytes(sep) if self.sep in _base64_alphabet: - raise ValueError('The given separator cannot be used because it ' - 'may be contained in the signature itself. ' - 'Alphanumeric characters and `-_=` must not be ' - 'used.') - self.salt = 'itsdangerous.Signer' if salt is None else salt + raise ValueError( + "The given separator cannot be used because it " + "may be contained in the signature itself. " + "Alphanumeric characters and `-_=` must not be " + "used." + ) + self.salt = "itsdangerous.Signer" if salt is None else salt if key_derivation is None: key_derivation = self.default_key_derivation self.key_derivation = key_derivation @@ -344,19 +352,18 @@ class Signer(object): password. Instead you should use large random secret keys. """ salt = want_bytes(self.salt) - if self.key_derivation == 'concat': + if self.key_derivation == "concat": return self.digest_method(salt + self.secret_key).digest() - elif self.key_derivation == 'django-concat': - return self.digest_method(salt + b'signer' + - self.secret_key).digest() - elif self.key_derivation == 'hmac': + elif self.key_derivation == "django-concat": + return self.digest_method(salt + b"signer" + self.secret_key).digest() + elif self.key_derivation == "hmac": mac = hmac.new(self.secret_key, digestmod=self.digest_method) mac.update(salt) return mac.digest() - elif self.key_derivation == 'none': + elif self.key_derivation == "none": return self.secret_key else: - raise TypeError('Unknown key derivation method') + raise TypeError("Unknown key derivation method") def get_signature(self, value): """Returns the signature for the given value""" @@ -367,9 +374,7 @@ class Signer(object): def sign(self, value): """Signs the given string.""" - return want_bytes(value) + \ - want_bytes(self.sep) + \ - self.get_signature(value) + return want_bytes(value) + want_bytes(self.sep) + self.get_signature(value) def verify_signature(self, value, sig): """Verifies the signature for the given value.""" @@ -385,12 +390,11 @@ class Signer(object): signed_value = want_bytes(signed_value) sep = want_bytes(self.sep) if sep not in signed_value: - raise BadSignature('No %r found in value' % self.sep) + raise BadSignature("No %r found in value" % self.sep) value, sig = signed_value.rsplit(sep, 1) if self.verify_signature(value, sig): return value - raise BadSignature('Signature %r does not match' % sig, - payload=value) + raise BadSignature("Signature %r does not match" % sig, payload=value) def validate(self, signed_value): """Just validates the given signed value. Returns `True` if the @@ -441,7 +445,7 @@ class TimestampSigner(Signer): sig_error = None except BadSignature as e: sig_error = e - result = e.payload or b'' + result = e.payload or b"" sep = want_bytes(self.sep) # If there is no timestamp in the result there is something @@ -449,10 +453,10 @@ class TimestampSigner(Signer): # that one directly, otherwise we have a weird situation in which # we shouldn't have come except someone uses a time-based serializer # on non-timestamp data, so catch that. - if not sep in result: + if sep not in result: if sig_error: raise sig_error - raise BadTimeSignature('timestamp missing', payload=result) + raise BadTimeSignature("timestamp missing", payload=result) value, timestamp = result.rsplit(sep, 1) try: @@ -463,22 +467,24 @@ class TimestampSigner(Signer): # Signature is *not* okay. Raise a proper error now that we have # split the value and the timestamp. if sig_error is not None: - raise BadTimeSignature(text_type(sig_error), payload=value, - date_signed=timestamp) + raise BadTimeSignature( + text_type(sig_error), payload=value, date_signed=timestamp + ) # Signature was okay but the timestamp is actually not there or # malformed. Should not happen, but well. We handle it nonetheless if timestamp is None: - raise BadTimeSignature('Malformed timestamp', payload=value) + raise BadTimeSignature("Malformed timestamp", payload=value) # Check timestamp is not older than max_age if max_age is not None: age = self.get_timestamp() - timestamp if age > max_age: raise SignatureExpired( - 'Signature age %s > %s seconds' % (age, max_age), + "Signature age %s > %s seconds" % (age, max_age), payload=value, - date_signed=self.timestamp_to_datetime(timestamp)) + date_signed=self.timestamp_to_datetime(timestamp), + ) if return_timestamp: return value, self.timestamp_to_datetime(timestamp) @@ -528,9 +534,13 @@ class Serializer(object): default_signer = Signer def __init__( - self, secret_key, salt=b'itsdangerous', - serializer=None, serializer_kwargs=None, - signer=None, signer_kwargs=None + self, + secret_key, + salt=b"itsdangerous", + serializer=None, + serializer_kwargs=None, + signer=None, + signer_kwargs=None, ): self.secret_key = want_bytes(secret_key) self.salt = want_bytes(salt) @@ -557,12 +567,14 @@ class Serializer(object): is_text = is_text_serializer(serializer) try: if is_text: - payload = payload.decode('utf-8') + payload = payload.decode("utf-8") return serializer.loads(payload) except Exception as e: - raise BadPayload('Could not load the payload because an ' - 'exception occurred on unserializing the data', - original_error=e) + raise BadPayload( + "Could not load the payload because an " + "exception occurred on unserializing the data", + original_error=e, + ) def dump_payload(self, obj): """Dumps the encoded object. The return value is always a @@ -587,7 +599,7 @@ class Serializer(object): payload = want_bytes(self.dump_payload(obj)) rv = self.make_signer(salt).sign(payload) if self.is_text_serializer: - rv = rv.decode('utf-8') + rv = rv.decode("utf-8") return rv def dump(self, obj, f, salt=None): @@ -622,8 +634,7 @@ class Serializer(object): """ return self._loads_unsafe_impl(s, salt) - def _loads_unsafe_impl(self, s, salt, load_kwargs=None, - load_payload_kwargs=None): + def _loads_unsafe_impl(self, s, salt, load_kwargs=None, load_payload_kwargs=None): """Lowlevel helper function to implement :meth:`loads_unsafe` in serializer subclasses. """ @@ -633,8 +644,10 @@ class Serializer(object): if e.payload is None: return False, None try: - return False, self.load_payload(e.payload, - **(load_payload_kwargs or {})) + return ( + False, + self.load_payload(e.payload, **(load_payload_kwargs or {})), + ) except BadPayload: return False, None @@ -661,15 +674,16 @@ class TimedSerializer(Serializer): which is a subclass of :exc:`BadSignature`. All arguments are forwarded to the signer's :meth:`~TimestampSigner.unsign` method. """ - base64d, timestamp = self.make_signer(salt) \ - .unsign(s, max_age, return_timestamp=True) + base64d, timestamp = self.make_signer(salt).unsign( + s, max_age, return_timestamp=True + ) payload = self.load_payload(base64d) if return_timestamp: return payload, timestamp return payload def loads_unsafe(self, s, max_age=None, salt=None): - load_kwargs = {'max_age': max_age} + load_kwargs = {"max_age": max_age} load_payload_kwargs = {} return self._loads_unsafe_impl(s, salt, load_kwargs, load_payload_kwargs) @@ -680,27 +694,35 @@ class JSONWebSignatureSerializer(Serializer): """ jws_algorithms = { - 'HS256': HMACAlgorithm(hashlib.sha256), - 'HS384': HMACAlgorithm(hashlib.sha384), - 'HS512': HMACAlgorithm(hashlib.sha512), - 'none': NoneAlgorithm(), + "HS256": HMACAlgorithm(hashlib.sha256), + "HS384": HMACAlgorithm(hashlib.sha384), + "HS512": HMACAlgorithm(hashlib.sha512), + "none": NoneAlgorithm(), } #: The default algorithm to use for signature generation - default_algorithm = 'HS512' + default_algorithm = "HS512" default_serializer = _CompactJSON def __init__( - self, secret_key, salt=None, - serializer=None, serializer_kwargs=None, - signer=None, signer_kwargs=None, - algorithm_name=None + self, + secret_key, + salt=None, + serializer=None, + serializer_kwargs=None, + signer=None, + signer_kwargs=None, + algorithm_name=None, ): Serializer.__init__( - self, secret_key=secret_key, salt=salt, - serializer=serializer, serializer_kwargs=serializer_kwargs, - signer=signer, signer_kwargs=signer_kwargs + self, + secret_key=secret_key, + salt=salt, + serializer=serializer, + serializer_kwargs=serializer_kwargs, + signer=signer, + signer_kwargs=signer_kwargs, ) if algorithm_name is None: algorithm_name = self.default_algorithm @@ -709,56 +731,69 @@ class JSONWebSignatureSerializer(Serializer): def load_payload(self, payload, serializer=None, return_header=False): payload = want_bytes(payload) - if b'.' not in payload: + if b"." not in payload: raise BadPayload('No "." found in value') - base64d_header, base64d_payload = payload.split(b'.', 1) + base64d_header, base64d_payload = payload.split(b".", 1) try: json_header = base64_decode(base64d_header) except Exception as e: - raise BadHeader('Could not base64 decode the header because of ' - 'an exception', original_error=e) + raise BadHeader( + "Could not base64 decode the header because of " "an exception", + original_error=e, + ) try: json_payload = base64_decode(base64d_payload) except Exception as e: - raise BadPayload('Could not base64 decode the payload because of ' - 'an exception', original_error=e) + raise BadPayload( + "Could not base64 decode the payload because of " "an exception", + original_error=e, + ) try: - header = Serializer.load_payload(self, json_header, - serializer=json) + header = Serializer.load_payload(self, json_header, serializer=json) except BadData as e: - raise BadHeader('Could not unserialize header because it was ' - 'malformed', original_error=e) + raise BadHeader( + "Could not unserialize header because it was " "malformed", + original_error=e, + ) if not isinstance(header, dict): - raise BadHeader('Header payload is not a JSON object', - header=header) + raise BadHeader("Header payload is not a JSON object", header=header) payload = Serializer.load_payload(self, json_payload, serializer=serializer) if return_header: return payload, header return payload def dump_payload(self, header, obj): - base64d_header = base64_encode(self.serializer.dumps(header, **self.serializer_kwargs)) - base64d_payload = base64_encode(self.serializer.dumps(obj, **self.serializer_kwargs)) - return base64d_header + b'.' + base64d_payload + base64d_header = base64_encode( + self.serializer.dumps(header, **self.serializer_kwargs) + ) + base64d_payload = base64_encode( + self.serializer.dumps(obj, **self.serializer_kwargs) + ) + return base64d_header + b"." + base64d_payload def make_algorithm(self, algorithm_name): try: return self.jws_algorithms[algorithm_name] except KeyError: - raise NotImplementedError('Algorithm not supported') + raise NotImplementedError("Algorithm not supported") def make_signer(self, salt=None, algorithm=None): if salt is None: salt = self.salt - key_derivation = 'none' if salt is None else None + key_derivation = "none" if salt is None else None if algorithm is None: algorithm = self.algorithm - return self.signer(self.secret_key, salt=salt, sep='.', - key_derivation=key_derivation, algorithm=algorithm) + return self.signer( + self.secret_key, + salt=salt, + sep=".", + key_derivation=key_derivation, + algorithm=algorithm, + ) def make_header(self, header_fields): header = header_fields.copy() if header_fields else {} - header['alg'] = self.algorithm_name + header["alg"] = self.algorithm_name return header def dumps(self, obj, salt=None, header_fields=None): @@ -776,16 +811,16 @@ class JSONWebSignatureSerializer(Serializer): """ payload, header = self.load_payload( self.make_signer(salt, self.algorithm).unsign(want_bytes(s)), - return_header=True) - if header.get('alg') != self.algorithm_name: - raise BadHeader('Algorithm mismatch', header=header, - payload=payload) + return_header=True, + ) + if header.get("alg") != self.algorithm_name: + raise BadHeader("Algorithm mismatch", header=header, payload=payload) if return_header: return payload, header return payload def loads_unsafe(self, s, salt=None, return_header=False): - kwargs = {'return_header': return_header} + kwargs = {"return_header": return_header} return self._loads_unsafe_impl(s, salt, kwargs, kwargs) @@ -815,37 +850,39 @@ class TimedJSONWebSignatureSerializer(JSONWebSignatureSerializer): header = JSONWebSignatureSerializer.make_header(self, header_fields) iat = self.now() exp = iat + self.expires_in - header['iat'] = iat - header['exp'] = exp + header["iat"] = iat + header["exp"] = exp return header def loads(self, s, salt=None, return_header=False): payload, header = JSONWebSignatureSerializer.loads( - self, s, salt, return_header=True) + self, s, salt, return_header=True + ) - if 'exp' not in header: - raise BadSignature('Missing expiry date', payload=payload) + if "exp" not in header: + raise BadSignature("Missing expiry date", payload=payload) try: - header['exp'] = int(header['exp']) + header["exp"] = int(header["exp"]) except ValueError: - raise BadHeader('Expiry date is not valid timestamp', payload=payload) + raise BadHeader("Expiry date is not valid timestamp", payload=payload) - if not (isinstance(header['exp'], number_types) - and header['exp'] > 0): - raise BadSignature('expiry date is not an IntDate', - payload=payload) + if not (isinstance(header["exp"], number_types) and header["exp"] > 0): + raise BadSignature("expiry date is not an IntDate", payload=payload) - if header['exp'] < self.now(): - raise SignatureExpired('Signature expired', payload=payload, - date_signed=self.get_issue_date(header)) + if header["exp"] < self.now(): + raise SignatureExpired( + "Signature expired", + payload=payload, + date_signed=self.get_issue_date(header), + ) if return_header: return payload, header return payload def get_issue_date(self, header): - rv = header.get('iat') + rv = header.get("iat") if isinstance(rv, number_types): return datetime.utcfromtimestamp(int(rv)) @@ -861,20 +898,25 @@ class URLSafeSerializerMixin(object): def load_payload(self, payload, *args, **kwargs): decompress = False - if payload.startswith(b'.'): + if payload.startswith(b"."): payload = payload[1:] decompress = True try: json = base64_decode(payload) except Exception as e: - raise BadPayload('Could not base64 decode the payload because of ' - 'an exception', original_error=e) + raise BadPayload( + "Could not base64 decode the payload because of " "an exception", + original_error=e, + ) if decompress: try: json = zlib.decompress(json) except Exception as e: - raise BadPayload('Could not zlib decompress the payload before ' - 'decoding the payload', original_error=e) + raise BadPayload( + "Could not zlib decompress the payload before " + "decoding the payload", + original_error=e, + ) return super(URLSafeSerializerMixin, self).load_payload(json, *args, **kwargs) def dump_payload(self, obj): @@ -886,7 +928,7 @@ class URLSafeSerializerMixin(object): is_compressed = True base64d = base64_encode(json) if is_compressed: - base64d = b'.' + base64d + base64d = b"." + base64d return base64d @@ -895,6 +937,7 @@ class URLSafeSerializer(URLSafeSerializerMixin, Serializer): safe string consisting of the upper and lowercase character of the alphabet as well as ``'_'``, ``'-'`` and ``'.'``. """ + default_serializer = _CompactJSON @@ -903,4 +946,5 @@ class URLSafeTimedSerializer(URLSafeSerializerMixin, TimedSerializer): safe string consisting of the upper and lowercase character of the alphabet as well as ``'_'``, ``'-'`` and ``'.'``. """ + default_serializer = _CompactJSON diff --git a/tests/test_itsdangerous.py b/tests/test_itsdangerous.py index b28406b..be0744c 100755 --- a/tests/test_itsdangerous.py +++ b/tests/test_itsdangerous.py @@ -1,25 +1,32 @@ #!/usr/bin/env python -import time -import pickle -import pytest import hashlib +import pickle +import time import unittest from datetime import datetime +import pytest + import itsdangerous as idmod -from itsdangerous import want_bytes, text_type, PY2 +from itsdangerous import PY2 +from itsdangerous import text_type +from itsdangerous import want_bytes # Helper function for some unsafe string manipulation on encoded # data. This is required for Python 3 but would break on Python 2 if PY2: + def _coerce_string(reference_string, value): return value + + else: + def _coerce_string(reference_string, value): - assert isinstance(value, text_type), 'rhs needs to be a string' + assert isinstance(value, text_type), "rhs needs to be a string" if type(reference_string) != type(value): - value = value.encode('utf-8') + value = value.encode("utf-8") return value @@ -31,17 +38,18 @@ class UtilityTestCase(unittest.TestCase): class SignerTestCase(unittest.TestCase): signer_class = idmod.Signer + def make_signer(self, *args, **kwargs): return self.signer_class(*args, **kwargs) def test_sign(self): - s = self.make_signer('secret-key') - assert isinstance(s.sign('my string'), bytes) + s = self.make_signer("secret-key") + assert isinstance(s.sign("my string"), bytes) def test_sign_invalid_separator(self): with pytest.raises(ValueError) as excinfo: - s = self.make_signer('secret-key', sep='-') - assert 'separator cannot be used' in str(excinfo.value) + self.make_signer("secret-key", sep="-") + assert "separator cannot be used" in str(excinfo.value) class SerializerTestCase(unittest.TestCase): @@ -51,99 +59,108 @@ class SerializerTestCase(unittest.TestCase): return self.serializer_class(*args, **kwargs) def test_dumps_loads(self): - objects = (['a', 'list'], 'a string', u'a unicode string \u2019', - {'a': 'dictionary'}, 42, 42.5) - s = self.make_serializer('Test') + objects = ( + ["a", "list"], + "a string", + u"a unicode string \u2019", + {"a": "dictionary"}, + 42, + 42.5, + ) + s = self.make_serializer("Test") for o in objects: value = s.dumps(o) self.assertNotEqual(o, value) self.assertEqual(o, s.loads(value)) def test_decode_detects_tampering(self): - s = self.make_serializer('Test') + s = self.make_serializer("Test") transforms = ( lambda s: s.upper(), - lambda s: s + _coerce_string(s, 'a'), - lambda s: _coerce_string(s, 'a') + s[1:], - lambda s: s.replace(_coerce_string(s, '.'), _coerce_string(s, '')), + lambda s: s + _coerce_string(s, "a"), + lambda s: _coerce_string(s, "a") + s[1:], + lambda s: s.replace(_coerce_string(s, "."), _coerce_string(s, "")), ) - value = { - 'foo': 'bar', - 'baz': 1, - } + value = {"foo": "bar", "baz": 1} encoded = s.dumps(value) self.assertEqual(value, s.loads(encoded)) for transform in transforms: - self.assertRaises( - idmod.BadSignature, s.loads, transform(encoded)) + self.assertRaises(idmod.BadSignature, s.loads, transform(encoded)) def test_accepts_unicode(self): - objects = (['a', 'list'], 'a string', u'a unicode string \u2019', - {'a': 'dictionary'}, 42, 42.5) - s = self.make_serializer('Test') + objects = ( + ["a", "list"], + "a string", + u"a unicode string \u2019", + {"a": "dictionary"}, + 42, + 42.5, + ) + s = self.make_serializer("Test") for o in objects: value = s.dumps(o) self.assertNotEqual(o, value) self.assertEqual(o, s.loads(value)) def test_exception_attributes(self): - secret_key = 'predictable-key' - value = u'hello' + secret_key = "predictable-key" + value = u"hello" s = self.make_serializer(secret_key) ts = s.dumps(value) try: - s.loads(ts + _coerce_string(ts, 'x')) + s.loads(ts + _coerce_string(ts, "x")) except idmod.BadSignature as e: - self.assertEqual(want_bytes(e.payload), - want_bytes(ts).rsplit(b'.', 1)[0]) + self.assertEqual(want_bytes(e.payload), want_bytes(ts).rsplit(b".", 1)[0]) self.assertEqual(s.load_payload(e.payload), value) else: - self.fail('Did not get bad signature') + self.fail("Did not get bad signature") def test_unsafe_load(self): - secret_key = 'predictable-key' - value = u'hello' + secret_key = "predictable-key" + value = u"hello" s = self.make_serializer(secret_key) ts = s.dumps(value) - self.assertEqual(s.loads_unsafe(ts), (True, u'hello')) - self.assertEqual(s.loads_unsafe(ts, salt='modified'), (False, u'hello')) + self.assertEqual(s.loads_unsafe(ts), (True, u"hello")) + self.assertEqual(s.loads_unsafe(ts, salt="modified"), (False, u"hello")) def test_load_unsafe_with_unicode_strings(self): - secret_key = 'predictable-key' - value = u'hello' + secret_key = "predictable-key" + value = u"hello" s = self.make_serializer(secret_key) ts = s.dumps(value) - self.assertEqual(s.loads_unsafe(ts), (True, u'hello')) - self.assertEqual(s.loads_unsafe(ts, salt='modified'), (False, u'hello')) + self.assertEqual(s.loads_unsafe(ts), (True, u"hello")) + self.assertEqual(s.loads_unsafe(ts, salt="modified"), (False, u"hello")) try: - s.loads(ts, salt='modified') + s.loads(ts, salt="modified") except idmod.BadSignature as e: - self.assertEqual(s.load_payload(e.payload), u'hello') + self.assertEqual(s.load_payload(e.payload), u"hello") def test_signer_kwargs(self): - secret_key = 'predictable-key' - value = 'hello' - s = self.make_serializer(secret_key, signer_kwargs=dict( - digest_method=hashlib.md5, - key_derivation='hmac' - )) + secret_key = "predictable-key" + value = "hello" + s = self.make_serializer( + secret_key, + signer_kwargs=dict(digest_method=hashlib.md5, key_derivation="hmac"), + ) ts = s.dumps(value) - self.assertEqual(s.loads(ts), u'hello') + self.assertEqual(s.loads(ts), u"hello") def test_serializer_kwargs(self): - s = self.make_serializer('predictable-key', serializer_kwargs={'sort_keys': True}) + s = self.make_serializer( + "predictable-key", serializer_kwargs={"sort_keys": True} + ) # pickle tests pop serializer kwargs, so skip this test for those if not s.serializer_kwargs: return - ts1 = s.dumps({'c': 3, 'a': 1, 'b': 2}) + ts1 = s.dumps({"c": 3, "a": 1, "b": 2}) ts2 = s.dumps(dict(a=1, b=2, c=3)) self.assertEqual(ts1, ts2) @@ -160,8 +177,8 @@ class TimedSerializerTestCase(SerializerTestCase): time.time = self._time def test_decode_with_timeout(self): - secret_key = 'predictable-key' - value = u'hello' + secret_key = "predictable-key" + value = u"hello" s = self.make_serializer(secret_key) ts = s.dumps(value) @@ -171,12 +188,11 @@ class TimedSerializerTestCase(SerializerTestCase): time.time = lambda: 10 self.assertEqual(s.loads(ts, max_age=11), value) self.assertEqual(s.loads(ts, max_age=10), value) - self.assertRaises( - idmod.SignatureExpired, s.loads, ts, max_age=9) + self.assertRaises(idmod.SignatureExpired, s.loads, ts, max_age=9) def test_decode_return_timestamp(self): - secret_key = 'predictable-key' - value = u'hello' + secret_key = "predictable-key" + value = u"hello" s = self.make_serializer(secret_key) ts = s.dumps(value) @@ -185,34 +201,32 @@ class TimedSerializerTestCase(SerializerTestCase): self.assertEqual(timestamp, datetime.utcfromtimestamp(time.time())) def test_exception_attributes(self): - secret_key = 'predictable-key' - value = u'hello' + secret_key = "predictable-key" + value = u"hello" s = self.make_serializer(secret_key) ts = s.dumps(value) try: s.loads(ts, max_age=-1) except idmod.SignatureExpired as e: - self.assertEqual(e.date_signed, - datetime.utcfromtimestamp(time.time())) - self.assertEqual(want_bytes(e.payload), - want_bytes(ts).rsplit(b'.', 2)[0]) + self.assertEqual(e.date_signed, datetime.utcfromtimestamp(time.time())) + self.assertEqual(want_bytes(e.payload), want_bytes(ts).rsplit(b".", 2)[0]) self.assertEqual(s.load_payload(e.payload), value) else: - self.fail('Did not get expiration') + self.fail("Did not get expiration") class JSONWebSignatureSerializerTestCase(SerializerTestCase): serializer_class = idmod.JSONWebSignatureSerializer def test_decode_return_header(self): - secret_key = 'predictable-key' - value = u'hello' + secret_key = "predictable-key" + value = u"hello" header = {"typ": "dummy"} s = self.make_serializer(secret_key) full_header = header.copy() - full_header['alg'] = s.algorithm_name + full_header["alg"] = s.algorithm_name ts = s.dumps(value, header_fields=header) loaded, loaded_header = s.loads(ts, return_header=True) @@ -220,82 +234,88 @@ class JSONWebSignatureSerializerTestCase(SerializerTestCase): self.assertEqual(loaded_header, full_header) def test_hmac_algorithms(self): - secret_key = 'predictable-key' - value = u'hello' + secret_key = "predictable-key" + value = u"hello" - algorithms = ('HS256', 'HS384', 'HS512') + algorithms = ("HS256", "HS384", "HS512") for algorithm in algorithms: s = self.make_serializer(secret_key, algorithm_name=algorithm) ts = s.dumps(value) self.assertEqual(s.loads(ts), value) def test_none_algorithm(self): - secret_key = 'predictable-key' - value = u'hello' + secret_key = "predictable-key" + value = u"hello" s = self.make_serializer(secret_key) ts = s.dumps(value) self.assertEqual(s.loads(ts), value) def test_algorithm_mismatch(self): - secret_key = 'predictable-key' - value = u'hello' + secret_key = "predictable-key" + value = u"hello" - s = self.make_serializer(secret_key, algorithm_name='HS256') + s = self.make_serializer(secret_key, algorithm_name="HS256") ts = s.dumps(value) - s = self.make_serializer(secret_key, algorithm_name='HS384') + s = self.make_serializer(secret_key, algorithm_name="HS384") try: s.loads(ts) except idmod.BadSignature as e: self.assertEqual(s.load_payload(e.payload), value) else: - self.fail('Did not get algorithm mismatch') + self.fail("Did not get algorithm mismatch") class TimedJSONWebSignatureSerializerTest(unittest.TestCase): serializer_class = idmod.TimedJSONWebSignatureSerializer def test_token_contains_issue_date_and_expiry_time(self): - s = self.serializer_class('secret') - result = s.dumps({'es': 'geht'}) - self.assertTrue('exp' in s.loads(result, return_header=True)[1]) - self.assertTrue('iat' in s.loads(result, return_header=True)[1]) + s = self.serializer_class("secret") + result = s.dumps({"es": "geht"}) + self.assertTrue("exp" in s.loads(result, return_header=True)[1]) + self.assertTrue("iat" in s.loads(result, return_header=True)[1]) def test_token_expires_at_given_expiry_time(self): - s = self.serializer_class('secret') + s = self.serializer_class("secret") an_hour_ago = int(time.time()) - 3601 s.now = lambda: an_hour_ago - result = s.dumps({'foo': 'bar'}) - s = self.serializer_class('secret') + result = s.dumps({"foo": "bar"}) + s = self.serializer_class("secret") self.assertRaises(idmod.SignatureExpired, s.loads, result) def test_token_is_invalid_if_expiry_time_is_missing(self): - bad_s = idmod.JSONWebSignatureSerializer('secret') + bad_s = idmod.JSONWebSignatureSerializer("secret") invalid_token_empty = bad_s.dumps({}) - s = self.serializer_class('secret') + s = self.serializer_class("secret") self.assertRaises(idmod.BadSignature, s.loads, invalid_token_empty) def test_token_is_invalid_if_expiry_time_is_negative(self): - s = self.serializer_class('secret', expires_in=-123) - result = s.dumps({'foo': 'bar'}) + s = self.serializer_class("secret", expires_in=-123) + result = s.dumps({"foo": "bar"}) self.assertRaises(idmod.BadSignature, s.loads, result) def test_creating_a_token_adds_the_expiry_date(self): expires_in_two_hours = 7200 - s = self.serializer_class('secret', expires_in=expires_in_two_hours) - result, header = s.loads(s.dumps({'foo': 'bar'}), return_header=True) - self.assertEqual(header['exp'] - header['iat'], expires_in_two_hours) + s = self.serializer_class("secret", expires_in=expires_in_two_hours) + result, header = s.loads(s.dumps({"foo": "bar"}), return_header=True) + self.assertEqual(header["exp"] - header["iat"], expires_in_two_hours) class URLSafeSerializerMixin(object): - def test_is_base62(self): - allowed = frozenset(b'0123456789abcdefghijklmnopqrstuvwxyz' + - b'ABCDEFGHIJKLMNOPQRSTUVWXYZ_-.') - objects = (['a', 'list'], 'a string', u'a unicode string \u2019', - {'a': 'dictionary'}, 42, 42.5) - s = self.make_serializer('Test') + allowed = frozenset( + b"0123456789abcdefghijklmnopqrstuvwxyz" + b"ABCDEFGHIJKLMNOPQRSTUVWXYZ_-." + ) + objects = ( + ["a", "list"], + "a string", + u"a unicode string \u2019", + {"a": "dictionary"}, + 42, + 42.5, + ) + s = self.make_serializer("Test") for o in objects: value = want_bytes(s.dumps(o)) self.assertTrue(set(value).issubset(set(allowed))) @@ -303,15 +323,14 @@ class URLSafeSerializerMixin(object): self.assertEqual(o, s.loads(value)) def test_invalid_base64_does_not_fail_load_payload(self): - s = idmod.URLSafeSerializer('aha!') - self.assertRaises(idmod.BadPayload, s.load_payload, b'kZ4m3du844lIN') + s = idmod.URLSafeSerializer("aha!") + self.assertRaises(idmod.BadPayload, s.load_payload, b"kZ4m3du844lIN") class PickleSerializerMixin(object): - def make_serializer(self, *args, **kwargs): - kwargs.pop('serializer_kwargs', '') - kwargs.setdefault('serializer', pickle) + kwargs.pop("serializer_kwargs", "") + kwargs.setdefault("serializer", pickle) return super(PickleSerializerMixin, self).make_serializer(*args, **kwargs) @@ -335,5 +354,7 @@ class PickleURLSafeSerializerTestCase(PickleSerializerMixin, URLSafeSerializerTe pass -class PickleURLSafeTimedSerializerTestCase(PickleSerializerMixin, URLSafeTimedSerializerTestCase): +class PickleURLSafeTimedSerializerTestCase( + PickleSerializerMixin, URLSafeTimedSerializerTestCase +): pass @@ -1,6 +1,7 @@ [tox] envlist = py{37,36,35,34,27,py3,py} + stylecheck docs-html coverage-report skip_missing_interpreters = true @@ -11,6 +12,11 @@ setenv = deps = pytest-cov commands = pytest --cov --cov-report= {posargs} +[testenv:stylecheck] +deps = pre-commit +skip_install = True +commands = pre-commit run --all-files --show-diff-on-failure + [testenv:docs-html] deps = -r docs/requirements.txt commands = sphinx-build -W -b html -d {envtmpdir}/doctrees docs {envtmpdir}/html |