diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/__init__.py | 3 | ||||
-rw-r--r-- | tests/oauth1/rfc5849/endpoints/test_base.py | 13 | ||||
-rw-r--r-- | tests/oauth1/rfc5849/test_signatures.py | 39 | ||||
-rw-r--r-- | tests/oauth2/rfc6749/endpoints/test_base_endpoint.py | 2 | ||||
-rw-r--r-- | tests/oauth2/rfc6749/endpoints/test_error_responses.py | 55 | ||||
-rw-r--r-- | tests/oauth2/rfc6749/endpoints/test_introspect_endpoint.py | 30 | ||||
-rw-r--r-- | tests/oauth2/rfc6749/endpoints/test_revocation_endpoint.py | 29 | ||||
-rw-r--r-- | tests/oauth2/rfc6749/grant_types/test_authorization_code.py | 6 | ||||
-rw-r--r-- | tests/oauth2/rfc6749/test_parameters.py | 27 | ||||
-rw-r--r-- | tests/oauth2/rfc6749/test_tokens.py | 79 | ||||
-rw-r--r-- | tests/openid/connect/core/endpoints/test_claims_handling.py | 2 | ||||
-rw-r--r-- | tests/openid/connect/core/grant_types/test_authorization_code.py | 1 | ||||
-rw-r--r-- | tests/openid/connect/core/grant_types/test_base.py | 104 | ||||
-rw-r--r-- | tests/openid/connect/core/test_request_validator.py | 6 | ||||
-rw-r--r-- | tests/openid/connect/core/test_tokens.py | 6 | ||||
-rw-r--r-- | tests/test_common.py | 16 |
16 files changed, 386 insertions, 32 deletions
diff --git a/tests/__init__.py b/tests/__init__.py index e69de29..f33236b 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,3 @@ +import oauthlib + +oauthlib.set_debug(True) diff --git a/tests/oauth1/rfc5849/endpoints/test_base.py b/tests/oauth1/rfc5849/endpoints/test_base.py index 60f7860..795ddee 100644 --- a/tests/oauth1/rfc5849/endpoints/test_base.py +++ b/tests/oauth1/rfc5849/endpoints/test_base.py @@ -4,7 +4,7 @@ from re import sub from mock import MagicMock -from oauthlib.common import safe_string_equals +from oauthlib.common import CaseInsensitiveDict, safe_string_equals from oauthlib.oauth1 import Client, RequestValidator from oauthlib.oauth1.rfc5849 import (SIGNATURE_HMAC, SIGNATURE_PLAINTEXT, SIGNATURE_RSA, errors) @@ -179,6 +179,17 @@ class BaseEndpointTest(TestCase): self.assertRaises(errors.InvalidRequestError, e._check_mandatory_parameters, r) + def test_case_insensitive_headers(self): + """Ensure headers are case-insensitive""" + v = RequestValidator() + e = BaseEndpoint(v) + r = e._create_request('https://a.b', 'POST', + ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&' + 'oauth_version=1.0&oauth_signature_method=RSA-SHA1&' + 'oauth_timestamp=123456789a'), + URLENCODED) + self.assertIsInstance(r.headers, CaseInsensitiveDict) + def test_signature_method_validation(self): """Ensure valid signature method is used.""" diff --git a/tests/oauth1/rfc5849/test_signatures.py b/tests/oauth1/rfc5849/test_signatures.py index 48609e5..bb0dc78 100644 --- a/tests/oauth1/rfc5849/test_signatures.py +++ b/tests/oauth1/rfc5849/test_signatures.py @@ -3,8 +3,8 @@ from __future__ import absolute_import, unicode_literals from oauthlib.common import unicode_type from oauthlib.oauth1.rfc5849.signature import (collect_parameters, - construct_base_string, - normalize_base_string_uri, + signature_base_string, + base_string_uri, normalize_parameters, sign_hmac_sha1, sign_hmac_sha1_with_client, @@ -79,7 +79,7 @@ class SignatureTests(TestCase): resource_owner_secret = self.resource_owner_secret ) - def test_construct_base_string(self): + def test_signature_base_string(self): """ Example text to be turned into a base string:: @@ -104,20 +104,20 @@ class SignatureTests(TestCase): D%2522137131201%2522%252Coauth_nonce%253D%25227d8f3e4a%2522%252Coau th_signature%253D%2522bYT5CMsGcbgUdFHObYMEfcx6bsw%25253D%2522 """ - self.assertRaises(ValueError, construct_base_string, + self.assertRaises(ValueError, signature_base_string, self.http_method, self.base_string_url, self.normalized_encoded_request_parameters) - self.assertRaises(ValueError, construct_base_string, + self.assertRaises(ValueError, signature_base_string, self.http_method.decode('utf-8'), self.base_string_url, self.normalized_encoded_request_parameters) - self.assertRaises(ValueError, construct_base_string, + self.assertRaises(ValueError, signature_base_string, self.http_method.decode('utf-8'), self.base_string_url.decode('utf-8'), self.normalized_encoded_request_parameters) - base_string = construct_base_string( + base_string = signature_base_string( self.http_method.decode('utf-8'), self.base_string_url.decode('utf-8'), self.normalized_encoded_request_parameters.decode('utf-8') @@ -125,7 +125,7 @@ class SignatureTests(TestCase): self.assertEqual(self.control_base_string, base_string) - def test_normalize_base_string_uri(self): + def test_base_string_uri(self): """ Example text to be turned into a normalized base string uri:: @@ -137,33 +137,44 @@ class SignatureTests(TestCase): https://www.example.net:8080/ """ + # test first example from RFC 5849 section 3.4.1.2. + # Note: there is a space between "r" and "v" + uri = 'http://EXAMPLE.COM:80/r v/X?id=123' + self.assertEqual(base_string_uri(uri), + 'http://example.com/r%20v/X') + + # test second example from RFC 5849 section 3.4.1.2. + uri = 'https://www.example.net:8080/?q=1' + self.assertEqual(base_string_uri(uri), + 'https://www.example.net:8080/') + # test for unicode failure uri = b"www.example.com:8080" - self.assertRaises(ValueError, normalize_base_string_uri, uri) + self.assertRaises(ValueError, base_string_uri, uri) # test for missing scheme uri = "www.example.com:8080" - self.assertRaises(ValueError, normalize_base_string_uri, uri) + self.assertRaises(ValueError, base_string_uri, uri) # test a URI with the default port uri = "http://www.example.com:80/" - self.assertEqual(normalize_base_string_uri(uri), + self.assertEqual(base_string_uri(uri), "http://www.example.com/") # test a URI missing a path uri = "http://www.example.com" - self.assertEqual(normalize_base_string_uri(uri), + self.assertEqual(base_string_uri(uri), "http://www.example.com/") # test a relative URI uri = "/a-host-relative-uri" host = "www.example.com" - self.assertRaises(ValueError, normalize_base_string_uri, (uri, host)) + self.assertRaises(ValueError, base_string_uri, (uri, host)) # test overriding the URI's netloc with a host argument uri = "http://www.example.com/a-path" host = "alternatehost.example.com" - self.assertEqual(normalize_base_string_uri(uri, host), + self.assertEqual(base_string_uri(uri, host), "http://alternatehost.example.com/a-path") def test_collect_parameters(self): diff --git a/tests/oauth2/rfc6749/endpoints/test_base_endpoint.py b/tests/oauth2/rfc6749/endpoints/test_base_endpoint.py index 4f78d9b..bf04a42 100644 --- a/tests/oauth2/rfc6749/endpoints/test_base_endpoint.py +++ b/tests/oauth2/rfc6749/endpoints/test_base_endpoint.py @@ -25,7 +25,7 @@ class BaseEndpointTest(TestCase): server = Server(validator) server.catch_errors = True h, b, s = server.create_token_response( - 'https://example.com?grant_type=authorization_code&code=abc' + 'https://example.com', body='grant_type=authorization_code&code=abc' ) self.assertIn("server_error", b) self.assertEqual(s, 500) diff --git a/tests/oauth2/rfc6749/endpoints/test_error_responses.py b/tests/oauth2/rfc6749/endpoints/test_error_responses.py index a249cb1..2479836 100644 --- a/tests/oauth2/rfc6749/endpoints/test_error_responses.py +++ b/tests/oauth2/rfc6749/endpoints/test_error_responses.py @@ -6,11 +6,11 @@ import json import mock +from oauthlib.common import urlencode from oauthlib.oauth2 import (BackendApplicationServer, LegacyApplicationServer, MobileApplicationServer, RequestValidator, WebApplicationServer) from oauthlib.oauth2.rfc6749 import errors - from ....unittest import TestCase @@ -437,3 +437,56 @@ class ErrorResponseTest(TestCase): _, body, _ = self.backend.create_token_response('https://i.b/token', body='grant_type=bar') self.assertEqual('unsupported_grant_type', json.loads(body)['error']) + + def test_invalid_request_method(self): + test_methods = ['GET', 'pUt', 'dEleTe', 'paTcH'] + test_methods = test_methods + [x.lower() for x in test_methods] + [x.upper() for x in test_methods] + for method in test_methods: + self.validator.authenticate_client.side_effect = self.set_client + + uri = "http://i/b/token/" + try: + _, body, s = self.web.create_token_response(uri, + body='grant_type=access_token&code=123', http_method=method) + self.fail('This should have failed with InvalidRequestError') + except errors.InvalidRequestError as ire: + self.assertIn('Unsupported request method', ire.description) + + try: + _, body, s = self.legacy.create_token_response(uri, + body='grant_type=access_token&code=123', http_method=method) + self.fail('This should have failed with InvalidRequestError') + except errors.InvalidRequestError as ire: + self.assertIn('Unsupported request method', ire.description) + + try: + _, body, s = self.backend.create_token_response(uri, + body='grant_type=access_token&code=123', http_method=method) + self.fail('This should have failed with InvalidRequestError') + except errors.InvalidRequestError as ire: + self.assertIn('Unsupported request method', ire.description) + + def test_invalid_post_request(self): + self.validator.authenticate_client.side_effect = self.set_client + for param in ['token', 'secret', 'code', 'foo']: + uri = 'https://i/b/token?' + urlencode([(param, 'secret')]) + try: + _, body, s = self.web.create_token_response(uri, + body='grant_type=access_token&code=123') + self.fail('This should have failed with InvalidRequestError') + except errors.InvalidRequestError as ire: + self.assertIn('URL query parameters are not allowed', ire.description) + + try: + _, body, s = self.legacy.create_token_response(uri, + body='grant_type=access_token&code=123') + self.fail('This should have failed with InvalidRequestError') + except errors.InvalidRequestError as ire: + self.assertIn('URL query parameters are not allowed', ire.description) + + try: + _, body, s = self.backend.create_token_response(uri, + body='grant_type=access_token&code=123') + self.fail('This should have failed with InvalidRequestError') + except errors.InvalidRequestError as ire: + self.assertIn('URL query parameters are not allowed', ire.description) diff --git a/tests/oauth2/rfc6749/endpoints/test_introspect_endpoint.py b/tests/oauth2/rfc6749/endpoints/test_introspect_endpoint.py index b9bf76a..ae3deae 100644 --- a/tests/oauth2/rfc6749/endpoints/test_introspect_endpoint.py +++ b/tests/oauth2/rfc6749/endpoints/test_introspect_endpoint.py @@ -139,3 +139,33 @@ class IntrospectEndpointTest(TestCase): self.assertEqual(h, self.resp_h) self.assertEqual(loads(b)['error'], 'invalid_request') self.assertEqual(s, 400) + + def test_introspect_invalid_request_method(self): + endpoint = IntrospectEndpoint(self.validator, + supported_token_types=['access_token']) + test_methods = ['GET', 'pUt', 'dEleTe', 'paTcH'] + test_methods = test_methods + [x.lower() for x in test_methods] + [x.upper() for x in test_methods] + for method in test_methods: + body = urlencode([('token', 'foo'), + ('token_type_hint', 'refresh_token')]) + h, b, s = endpoint.create_introspect_response(self.uri, + http_method = method, headers=self.headers, body=body) + self.assertEqual(h, self.resp_h) + self.assertEqual(loads(b)['error'], 'invalid_request') + self.assertIn('Unsupported request method', loads(b)['error_description']) + self.assertEqual(s, 400) + + def test_introspect_bad_post_request(self): + endpoint = IntrospectEndpoint(self.validator, + supported_token_types=['access_token']) + for param in ['token', 'secret', 'code', 'foo']: + uri = 'http://some.endpoint?' + urlencode([(param, 'secret')]) + body = urlencode([('token', 'foo'), + ('token_type_hint', 'access_token')]) + h, b, s = endpoint.create_introspect_response( + uri, + headers=self.headers, body=body) + self.assertEqual(h, self.resp_h) + self.assertEqual(loads(b)['error'], 'invalid_request') + self.assertIn('query parameters are not allowed', loads(b)['error_description']) + self.assertEqual(s, 400) diff --git a/tests/oauth2/rfc6749/endpoints/test_revocation_endpoint.py b/tests/oauth2/rfc6749/endpoints/test_revocation_endpoint.py index 2a24177..17be3a5 100644 --- a/tests/oauth2/rfc6749/endpoints/test_revocation_endpoint.py +++ b/tests/oauth2/rfc6749/endpoints/test_revocation_endpoint.py @@ -120,3 +120,32 @@ class RevocationEndpointTest(TestCase): self.assertEqual(h, self.resp_h) self.assertEqual(loads(b)['error'], 'invalid_request') self.assertEqual(s, 400) + + def test_revoke_invalid_request_method(self): + endpoint = RevocationEndpoint(self.validator, + supported_token_types=['access_token']) + test_methods = ['GET', 'pUt', 'dEleTe', 'paTcH'] + test_methods = test_methods + [x.lower() for x in test_methods] + [x.upper() for x in test_methods] + for method in test_methods: + body = urlencode([('token', 'foo'), + ('token_type_hint', 'refresh_token')]) + h, b, s = endpoint.create_revocation_response(self.uri, + http_method = method, headers=self.headers, body=body) + self.assertEqual(h, self.resp_h) + self.assertEqual(loads(b)['error'], 'invalid_request') + self.assertIn('Unsupported request method', loads(b)['error_description']) + self.assertEqual(s, 400) + + def test_revoke_bad_post_request(self): + endpoint = RevocationEndpoint(self.validator, + supported_token_types=['access_token']) + for param in ['token', 'secret', 'code', 'foo']: + uri = 'http://some.endpoint?' + urlencode([(param, 'secret')]) + body = urlencode([('token', 'foo'), + ('token_type_hint', 'access_token')]) + h, b, s = endpoint.create_revocation_response(uri, + headers=self.headers, body=body) + self.assertEqual(h, self.resp_h) + self.assertEqual(loads(b)['error'], 'invalid_request') + self.assertIn('query parameters are not allowed', loads(b)['error_description']) + self.assertEqual(s, 400) diff --git a/tests/oauth2/rfc6749/grant_types/test_authorization_code.py b/tests/oauth2/rfc6749/grant_types/test_authorization_code.py index 00e2b6d..2c9db3c 100644 --- a/tests/oauth2/rfc6749/grant_types/test_authorization_code.py +++ b/tests/oauth2/rfc6749/grant_types/test_authorization_code.py @@ -215,8 +215,10 @@ class AuthorizationCodeGrantTest(TestCase): self.mock_validator.is_pkce_required.return_value = required self.request.code_challenge = "present" _, ri = self.auth.validate_authorization_request(self.request) - self.assertIsNotNone(ri["request"].code_challenge_method) - self.assertEqual(ri["request"].code_challenge_method, "plain") + self.assertIn("code_challenge", ri) + self.assertIn("code_challenge_method", ri) + self.assertEqual(ri["code_challenge"], "present") + self.assertEqual(ri["code_challenge_method"], "plain") def test_pkce_wrong_method(self): for required in [True, False]: diff --git a/tests/oauth2/rfc6749/test_parameters.py b/tests/oauth2/rfc6749/test_parameters.py index c42f516..48b7eac 100644 --- a/tests/oauth2/rfc6749/test_parameters.py +++ b/tests/oauth2/rfc6749/test_parameters.py @@ -73,7 +73,8 @@ class ParameterTests(TestCase): error_nocode = 'https://client.example.com/cb?state=xyz' error_nostate = 'https://client.example.com/cb?code=SplxlOBeZQQYbYS6WxSbIA' error_wrongstate = 'https://client.example.com/cb?code=SplxlOBeZQQYbYS6WxSbIA&state=abc' - error_response = 'https://client.example.com/cb?error=access_denied&state=xyz' + error_denied = 'https://client.example.com/cb?error=access_denied&state=xyz' + error_invalid = 'https://client.example.com/cb?error=invalid_request&state=xyz' implicit_base = 'https://example.com/cb#access_token=2YotnFZFEjr1zCsicMWpAA&scope=abc&' implicit_response = implicit_base + 'state={0}&token_type=example&expires_in=3600'.format(state) @@ -102,6 +103,15 @@ class ParameterTests(TestCase): ' "expires_in": 3600,' ' "refresh_token": "tGzv3JOkF0XG5Qx2TlKWIA",' ' "example_parameter": "example_value" }') + json_response_noexpire = ('{ "access_token": "2YotnFZFEjr1zCsicMWpAA",' + ' "token_type": "example",' + ' "refresh_token": "tGzv3JOkF0XG5Qx2TlKWIA",' + ' "example_parameter": "example_value"}') + json_response_expirenull = ('{ "access_token": "2YotnFZFEjr1zCsicMWpAA",' + ' "token_type": "example",' + ' "expires_in": null,' + ' "refresh_token": "tGzv3JOkF0XG5Qx2TlKWIA",' + ' "example_parameter": "example_value"}') json_custom_error = '{ "error": "incorrect_client_credentials" }' json_error = '{ "error": "access_denied" }' @@ -135,6 +145,13 @@ class ParameterTests(TestCase): 'example_parameter': 'example_value' } + json_noexpire_dict = { + 'access_token': '2YotnFZFEjr1zCsicMWpAA', + 'token_type': 'example', + 'refresh_token': 'tGzv3JOkF0XG5Qx2TlKWIA', + 'example_parameter': 'example_value' + } + json_notype_dict = { 'access_token': '2YotnFZFEjr1zCsicMWpAA', 'expires_in': 3600, @@ -180,8 +197,10 @@ class ParameterTests(TestCase): self.assertRaises(MissingCodeError, parse_authorization_code_response, self.error_nocode) - self.assertRaises(MissingCodeError, parse_authorization_code_response, - self.error_response) + self.assertRaises(AccessDeniedError, parse_authorization_code_response, + self.error_denied) + self.assertRaises(InvalidRequestFatalError, parse_authorization_code_response, + self.error_invalid) self.assertRaises(MismatchingStateError, parse_authorization_code_response, self.error_nostate, state=self.state) self.assertRaises(MismatchingStateError, parse_authorization_code_response, @@ -209,6 +228,8 @@ class ParameterTests(TestCase): self.assertEqual(parse_token_response(self.json_response_noscope, scope=['all', 'the', 'scopes']), self.json_noscope_dict) + self.assertEqual(parse_token_response(self.json_response_noexpire), self.json_noexpire_dict) + self.assertEqual(parse_token_response(self.json_response_expirenull), self.json_noexpire_dict) scope_changes_recorded = [] def record_scope_change(sender, message, old, new): diff --git a/tests/oauth2/rfc6749/test_tokens.py b/tests/oauth2/rfc6749/test_tokens.py index 061754f..e6f49b1 100644 --- a/tests/oauth2/rfc6749/test_tokens.py +++ b/tests/oauth2/rfc6749/test_tokens.py @@ -1,10 +1,14 @@ from __future__ import absolute_import, unicode_literals +import mock + +from oauthlib.common import Request from oauthlib.oauth2.rfc6749.tokens import ( - prepare_mac_header, - prepare_bearer_headers, + BearerToken, prepare_bearer_body, + prepare_bearer_headers, prepare_bearer_uri, + prepare_mac_header, ) from ...unittest import TestCase @@ -64,6 +68,7 @@ class TokenTest(TestCase): bearer_headers = { 'Authorization': 'Bearer vF9dft4qmT' } + valid_bearer_header_lowercase = {"Authorization": "bearer vF9dft4qmT"} fake_bearer_headers = [ {'Authorization': 'Beaver vF9dft4qmT'}, {'Authorization': 'BeavervF9dft4qmT'}, @@ -98,3 +103,73 @@ class TokenTest(TestCase): self.assertEqual(prepare_bearer_headers(self.token), self.bearer_headers) self.assertEqual(prepare_bearer_body(self.token), self.bearer_body) self.assertEqual(prepare_bearer_uri(self.token, uri=self.uri), self.bearer_uri) + + def test_valid_bearer_is_validated(self): + request_validator = mock.MagicMock() + request_validator.validate_bearer_token = self._mocked_validate_bearer_token + + request = Request("/", headers=self.bearer_headers) + result = BearerToken(request_validator=request_validator).validate_request( + request + ) + self.assertTrue(result) + + def test_lowercase_bearer_is_validated(self): + request_validator = mock.MagicMock() + request_validator.validate_bearer_token = self._mocked_validate_bearer_token + + request = Request("/", headers=self.valid_bearer_header_lowercase) + result = BearerToken(request_validator=request_validator).validate_request( + request + ) + self.assertTrue(result) + + def test_fake_bearer_is_not_validated(self): + request_validator = mock.MagicMock() + request_validator.validate_bearer_token = self._mocked_validate_bearer_token + + for fake_header in self.fake_bearer_headers: + request = Request("/", headers=fake_header) + result = BearerToken(request_validator=request_validator).validate_request( + request + ) + + self.assertFalse(result) + + def test_header_with_multispaces_is_validated(self): + request_validator = mock.MagicMock() + request_validator.validate_bearer_token = self._mocked_validate_bearer_token + + request = Request("/", headers=self.valid_header_with_multiple_spaces) + result = BearerToken(request_validator=request_validator).validate_request( + request + ) + + self.assertTrue(result) + + def test_estimate_type(self): + request_validator = mock.MagicMock() + request_validator.validate_bearer_token = self._mocked_validate_bearer_token + request = Request("/", headers=self.bearer_headers) + result = BearerToken(request_validator=request_validator).estimate_type(request) + self.assertEqual(result, 9) + + def test_estimate_type_with_fake_header_returns_type_0(self): + request_validator = mock.MagicMock() + request_validator.validate_bearer_token = self._mocked_validate_bearer_token + + for fake_header in self.fake_bearer_headers: + request = Request("/", headers=fake_header) + result = BearerToken(request_validator=request_validator).estimate_type( + request + ) + + if ( + fake_header["Authorization"].count(" ") == 2 + and fake_header["Authorization"].split()[0] == "Bearer" + ): + # If we're dealing with the header containing 2 spaces, it will be recognized + # as a Bearer valid header, the token itself will be invalid by the way. + self.assertEqual(result, 9) + else: + self.assertEqual(result, 0) diff --git a/tests/openid/connect/core/endpoints/test_claims_handling.py b/tests/openid/connect/core/endpoints/test_claims_handling.py index 270ef69..5f39d96 100644 --- a/tests/openid/connect/core/endpoints/test_claims_handling.py +++ b/tests/openid/connect/core/endpoints/test_claims_handling.py @@ -10,7 +10,7 @@ from __future__ import absolute_import, unicode_literals import mock -from oauthlib.oauth2 import RequestValidator +from oauthlib.openid import RequestValidator from oauthlib.openid.connect.core.endpoints.pre_configured import Server from tests.unittest import TestCase diff --git a/tests/openid/connect/core/grant_types/test_authorization_code.py b/tests/openid/connect/core/grant_types/test_authorization_code.py index 89401ab..b721a19 100644 --- a/tests/openid/connect/core/grant_types/test_authorization_code.py +++ b/tests/openid/connect/core/grant_types/test_authorization_code.py @@ -117,7 +117,6 @@ class OpenIDAuthCodeTest(TestCase): def set_scopes(self, client_id, code, client, request): request.scopes = self.request.scopes - request.state = self.request.state request.user = 'bob' return True diff --git a/tests/openid/connect/core/grant_types/test_base.py b/tests/openid/connect/core/grant_types/test_base.py new file mode 100644 index 0000000..d506b7e --- /dev/null +++ b/tests/openid/connect/core/grant_types/test_base.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +import mock +import time + +from oauthlib.common import Request +from oauthlib.openid.connect.core.grant_types.base import GrantTypeBase + +from tests.unittest import TestCase + + +class GrantBase(GrantTypeBase): + """Class to test GrantTypeBase""" + def __init__(self, request_validator=None, **kwargs): + self.request_validator = request_validator + + +class IDTokenTest(TestCase): + + def setUp(self): + self.request = Request('http://a.b/path') + self.request.scopes = ('hello', 'openid') + self.request.expires_in = 1800 + self.request.client_id = 'abcdef' + self.request.code = '1234' + self.request.response_type = 'id_token' + self.request.grant_type = 'authorization_code' + self.request.redirect_uri = 'https://a.b/cb' + self.request.state = 'abc' + self.request.nonce = None + + self.mock_validator = mock.MagicMock() + self.mock_validator.get_id_token.return_value = None + self.mock_validator.finalize_id_token.return_value = "eyJ.body.signature" + self.token = {} + + self.grant = GrantBase(request_validator=self.mock_validator) + + self.url_query = 'https://a.b/cb?code=abc&state=abc' + self.url_fragment = 'https://a.b/cb#code=abc&state=abc' + + def test_id_token_hash(self): + self.assertEqual(self.grant.id_token_hash( + "Qcb0Orv1zh30vL1MPRsbm-diHiMwcLyZvn1arpZv-Jxf_11jnpEX3Tgfvk", + ), "LDktKdoQak3Pk0cnXxCltA", "hash differs from RFC") + + def test_get_id_token_no_openid(self): + self.request.scopes = ('hello') + token = self.grant.add_id_token(self.token, "token_handler_mock", self.request) + self.assertNotIn("id_token", token) + + self.request.scopes = None + token = self.grant.add_id_token(self.token, "token_handler_mock", self.request) + self.assertNotIn("id_token", token) + + self.request.scopes = () + token = self.grant.add_id_token(self.token, "token_handler_mock", self.request) + self.assertNotIn("id_token", token) + + def test_get_id_token(self): + self.mock_validator.get_id_token.return_value = "toto" + token = self.grant.add_id_token(self.token, "token_handler_mock", self.request) + self.assertIn("id_token", token) + self.assertEqual(token["id_token"], "toto") + + def test_finalize_id_token(self): + token = self.grant.add_id_token(self.token, "token_handler_mock", self.request) + self.assertIn("id_token", token) + self.assertEqual(token["id_token"], "eyJ.body.signature") + id_token = self.mock_validator.finalize_id_token.call_args[0][0] + self.assertEqual(id_token['aud'], 'abcdef') + self.assertGreaterEqual(int(time.time()), id_token['iat']) + + def test_finalize_id_token_with_nonce(self): + token = self.grant.add_id_token(self.token, "token_handler_mock", self.request, "my_nonce") + self.assertIn("id_token", token) + self.assertEqual(token["id_token"], "eyJ.body.signature") + id_token = self.mock_validator.finalize_id_token.call_args[0][0] + self.assertEqual(id_token['nonce'], 'my_nonce') + + def test_finalize_id_token_with_at_hash(self): + self.token["access_token"] = "Qcb0Orv1zh30vL1MPRsbm-diHiMwcLyZvn1arpZv-Jxf_11jnpEX3Tgfvk" + token = self.grant.add_id_token(self.token, "token_handler_mock", self.request) + self.assertIn("id_token", token) + self.assertEqual(token["id_token"], "eyJ.body.signature") + id_token = self.mock_validator.finalize_id_token.call_args[0][0] + self.assertEqual(id_token['at_hash'], 'LDktKdoQak3Pk0cnXxCltA') + + def test_finalize_id_token_with_c_hash(self): + self.token["code"] = "Qcb0Orv1zh30vL1MPRsbm-diHiMwcLyZvn1arpZv-Jxf_11jnpEX3Tgfvk" + token = self.grant.add_id_token(self.token, "token_handler_mock", self.request) + self.assertIn("id_token", token) + self.assertEqual(token["id_token"], "eyJ.body.signature") + id_token = self.mock_validator.finalize_id_token.call_args[0][0] + self.assertEqual(id_token['c_hash'], 'LDktKdoQak3Pk0cnXxCltA') + + def test_finalize_id_token_with_c_and_at_hash(self): + self.token["code"] = "Qcb0Orv1zh30vL1MPRsbm-diHiMwcLyZvn1arpZv-Jxf_11jnpEX3Tgfvk" + self.token["access_token"] = "Qcb0Orv1zh30vL1MPRsbm-diHiMwcLyZvn1arpZv-Jxf_11jnpEX3Tgfvk" + token = self.grant.add_id_token(self.token, "token_handler_mock", self.request) + self.assertIn("id_token", token) + self.assertEqual(token["id_token"], "eyJ.body.signature") + id_token = self.mock_validator.finalize_id_token.call_args[0][0] + self.assertEqual(id_token['at_hash'], 'LDktKdoQak3Pk0cnXxCltA') + self.assertEqual(id_token['c_hash'], 'LDktKdoQak3Pk0cnXxCltA') diff --git a/tests/openid/connect/core/test_request_validator.py b/tests/openid/connect/core/test_request_validator.py index 1e71fb1..ebe0aeb 100644 --- a/tests/openid/connect/core/test_request_validator.py +++ b/tests/openid/connect/core/test_request_validator.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, unicode_literals -from oauthlib.openid.connect.core.request_validator import RequestValidator +from oauthlib.openid import RequestValidator from tests.unittest import TestCase @@ -22,8 +22,8 @@ class RequestValidatorTest(TestCase): ) self.assertRaises( NotImplementedError, - v.get_id_token, - 'token', 'token_handler', 'request' + v.finalize_id_token, + 'id_token', 'token', 'token_handler', 'request' ) self.assertRaises( NotImplementedError, diff --git a/tests/openid/connect/core/test_tokens.py b/tests/openid/connect/core/test_tokens.py index 1fcfb51..fde89d6 100644 --- a/tests/openid/connect/core/test_tokens.py +++ b/tests/openid/connect/core/test_tokens.py @@ -42,7 +42,7 @@ class JWTTokenTestCase(TestCase): """ request_mock = mock.MagicMock() - with mock.patch('oauthlib.oauth2.rfc6749.request_validator.RequestValidator', + with mock.patch('oauthlib.openid.RequestValidator', autospec=True) as RequestValidatorMock: request_validator = RequestValidatorMock() @@ -58,7 +58,7 @@ class JWTTokenTestCase(TestCase): """ with mock.patch('oauthlib.common.Request', autospec=True) as RequestMock, \ - mock.patch('oauthlib.oauth2.rfc6749.request_validator.RequestValidator', + mock.patch('oauthlib.openid.RequestValidator', autospec=True) as RequestValidatorMock: request_validator_mock = RequestValidatorMock() @@ -84,7 +84,7 @@ class JWTTokenTestCase(TestCase): """ with mock.patch('oauthlib.common.Request', autospec=True) as RequestMock, \ - mock.patch('oauthlib.oauth2.rfc6749.request_validator.RequestValidator', + mock.patch('oauthlib.openid.RequestValidator', autospec=True) as RequestValidatorMock: request_validator_mock = RequestValidatorMock() diff --git a/tests/test_common.py b/tests/test_common.py index 20d9f5b..ae2531b 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,8 +1,10 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, unicode_literals +import os import sys +import oauthlib from oauthlib.common import (CaseInsensitiveDict, Request, add_params_to_uri, extract_params, generate_client_id, generate_nonce, generate_timestamp, @@ -214,6 +216,20 @@ class RequestTest(TestCase): self.assertEqual(r.headers['token'], 'foobar') self.assertEqual(r.token, 'banana') + def test_sanitized_request_non_debug_mode(self): + """make sure requests are sanitized when in non debug mode. + For the debug mode, the other tests checking sanitization should prove + that debug mode is working. + """ + try: + oauthlib.set_debug(False) + r = Request(URI, headers={'token': 'foobar'}, body='token=banana') + self.assertNotIn('token', repr(r)) + self.assertIn('SANITIZED', repr(r)) + finally: + # set flag back for other tests + oauthlib.set_debug(True) + class CaseInsensitiveDictTest(TestCase): |