diff options
author | Jon Dufresne <jon.dufresne@gmail.com> | 2020-12-17 12:09:56 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-12-17 15:09:56 -0500 |
commit | 97cd4031549d363fa34737635eab33491fa718fe (patch) | |
tree | e11c7c55e578c3ce17b7851ba92cc0f91d947e12 | |
parent | a1907c037a44e3dda54ea60163a15e11cfd81774 (diff) | |
download | pyjwt-97cd4031549d363fa34737635eab33491fa718fe.tar.gz |
Tighten bytes/str boundaries and remove unnecessary coercing (#547)
Now that the project is Python 3 only, the boundaries between bytes and
Unicode strings is more explicit.
This allows removing several unnecessary force_bytes() and
force_unicode() calls that handled differences between Python 2 and
Python 3. All uses of force_unicode() have been removed.
For values that are known to be bytes, use `.decode()` instead. For
values are that known to be str, use `.encode()` instead. This strategy
makes the type explicit and reduces a function call.
Key handling continues to use force_bytes() to allow callers to pass
either bytes or str.
To help enforce bytes/str handling in the future, the `-b` option is
passed to Python when testing. This option will emit a warning if bytes
and str are improperly mixed together.
-rw-r--r-- | jwt/algorithms.py | 27 | ||||
-rw-r--r-- | jwt/api_jws.py | 8 | ||||
-rw-r--r-- | jwt/utils.py | 9 | ||||
-rw-r--r-- | tests/test_api_jws.py | 8 | ||||
-rw-r--r-- | tests/test_utils.py | 12 | ||||
-rw-r--r-- | tox.ini | 2 |
6 files changed, 21 insertions, 45 deletions
diff --git a/jwt/algorithms.py b/jwt/algorithms.py index 2eb090c..c94f631 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -8,7 +8,6 @@ from .utils import ( base64url_encode, der_to_raw_signature, force_bytes, - force_unicode, from_base64url_uint, raw_to_der_signature, to_base64url_uint, @@ -194,7 +193,7 @@ class HMACAlgorithm(Algorithm): def to_jwk(key_obj): return json.dumps( { - "k": force_unicode(base64url_encode(force_bytes(key_obj))), + "k": base64url_encode(force_bytes(key_obj)).decode(), "kty": "oct", } ) @@ -260,18 +259,14 @@ if has_crypto: # noqa: C901 obj = { "kty": "RSA", "key_ops": ["sign"], - "n": force_unicode( - to_base64url_uint(numbers.public_numbers.n) - ), - "e": force_unicode( - to_base64url_uint(numbers.public_numbers.e) - ), - "d": force_unicode(to_base64url_uint(numbers.d)), - "p": force_unicode(to_base64url_uint(numbers.p)), - "q": force_unicode(to_base64url_uint(numbers.q)), - "dp": force_unicode(to_base64url_uint(numbers.dmp1)), - "dq": force_unicode(to_base64url_uint(numbers.dmq1)), - "qi": force_unicode(to_base64url_uint(numbers.iqmp)), + "n": to_base64url_uint(numbers.public_numbers.n).decode(), + "e": to_base64url_uint(numbers.public_numbers.e).decode(), + "d": to_base64url_uint(numbers.d).decode(), + "p": to_base64url_uint(numbers.p).decode(), + "q": to_base64url_uint(numbers.q).decode(), + "dp": to_base64url_uint(numbers.dmp1).decode(), + "dq": to_base64url_uint(numbers.dmq1).decode(), + "qi": to_base64url_uint(numbers.iqmp).decode(), } elif getattr(key_obj, "verify", None): @@ -281,8 +276,8 @@ if has_crypto: # noqa: C901 obj = { "kty": "RSA", "key_ops": ["verify"], - "n": force_unicode(to_base64url_uint(numbers.n)), - "e": force_unicode(to_base64url_uint(numbers.e)), + "n": to_base64url_uint(numbers.n).decode(), + "e": to_base64url_uint(numbers.e).decode(), } else: raise InvalidKeyError("Not a public or private key") diff --git a/jwt/api_jws.py b/jwt/api_jws.py index bce7a1a..cf0d8fb 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -11,7 +11,7 @@ from .exceptions import ( InvalidSignatureError, InvalidTokenError, ) -from .utils import base64url_decode, base64url_encode, force_bytes, merge_dict +from .utils import base64url_decode, base64url_encode, merge_dict class PyJWS: @@ -95,9 +95,9 @@ class PyJWS: self._validate_headers(headers) header.update(headers) - json_header = force_bytes( - json.dumps(header, separators=(",", ":"), cls=json_encoder) - ) + json_header = json.dumps( + header, separators=(",", ":"), cls=json_encoder + ).encode() segments.append(base64url_encode(json_header)) segments.append(base64url_encode(payload)) diff --git a/jwt/utils.py b/jwt/utils.py index 0e8f210..a617342 100644 --- a/jwt/utils.py +++ b/jwt/utils.py @@ -11,15 +11,6 @@ except ImportError: pass -def force_unicode(value): - if isinstance(value, bytes): - return value.decode("utf-8") - elif isinstance(value, str): - return value - else: - raise TypeError("Expected a string value") - - def force_bytes(value): if isinstance(value, str): return value.encode("utf-8") diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py index c06f7c5..a0456dc 100644 --- a/tests/test_api_jws.py +++ b/tests/test_api_jws.py @@ -11,7 +11,7 @@ from jwt.exceptions import ( InvalidSignatureError, InvalidTokenError, ) -from jwt.utils import base64url_decode, force_bytes +from jwt.utils import base64url_decode from .utils import key_path @@ -613,7 +613,7 @@ class TestJWS: # PEM-formatted EC key with open(key_path("testkey_ec.priv"), "rb") as ec_priv_file: priv_eckey = load_pem_private_key( - force_bytes(ec_priv_file.read()), password=None + ec_priv_file.read(), password=None ) jws_message = jws.encode(payload, priv_eckey, algorithm="ES384") @@ -637,7 +637,7 @@ class TestJWS: # PEM-formatted EC key with open(key_path("testkey_ec.priv"), "rb") as ec_priv_file: priv_eckey = load_pem_private_key( - force_bytes(ec_priv_file.read()), password=None + ec_priv_file.read(), password=None ) jws_message = jws.encode(payload, priv_eckey, algorithm="ES512") @@ -700,7 +700,7 @@ class TestJWS: payload, "secret", headers=data, json_encoder=CustomJSONEncoder ) - header = force_bytes(token.split(".")[0]) + header, *_ = token.split(".") header = json.loads(base64url_decode(header)) assert "some_decimal" in header diff --git a/tests/test_utils.py b/tests/test_utils.py index cb73c52..a089f86 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,11 +1,6 @@ import pytest -from jwt.utils import ( - force_bytes, - force_unicode, - from_base64url_uint, - to_base64url_uint, -) +from jwt.utils import force_bytes, from_base64url_uint, to_base64url_uint @pytest.mark.parametrize( @@ -39,11 +34,6 @@ def test_from_base64url_uint(inputval, expected): assert actual == expected -def test_force_unicode_raises_error_on_invalid_object(): - with pytest.raises(TypeError): - force_unicode({}) - - def test_force_bytes_raises_error_on_invalid_object(): with pytest.raises(TypeError): force_bytes({}) @@ -41,7 +41,7 @@ setenv = extras = tests py{36,37,38,39}-crypto-{linux,windows}: crypto -commands = coverage run -m pytest {posargs} +commands = {envpython} -b -m coverage run -m pytest {posargs} [testenv:docs] |