summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Lord <davidism@gmail.com>2018-10-17 09:12:40 -0700
committerDavid Lord <davidism@gmail.com>2018-10-17 12:47:51 -0700
commita8c7d1b9737fda3b3092d5950a021398f06955b5 (patch)
tree727ec1dd7a76e78fb35db3c162c10d9f1945e01b
parentccc67a137fb912eb65c6959107d0d5bc69db3890 (diff)
downloaditsdangerous-split.tar.gz
split tests into modules, use pytest, 100% coveragesplit
-rw-r--r--.gitignore1
-rw-r--r--.pre-commit-config.yaml9
-rw-r--r--.travis.yml4
-rw-r--r--LICENSE.rst2
-rw-r--r--setup.cfg2
-rw-r--r--src/itsdangerous/_compat.py4
-rw-r--r--src/itsdangerous/jws.py8
-rw-r--r--tests/__init__.py0
-rw-r--r--tests/test_compat.py11
-rw-r--r--tests/test_encoding.py38
-rwxr-xr-xtests/test_itsdangerous.py359
-rw-r--r--tests/test_jws.py122
-rw-r--r--tests/test_serializer.py133
-rw-r--r--tests/test_signer.py99
-rw-r--r--tests/test_timed.py85
-rw-r--r--tests/test_url_safe.py24
-rw-r--r--tox.ini4
17 files changed, 533 insertions, 372 deletions
diff --git a/.gitignore b/.gitignore
index 7d04173..c6c3b67 100644
--- a/.gitignore
+++ b/.gitignore
@@ -16,3 +16,4 @@ docs/_build/
.coverage
.coverage.*
htmlcov/
+.hypothesis/
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 7e29622..7be015d 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/asottile/reorder_python_imports
- rev: v1.2.0
+ rev: v1.3.1
hooks:
- id: reorder-python-imports
args: ["--application-directories", "src"]
@@ -9,7 +9,10 @@ repos:
hooks:
- id: black
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v1.4.0-1
+ rev: v2.0.0
hooks:
+ - id: check-byte-order-marker
+ - id: trailing-whitespace
+ - id: end-of-file-fixer
- id: flake8
- additional_dependencies: [flake8-bugbear] \ No newline at end of file
+ additional_dependencies: [flake8-bugbear]
diff --git a/.travis.yml b/.travis.yml
index 92acace..e87eaa0 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -25,7 +25,9 @@ script:
- tox
cache:
- - pip
+ directories:
+ - $HOME/.cache/pip
+ - $HOME/.cache/pre-commit
branches:
only:
diff --git a/LICENSE.rst b/LICENSE.rst
index fbfceb2..e506dca 100644
--- a/LICENSE.rst
+++ b/LICENSE.rst
@@ -44,4 +44,4 @@ The initial implementation of It's Dangerous was inspired by Django's
signing module.
Copyright © Django Software Foundation and individual contributors.
-All rights reserved. \ No newline at end of file
+All rights reserved.
diff --git a/setup.cfg b/setup.cfg
index bfbcafd..6614658 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -33,4 +33,4 @@ ignore = E203, E501, W503
# up to 88 allowed by bugbear B950
max-line-length = 80
# init is used to export public API, ignore import warnings
-exclude = src/itsdangerous/__init__.py \ No newline at end of file
+exclude = src/itsdangerous/__init__.py
diff --git a/src/itsdangerous/_compat.py b/src/itsdangerous/_compat.py
index de70f9c..2291bce 100644
--- a/src/itsdangerous/_compat.py
+++ b/src/itsdangerous/_compat.py
@@ -16,7 +16,7 @@ else:
number_types = (numbers.Real, decimal.Decimal)
-def constant_time_compare(val1, val2):
+def _constant_time_compare(val1, val2):
"""Return ``True`` if the two strings are equal, ``False``
otherwise.
@@ -43,4 +43,4 @@ def constant_time_compare(val1, val2):
# Starting with 2.7/3.3 the standard library has a c-implementation for
# constant time string compares.
-constant_time_compare = getattr(hmac, "compare_digest", constant_time_compare)
+constant_time_compare = getattr(hmac, "compare_digest", _constant_time_compare)
diff --git a/src/itsdangerous/jws.py b/src/itsdangerous/jws.py
index b1f31de..92e9ec8 100644
--- a/src/itsdangerous/jws.py
+++ b/src/itsdangerous/jws.py
@@ -190,13 +190,13 @@ class TimedJSONWebSignatureSerializer(JSONWebSignatureSerializer):
if "exp" not in header:
raise BadSignature("Missing expiry date", payload=payload)
+ int_date_error = BadHeader("Expiry date is not an IntDate", payload=payload)
try:
header["exp"] = int(header["exp"])
except ValueError:
- raise BadHeader("Expiry date is not valid timestamp", payload=payload)
-
- if not (isinstance(header["exp"], number_types) and header["exp"] > 0):
- raise BadSignature("expiry date is not an IntDate", payload=payload)
+ raise int_date_error
+ if header["exp"] < 0:
+ raise int_date_error
if header["exp"] < self.now():
raise SignatureExpired(
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/__init__.py
diff --git a/tests/test_compat.py b/tests/test_compat.py
new file mode 100644
index 0000000..2043fad
--- /dev/null
+++ b/tests/test_compat.py
@@ -0,0 +1,11 @@
+import pytest
+
+from itsdangerous._compat import _constant_time_compare
+
+
+@pytest.mark.parametrize(
+ ("a", "b", "expect"),
+ ((b"a", b"a", True), (b"a", b"b", False), (b"a", b"aa", False)),
+)
+def test_python_constant_time_compare(a, b, expect):
+ assert _constant_time_compare(a, b) == expect
diff --git a/tests/test_encoding.py b/tests/test_encoding.py
new file mode 100644
index 0000000..d60ec17
--- /dev/null
+++ b/tests/test_encoding.py
@@ -0,0 +1,38 @@
+# -*- coding: utf-8 -*-
+import pytest
+
+from itsdangerous.encoding import base64_decode
+from itsdangerous.encoding import base64_encode
+from itsdangerous.encoding import bytes_to_int
+from itsdangerous.encoding import int_to_bytes
+from itsdangerous.encoding import want_bytes
+from itsdangerous.exc import BadData
+
+
+@pytest.mark.parametrize("value", (u"mañana", b"tomorrow"))
+def test_want_bytes(value):
+ out = want_bytes(value)
+ assert isinstance(out, bytes)
+
+
+@pytest.mark.parametrize("value", (u"無限", b"infinite"))
+def test_base64(value):
+ enc = base64_encode(value)
+ assert isinstance(enc, bytes)
+ dec = base64_decode(enc)
+ assert dec == want_bytes(value)
+
+
+def test_base64_bad():
+ with pytest.raises(BadData):
+ base64_decode("12345")
+
+
+@pytest.mark.parametrize(
+ ("value", "expect"), ((0, b""), (192, b"\xc0"), (18446744073709551615, b"\xff" * 8))
+)
+def test_int_bytes(value, expect):
+ enc = int_to_bytes(value)
+ assert enc == expect
+ dec = bytes_to_int(enc)
+ assert dec == value
diff --git a/tests/test_itsdangerous.py b/tests/test_itsdangerous.py
deleted file mode 100755
index 1a9f1ec..0000000
--- a/tests/test_itsdangerous.py
+++ /dev/null
@@ -1,359 +0,0 @@
-#!/usr/bin/env python
-import hashlib
-import pickle
-import time
-import unittest
-from datetime import datetime
-
-import pytest
-
-import itsdangerous
-from itsdangerous._compat import PY2
-from itsdangerous._compat import text_type
-from itsdangerous.encoding import want_bytes
-
-# Helper function for some unsafe string manipulation on encoded
-# data. This is required for Python 3 but would break on Python 2
-if PY2:
-
- def _coerce_string(reference_string, value):
- return value
-
-
-else:
-
- def _coerce_string(reference_string, value):
- assert isinstance(value, text_type), "rhs needs to be a string"
- if type(reference_string) != type(value):
- value = value.encode("utf-8")
- return value
-
-
-class UtilityTestCase(unittest.TestCase):
- def test_want_bytes(self):
- self.assertEqual(want_bytes(b"foobar"), b"foobar")
- self.assertEqual(want_bytes(u"foobar"), b"foobar")
-
-
-class SignerTestCase(unittest.TestCase):
- signer_class = itsdangerous.Signer
-
- def make_signer(self, *args, **kwargs):
- return self.signer_class(*args, **kwargs)
-
- def test_sign(self):
- s = self.make_signer("secret-key")
- assert isinstance(s.sign("my string"), bytes)
-
- def test_sign_invalid_separator(self):
- with pytest.raises(ValueError) as excinfo:
- self.make_signer("secret-key", sep="-")
- assert "separator cannot be used" in str(excinfo.value)
-
-
-class SerializerTestCase(unittest.TestCase):
- serializer_class = itsdangerous.Serializer
-
- def make_serializer(self, *args, **kwargs):
- return self.serializer_class(*args, **kwargs)
-
- def test_dumps_loads(self):
- objects = (
- ["a", "list"],
- "a string",
- u"a unicode string \u2019",
- {"a": "dictionary"},
- 42,
- 42.5,
- )
- s = self.make_serializer("Test")
- for o in objects:
- value = s.dumps(o)
- self.assertNotEqual(o, value)
- self.assertEqual(o, s.loads(value))
-
- def test_decode_detects_tampering(self):
- s = self.make_serializer("Test")
-
- transforms = (
- lambda s: s.upper(),
- lambda s: s + _coerce_string(s, "a"),
- lambda s: _coerce_string(s, "a") + s[1:],
- lambda s: s.replace(_coerce_string(s, "."), _coerce_string(s, "")),
- )
- value = {"foo": "bar", "baz": 1}
- encoded = s.dumps(value)
- self.assertEqual(value, s.loads(encoded))
- for transform in transforms:
- self.assertRaises(itsdangerous.BadSignature, s.loads, transform(encoded))
-
- def test_accepts_unicode(self):
- objects = (
- ["a", "list"],
- "a string",
- u"a unicode string \u2019",
- {"a": "dictionary"},
- 42,
- 42.5,
- )
- s = self.make_serializer("Test")
- for o in objects:
- value = s.dumps(o)
- self.assertNotEqual(o, value)
- self.assertEqual(o, s.loads(value))
-
- def test_exception_attributes(self):
- secret_key = "predictable-key"
- value = u"hello"
-
- s = self.make_serializer(secret_key)
- ts = s.dumps(value)
-
- try:
- s.loads(ts + _coerce_string(ts, "x"))
- except itsdangerous.BadSignature as e:
- self.assertEqual(want_bytes(e.payload), want_bytes(ts).rsplit(b".", 1)[0])
- self.assertEqual(s.load_payload(e.payload), value)
- else:
- self.fail("Did not get bad signature")
-
- def test_unsafe_load(self):
- secret_key = "predictable-key"
- value = u"hello"
-
- s = self.make_serializer(secret_key)
- ts = s.dumps(value)
- self.assertEqual(s.loads_unsafe(ts), (True, u"hello"))
- self.assertEqual(s.loads_unsafe(ts, salt="modified"), (False, u"hello"))
-
- def test_load_unsafe_with_unicode_strings(self):
- secret_key = "predictable-key"
- value = u"hello"
-
- s = self.make_serializer(secret_key)
- ts = s.dumps(value)
- self.assertEqual(s.loads_unsafe(ts), (True, u"hello"))
- self.assertEqual(s.loads_unsafe(ts, salt="modified"), (False, u"hello"))
-
- try:
- s.loads(ts, salt="modified")
- except itsdangerous.BadSignature as e:
- self.assertEqual(s.load_payload(e.payload), u"hello")
-
- def test_signer_kwargs(self):
- secret_key = "predictable-key"
- value = "hello"
- s = self.make_serializer(
- secret_key,
- signer_kwargs=dict(digest_method=hashlib.md5, key_derivation="hmac"),
- )
- ts = s.dumps(value)
- self.assertEqual(s.loads(ts), u"hello")
-
- def test_serializer_kwargs(self):
- s = self.make_serializer(
- "predictable-key", serializer_kwargs={"sort_keys": True}
- )
-
- # pickle tests pop serializer kwargs, so skip this test for those
- if not s.serializer_kwargs:
- return
-
- ts1 = s.dumps({"c": 3, "a": 1, "b": 2})
- ts2 = s.dumps(dict(a=1, b=2, c=3))
-
- self.assertEqual(ts1, ts2)
-
-
-class TimedSerializerTestCase(SerializerTestCase):
- serializer_class = itsdangerous.TimedSerializer
-
- def setUp(self):
- self._time = time.time
- time.time = lambda: 0
-
- def tearDown(self):
- time.time = self._time
-
- def test_decode_with_timeout(self):
- secret_key = "predictable-key"
- value = u"hello"
-
- s = self.make_serializer(secret_key)
- ts = s.dumps(value)
- self.assertNotEqual(ts, itsdangerous.Serializer(secret_key).dumps(value))
-
- self.assertEqual(s.loads(ts), value)
- time.time = lambda: 10
- self.assertEqual(s.loads(ts, max_age=11), value)
- self.assertEqual(s.loads(ts, max_age=10), value)
- self.assertRaises(itsdangerous.SignatureExpired, s.loads, ts, max_age=9)
-
- def test_decode_return_timestamp(self):
- secret_key = "predictable-key"
- value = u"hello"
-
- s = self.make_serializer(secret_key)
- ts = s.dumps(value)
- loaded, timestamp = s.loads(ts, return_timestamp=True)
- self.assertEqual(loaded, value)
- self.assertEqual(timestamp, datetime.utcfromtimestamp(time.time()))
-
- def test_exception_attributes(self):
- secret_key = "predictable-key"
- value = u"hello"
-
- s = self.make_serializer(secret_key)
- ts = s.dumps(value)
- try:
- s.loads(ts, max_age=-1)
- except itsdangerous.SignatureExpired as e:
- self.assertEqual(e.date_signed, datetime.utcfromtimestamp(time.time()))
- self.assertEqual(want_bytes(e.payload), want_bytes(ts).rsplit(b".", 2)[0])
- self.assertEqual(s.load_payload(e.payload), value)
- else:
- self.fail("Did not get expiration")
-
-
-class JSONWebSignatureSerializerTestCase(SerializerTestCase):
- serializer_class = itsdangerous.JSONWebSignatureSerializer
-
- def test_decode_return_header(self):
- secret_key = "predictable-key"
- value = u"hello"
- header = {"typ": "dummy"}
-
- s = self.make_serializer(secret_key)
- full_header = header.copy()
- full_header["alg"] = s.algorithm_name
-
- ts = s.dumps(value, header_fields=header)
- loaded, loaded_header = s.loads(ts, return_header=True)
- self.assertEqual(loaded, value)
- self.assertEqual(loaded_header, full_header)
-
- def test_hmac_algorithms(self):
- secret_key = "predictable-key"
- value = u"hello"
-
- algorithms = ("HS256", "HS384", "HS512")
- for algorithm in algorithms:
- s = self.make_serializer(secret_key, algorithm_name=algorithm)
- ts = s.dumps(value)
- self.assertEqual(s.loads(ts), value)
-
- def test_none_algorithm(self):
- secret_key = "predictable-key"
- value = u"hello"
-
- s = self.make_serializer(secret_key)
- ts = s.dumps(value)
- self.assertEqual(s.loads(ts), value)
-
- def test_algorithm_mismatch(self):
- secret_key = "predictable-key"
- value = u"hello"
-
- s = self.make_serializer(secret_key, algorithm_name="HS256")
- ts = s.dumps(value)
-
- s = self.make_serializer(secret_key, algorithm_name="HS384")
- try:
- s.loads(ts)
- except itsdangerous.BadSignature as e:
- self.assertEqual(s.load_payload(e.payload), value)
- else:
- self.fail("Did not get algorithm mismatch")
-
-
-class TimedJSONWebSignatureSerializerTest(unittest.TestCase):
- serializer_class = itsdangerous.TimedJSONWebSignatureSerializer
-
- def test_token_contains_issue_date_and_expiry_time(self):
- s = self.serializer_class("secret")
- result = s.dumps({"es": "geht"})
- self.assertTrue("exp" in s.loads(result, return_header=True)[1])
- self.assertTrue("iat" in s.loads(result, return_header=True)[1])
-
- def test_token_expires_at_given_expiry_time(self):
- s = self.serializer_class("secret")
- an_hour_ago = int(time.time()) - 3601
- s.now = lambda: an_hour_ago
- result = s.dumps({"foo": "bar"})
- s = self.serializer_class("secret")
- self.assertRaises(itsdangerous.SignatureExpired, s.loads, result)
-
- def test_token_is_invalid_if_expiry_time_is_missing(self):
- bad_s = itsdangerous.JSONWebSignatureSerializer("secret")
- invalid_token_empty = bad_s.dumps({})
- s = self.serializer_class("secret")
- self.assertRaises(itsdangerous.BadSignature, s.loads, invalid_token_empty)
-
- def test_token_is_invalid_if_expiry_time_is_negative(self):
- s = self.serializer_class("secret", expires_in=-123)
- result = s.dumps({"foo": "bar"})
- self.assertRaises(itsdangerous.BadSignature, s.loads, result)
-
- def test_creating_a_token_adds_the_expiry_date(self):
- expires_in_two_hours = 7200
- s = self.serializer_class("secret", expires_in=expires_in_two_hours)
- result, header = s.loads(s.dumps({"foo": "bar"}), return_header=True)
- self.assertEqual(header["exp"] - header["iat"], expires_in_two_hours)
-
-
-class URLSafeSerializerMixin(object):
- def test_is_base62(self):
- allowed = frozenset(
- b"0123456789abcdefghijklmnopqrstuvwxyz" + b"ABCDEFGHIJKLMNOPQRSTUVWXYZ_-."
- )
- objects = (
- ["a", "list"],
- "a string",
- u"a unicode string \u2019",
- {"a": "dictionary"},
- 42,
- 42.5,
- )
- s = self.make_serializer("Test")
- for o in objects:
- value = want_bytes(s.dumps(o))
- self.assertTrue(set(value).issubset(set(allowed)))
- self.assertNotEqual(o, value)
- self.assertEqual(o, s.loads(value))
-
- def test_invalid_base64_does_not_fail_load_payload(self):
- s = itsdangerous.URLSafeSerializer("aha!")
- self.assertRaises(itsdangerous.BadPayload, s.load_payload, b"kZ4m3du844lIN")
-
-
-class PickleSerializerMixin(object):
- def make_serializer(self, *args, **kwargs):
- kwargs.pop("serializer_kwargs", "")
- kwargs.setdefault("serializer", pickle)
- return super(PickleSerializerMixin, self).make_serializer(*args, **kwargs)
-
-
-class URLSafeSerializerTestCase(URLSafeSerializerMixin, SerializerTestCase):
- serializer_class = itsdangerous.URLSafeSerializer
-
-
-class URLSafeTimedSerializerTestCase(URLSafeSerializerMixin, TimedSerializerTestCase):
- serializer_class = itsdangerous.URLSafeTimedSerializer
-
-
-class PickleSerializerTestCase(PickleSerializerMixin, SerializerTestCase):
- pass
-
-
-class PickleTimedSerializerTestCase(PickleSerializerMixin, TimedSerializerTestCase):
- pass
-
-
-class PickleURLSafeSerializerTestCase(PickleSerializerMixin, URLSafeSerializerTestCase):
- pass
-
-
-class PickleURLSafeTimedSerializerTestCase(
- PickleSerializerMixin, URLSafeTimedSerializerTestCase
-):
- pass
diff --git a/tests/test_jws.py b/tests/test_jws.py
new file mode 100644
index 0000000..9938311
--- /dev/null
+++ b/tests/test_jws.py
@@ -0,0 +1,122 @@
+from functools import partial
+
+import pytest
+from tests.test_serializer import TestSerializer
+from tests.test_timed import TestTimedSerializer
+
+from itsdangerous.exc import BadData
+from itsdangerous.exc import BadHeader
+from itsdangerous.exc import BadPayload
+from itsdangerous.exc import BadSignature
+from itsdangerous.exc import SignatureExpired
+from itsdangerous.jws import JSONWebSignatureSerializer
+from itsdangerous.jws import TimedJSONWebSignatureSerializer
+
+
+class TestJWSSerializer(TestSerializer):
+ @pytest.fixture()
+ def serializer_factory(self):
+ return partial(JSONWebSignatureSerializer, secret_key="secret-key")
+
+ test_signer_cls = None
+ test_signer_kwargs = None
+
+ @pytest.mark.parametrize("algorithm_name", ("HS256", "HS384", "HS512", "none"))
+ def test_algorithm(self, serializer_factory, algorithm_name):
+ serializer = serializer_factory(algorithm_name=algorithm_name)
+ assert serializer.loads(serializer.dumps("value")) == "value"
+
+ def test_invalid_algorithm(self, serializer_factory):
+ with pytest.raises(NotImplementedError) as exc_info:
+ serializer_factory(algorithm_name="invalid")
+
+ assert "not supported" in str(exc_info.value)
+
+ def test_algorithm_mismatch(self, serializer_factory, serializer):
+ other = serializer_factory(algorithm_name="HS256")
+ other.algorithm = serializer.algorithm
+ signed = other.dumps("value")
+
+ with pytest.raises(BadHeader) as exc_info:
+ serializer.loads(signed)
+
+ assert "mismatch" in str(exc_info.value)
+
+ @pytest.mark.parametrize(
+ ("value", "exc_cls", "match"),
+ (
+ ("ab", BadPayload, '"."'),
+ ("a.b", BadHeader, "base64 decode"),
+ ("ew.b", BadPayload, "base64 decode"),
+ ("ew.ab", BadData, "malformed"),
+ ("W10.ab", BadHeader, "JSON object"),
+ ),
+ )
+ def test_load_payload_exceptions(self, serializer, value, exc_cls, match):
+ signer = serializer.make_signer()
+ signed = signer.sign(value)
+
+ with pytest.raises(exc_cls) as exc_info:
+ serializer.loads(signed)
+
+ assert match in str(exc_info.value)
+
+
+class TestTimedJWSSerializer(TestJWSSerializer, TestTimedSerializer):
+ @pytest.fixture()
+ def serializer_factory(self):
+ return partial(
+ TimedJSONWebSignatureSerializer, secret_key="secret-key", expires_in=10
+ )
+
+ def test_default_expires_in(self, serializer_factory):
+ serializer = serializer_factory(expires_in=None)
+ assert serializer.expires_in == serializer.DEFAULT_EXPIRES_IN
+
+ test_max_age = None
+
+ def test_exp(self, serializer, value, ts, freeze):
+ signed = serializer.dumps(value)
+ freeze.tick()
+ assert serializer.loads(signed) == value
+ freeze.tick(10)
+
+ with pytest.raises(SignatureExpired) as exc_info:
+ serializer.loads(signed)
+
+ assert exc_info.value.date_signed == ts
+ assert exc_info.value.payload == value
+
+ test_return_payload = None
+
+ def test_return_header(self, serializer, value, ts):
+ signed = serializer.dumps(value)
+ payload, header = serializer.loads(signed, return_header=True)
+ date_signed = serializer.get_issue_date(header)
+ assert (payload, date_signed) == (value, ts)
+
+ def test_missing_exp(self, serializer):
+ header = serializer.make_header(None)
+ del header["exp"]
+ signer = serializer.make_signer()
+ signed = signer.sign(serializer.dump_payload(header, "value"))
+
+ with pytest.raises(BadSignature):
+ serializer.loads(signed)
+
+ @pytest.mark.parametrize("exp", ("invalid", -1))
+ def test_invalid_exp(self, serializer, exp):
+ header = serializer.make_header(None)
+ header["exp"] = exp
+ signer = serializer.make_signer()
+ signed = signer.sign(serializer.dump_payload(header, "value"))
+
+ with pytest.raises(BadHeader) as exc_info:
+ serializer.loads(signed)
+
+ assert "IntDate" in str(exc_info.value)
+
+ def test_invalid_iat(self, serializer):
+ header = serializer.make_header(None)
+ header["iat"] = "invalid"
+ assert serializer.get_issue_date(header) is None
diff --git a/tests/test_serializer.py b/tests/test_serializer.py
new file mode 100644
index 0000000..465d507
--- /dev/null
+++ b/tests/test_serializer.py
@@ -0,0 +1,133 @@
+import pickle
+from functools import partial
+from io import BytesIO
+from io import StringIO
+
+import pytest
+
+from itsdangerous.exc import BadPayload
+from itsdangerous.exc import BadSignature
+from itsdangerous.serializer import Serializer
+
+
+def coerce_str(ref, s):
+ if not isinstance(s, type(ref)):
+ return s.encode("utf8")
+
+ return s
+
+
+class TestSerializer(object):
+ @pytest.fixture(params=(Serializer, partial(Serializer, serializer=pickle)))
+ def serializer_factory(self, request):
+ return partial(request.param, secret_key="secret_key")
+
+ @pytest.fixture()
+ def serializer(self, serializer_factory):
+ return serializer_factory()
+
+ @pytest.fixture()
+ def value(self):
+ return {"id": 42}
+
+ @pytest.mark.parametrize(
+ "value", (None, True, "str", u"text", [1, 2, 3], {"id": 42})
+ )
+ def test_serializer(self, serializer, value):
+ assert serializer.loads(serializer.dumps(value)) == value
+
+ @pytest.mark.parametrize(
+ "transform",
+ (
+ lambda s: s.upper(),
+ lambda s: s + coerce_str(s, "a"),
+ lambda s: coerce_str(s, "a") + s[1:],
+ lambda s: s.replace(coerce_str(s, "."), coerce_str(s, "")),
+ ),
+ )
+ def test_changed_value(self, serializer, value, transform):
+ signed = serializer.dumps(value)
+ assert serializer.loads(signed) == value
+ changed = transform(signed)
+
+ with pytest.raises(BadSignature):
+ serializer.loads(changed)
+
+ def test_bad_signature_exception(self, serializer, value):
+ bad_signed = serializer.dumps(value)[:-1]
+
+ with pytest.raises(BadSignature) as exc_info:
+ serializer.loads(bad_signed)
+
+ assert serializer.load_payload(exc_info.value.payload) == value
+
+ def test_bad_payload_exception(self, serializer, value):
+ original = serializer.dumps(value)
+ payload = original.rsplit(coerce_str(original, "."), 1)[0]
+ bad = serializer.make_signer().sign(payload[:-1])
+
+ with pytest.raises(BadPayload) as exc_info:
+ serializer.loads(bad)
+
+ assert exc_info.value.original_error is not None
+
+ def test_loads_unsafe(self, serializer, value):
+ signed = serializer.dumps(value)
+ assert serializer.loads_unsafe(signed) == (True, value)
+
+ bad_signed = signed[:-1]
+ assert serializer.loads_unsafe(bad_signed) == (False, value)
+
+ payload = signed.rsplit(coerce_str(signed, "."), 1)[0]
+ bad_payload = serializer.make_signer().sign(payload[:-1])[:-1]
+ assert serializer.loads_unsafe(bad_payload) == (False, None)
+
+ class BadUnsign(serializer.signer):
+ def unsign(self, signed_value, *args, **kwargs):
+ try:
+ return super(BadUnsign, self).unsign(signed_value, *args, **kwargs)
+ except BadSignature as e:
+ e.payload = None
+ raise
+
+ serializer.signer = BadUnsign
+ assert serializer.loads_unsafe(bad_signed) == (False, None)
+
+ def test_file(self, serializer, value):
+ f = BytesIO() if isinstance(serializer.dumps(value), bytes) else StringIO()
+ serializer.dump(value, f)
+ f.seek(0)
+ assert serializer.load(f) == value
+ f.seek(0)
+ assert serializer.load_unsafe(f) == (True, value)
+
+ def test_alt_salt(self, serializer, value):
+ signed = serializer.dumps(value, salt="other")
+
+ with pytest.raises(BadSignature):
+ serializer.loads(signed)
+
+ assert serializer.loads(signed, salt="other") == value
+
+ def test_signer_cls(self, serializer_factory, serializer, value):
+ class Other(serializer.signer):
+ default_key_derivation = "hmac"
+
+ other = serializer_factory(signer=Other)
+ assert other.loads(other.dumps(value)) == value
+ assert other.dumps(value) != serializer.dumps(value)
+
+ def test_signer_kwargs(self, serializer_factory, serializer, value):
+ other = serializer_factory(signer_kwargs={"key_derivation": "hmac"})
+ assert other.loads(other.dumps(value)) == value
+ assert other.dumps("value") != serializer.dumps("value")
+
+ def test_serializer_kwargs(self, serializer_factory):
+ serializer = serializer_factory(serializer_kwargs={"skipkeys": True})
+
+ try:
+ serializer.serializer.dumps(None, skipkeys=True)
+ except TypeError:
+ return
+
+ assert serializer.loads(serializer.dumps({(): 1})) == {}
diff --git a/tests/test_signer.py b/tests/test_signer.py
new file mode 100644
index 0000000..5f7fe8e
--- /dev/null
+++ b/tests/test_signer.py
@@ -0,0 +1,99 @@
+import hashlib
+from functools import partial
+
+import pytest
+
+from itsdangerous.exc import BadSignature
+from itsdangerous.signer import HMACAlgorithm
+from itsdangerous.signer import NoneAlgorithm
+from itsdangerous.signer import Signer
+from itsdangerous.signer import SigningAlgorithm
+
+
+class _ReverseAlgorithm(SigningAlgorithm):
+ def get_signature(self, key, value):
+ return (key + value)[::-1]
+
+
+class TestSigner(object):
+ @pytest.fixture()
+ def signer_factory(self):
+ return partial(Signer, secret_key="secret-key")
+
+ @pytest.fixture()
+ def signer(self, signer_factory):
+ return signer_factory()
+
+ def test_signer(self, signer):
+ signed = signer.sign("my string")
+ assert isinstance(signed, bytes)
+ assert signer.validate(signed)
+ out = signer.unsign(signed)
+ assert out == b"my string"
+
+ def test_no_separator(self, signer):
+ signed = signer.sign("my string")
+ signed = signed.replace(signer.sep, b"*", 1)
+ assert not signer.validate(signed)
+
+ with pytest.raises(BadSignature):
+ signer.unsign(signed)
+
+ def test_broken_signature(self, signer):
+ signed = signer.sign("b")
+ bad_signed = signed[:-1]
+ bad_sig = bad_signed.rsplit(b".", 1)[1]
+ assert not signer.verify_signature(b"b", bad_sig)
+
+ with pytest.raises(BadSignature) as exc_info:
+ signer.unsign(bad_signed)
+
+ assert exc_info.value.payload == b"b"
+
+ def test_changed_value(self, signer):
+ signed = signer.sign("my string")
+ signed = signed.replace(b"my", b"other", 1)
+ assert not signer.validate(signed)
+
+ with pytest.raises(BadSignature):
+ signer.unsign(signed)
+
+ def test_invalid_separator(self, signer_factory):
+ with pytest.raises(ValueError) as exc_info:
+ signer_factory(sep="-")
+
+ assert "separator cannot be used" in str(exc_info.value)
+
+ @pytest.mark.parametrize(
+ "key_derivation", ("concat", "django-concat", "hmac", "none")
+ )
+ def test_key_derivation(self, signer_factory, key_derivation):
+ signer = signer_factory(key_derivation=key_derivation)
+ assert signer.unsign(signer.sign("value")) == b"value"
+
+ def test_invalid_key_derivation(self, signer_factory):
+ signer = signer_factory(key_derivation="invalid")
+
+ with pytest.raises(TypeError):
+ signer.derive_key()
+
+ def test_digest_method(self, signer_factory):
+ signer = signer_factory(digest_method=hashlib.md5)
+ assert signer.unsign(signer.sign("value")) == b"value"
+
+ @pytest.mark.parametrize(
+ "algorithm", (None, NoneAlgorithm(), HMACAlgorithm(), _ReverseAlgorithm())
+ )
+ def test_algorithm(self, signer_factory, algorithm):
+ signer = signer_factory(algorithm=algorithm)
+ assert signer.unsign(signer.sign("value")) == b"value"
+
+ if algorithm is None:
+ assert signer.algorithm.digest_method == signer.digest_method
+
+
+def test_abstract_algorithm():
+ alg = SigningAlgorithm()
+
+ with pytest.raises(NotImplementedError):
+ alg.get_signature("a", "b")
diff --git a/tests/test_timed.py b/tests/test_timed.py
new file mode 100644
index 0000000..36c1a86
--- /dev/null
+++ b/tests/test_timed.py
@@ -0,0 +1,85 @@
+from datetime import datetime
+from functools import partial
+
+import pytest
+from freezegun import freeze_time
+from tests.test_serializer import TestSerializer
+from tests.test_signer import TestSigner
+
+from itsdangerous import Signer
+from itsdangerous.exc import BadTimeSignature
+from itsdangerous.exc import SignatureExpired
+from itsdangerous.timed import TimedSerializer
+from itsdangerous.timed import TimestampSigner
+
+
+class FreezeMixin(object):
+ @pytest.fixture()
+ def ts(self):
+ return datetime(2011, 6, 24, 0, 9, 5)
+
+ @pytest.fixture(autouse=True)
+ def freeze(self, ts):
+ with freeze_time(ts) as ft:
+ yield ft
+
+
+class TestTimestampSigner(FreezeMixin, TestSigner):
+ @pytest.fixture()
+ def signer_factory(self):
+ return partial(TimestampSigner, secret_key="secret-key")
+
+ def test_max_age(self, signer, ts, freeze):
+ signed = signer.sign("value")
+ freeze.tick()
+ assert signer.unsign(signed, max_age=10) == b"value"
+ freeze.tick(10)
+
+ with pytest.raises(SignatureExpired) as exc_info:
+ signer.unsign(signed, max_age=10)
+
+ assert exc_info.value.date_signed == ts
+
+ def test_return_timestamp(self, signer, ts):
+ signed = signer.sign("value")
+ assert signer.unsign(signed, return_timestamp=True) == (b"value", ts)
+
+ def test_timestamp_missing(self, signer):
+ other = Signer("secret-key")
+ signed = other.sign("value")
+
+ with pytest.raises(BadTimeSignature) as exc_info:
+ signer.unsign(signed)
+
+ assert "missing" in str(exc_info.value)
+
+ def test_malformed_timestamp(self, signer):
+ other = Signer("secret-key")
+ signed = other.sign(b"value.____________")
+
+ with pytest.raises(BadTimeSignature) as exc_info:
+ signer.unsign(signed)
+
+ assert "Malformed" in str(exc_info.value)
+
+
+class TestTimedSerializer(FreezeMixin, TestSerializer):
+ @pytest.fixture()
+ def serializer_factory(self):
+ return partial(TimedSerializer, secret_key="secret_key")
+
+ def test_max_age(self, serializer, value, ts, freeze):
+ signed = serializer.dumps(value)
+ freeze.tick()
+ assert serializer.loads(signed, max_age=10) == value
+ freeze.tick(10)
+
+ with pytest.raises(SignatureExpired) as exc_info:
+ serializer.loads(signed, max_age=10)
+
+ assert exc_info.value.date_signed == ts
+ assert serializer.load_payload(exc_info.value.payload) == value
+
+ def test_return_payload(self, serializer, value, ts):
+ signed = serializer.dumps(value)
+ assert serializer.loads(signed, return_timestamp=True) == (value, ts)
diff --git a/tests/test_url_safe.py b/tests/test_url_safe.py
new file mode 100644
index 0000000..5cb7f2c
--- /dev/null
+++ b/tests/test_url_safe.py
@@ -0,0 +1,24 @@
+from functools import partial
+
+import pytest
+from tests.test_serializer import TestSerializer
+from tests.test_timed import TestTimedSerializer
+
+from itsdangerous import URLSafeSerializer
+from itsdangerous import URLSafeTimedSerializer
+
+
+class TestURLSafeSerializer(TestSerializer):
+ @pytest.fixture()
+ def serializer_factory(self):
+ return partial(URLSafeSerializer, secret_key="secret-key")
+
+ @pytest.fixture(params=({"id": 42}, pytest.param("a" * 1000, id="zlib")))
+ def value(self, request):
+ return request.param
+
+
+class TestURLSafeTimedSerializer(TestURLSafeSerializer, TestTimedSerializer):
+ @pytest.fixture()
+ def serializer_factory(self):
+ return partial(URLSafeTimedSerializer, secret_key="secret-key")
diff --git a/tox.ini b/tox.ini
index a91897d..d64c2c9 100644
--- a/tox.ini
+++ b/tox.ini
@@ -9,7 +9,9 @@ skip_missing_interpreters = true
[testenv]
setenv =
COVERAGE_FILE = .coverage.{envname}
-deps = pytest-cov
+deps =
+ pytest-cov
+ freezegun
commands = pytest --cov --cov-report= {posargs}
[testenv:stylecheck]