summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/__init__.py3
-rw-r--r--tests/oauth1/rfc5849/endpoints/test_base.py13
-rw-r--r--tests/oauth1/rfc5849/test_signatures.py39
-rw-r--r--tests/oauth2/rfc6749/endpoints/test_base_endpoint.py2
-rw-r--r--tests/oauth2/rfc6749/endpoints/test_error_responses.py55
-rw-r--r--tests/oauth2/rfc6749/endpoints/test_introspect_endpoint.py30
-rw-r--r--tests/oauth2/rfc6749/endpoints/test_revocation_endpoint.py29
-rw-r--r--tests/oauth2/rfc6749/grant_types/test_authorization_code.py6
-rw-r--r--tests/oauth2/rfc6749/test_parameters.py27
-rw-r--r--tests/oauth2/rfc6749/test_tokens.py79
-rw-r--r--tests/openid/connect/core/endpoints/test_claims_handling.py2
-rw-r--r--tests/openid/connect/core/grant_types/test_authorization_code.py1
-rw-r--r--tests/openid/connect/core/grant_types/test_base.py104
-rw-r--r--tests/openid/connect/core/test_request_validator.py6
-rw-r--r--tests/openid/connect/core/test_tokens.py6
-rw-r--r--tests/test_common.py16
16 files changed, 386 insertions, 32 deletions
diff --git a/tests/__init__.py b/tests/__init__.py
index e69de29..f33236b 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -0,0 +1,3 @@
+import oauthlib
+
+oauthlib.set_debug(True)
diff --git a/tests/oauth1/rfc5849/endpoints/test_base.py b/tests/oauth1/rfc5849/endpoints/test_base.py
index 60f7860..795ddee 100644
--- a/tests/oauth1/rfc5849/endpoints/test_base.py
+++ b/tests/oauth1/rfc5849/endpoints/test_base.py
@@ -4,7 +4,7 @@ from re import sub
from mock import MagicMock
-from oauthlib.common import safe_string_equals
+from oauthlib.common import CaseInsensitiveDict, safe_string_equals
from oauthlib.oauth1 import Client, RequestValidator
from oauthlib.oauth1.rfc5849 import (SIGNATURE_HMAC, SIGNATURE_PLAINTEXT,
SIGNATURE_RSA, errors)
@@ -179,6 +179,17 @@ class BaseEndpointTest(TestCase):
self.assertRaises(errors.InvalidRequestError,
e._check_mandatory_parameters, r)
+ def test_case_insensitive_headers(self):
+ """Ensure headers are case-insensitive"""
+ v = RequestValidator()
+ e = BaseEndpoint(v)
+ r = e._create_request('https://a.b', 'POST',
+ ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
+ 'oauth_version=1.0&oauth_signature_method=RSA-SHA1&'
+ 'oauth_timestamp=123456789a'),
+ URLENCODED)
+ self.assertIsInstance(r.headers, CaseInsensitiveDict)
+
def test_signature_method_validation(self):
"""Ensure valid signature method is used."""
diff --git a/tests/oauth1/rfc5849/test_signatures.py b/tests/oauth1/rfc5849/test_signatures.py
index 48609e5..bb0dc78 100644
--- a/tests/oauth1/rfc5849/test_signatures.py
+++ b/tests/oauth1/rfc5849/test_signatures.py
@@ -3,8 +3,8 @@ from __future__ import absolute_import, unicode_literals
from oauthlib.common import unicode_type
from oauthlib.oauth1.rfc5849.signature import (collect_parameters,
- construct_base_string,
- normalize_base_string_uri,
+ signature_base_string,
+ base_string_uri,
normalize_parameters,
sign_hmac_sha1,
sign_hmac_sha1_with_client,
@@ -79,7 +79,7 @@ class SignatureTests(TestCase):
resource_owner_secret = self.resource_owner_secret
)
- def test_construct_base_string(self):
+ def test_signature_base_string(self):
"""
Example text to be turned into a base string::
@@ -104,20 +104,20 @@ class SignatureTests(TestCase):
D%2522137131201%2522%252Coauth_nonce%253D%25227d8f3e4a%2522%252Coau
th_signature%253D%2522bYT5CMsGcbgUdFHObYMEfcx6bsw%25253D%2522
"""
- self.assertRaises(ValueError, construct_base_string,
+ self.assertRaises(ValueError, signature_base_string,
self.http_method,
self.base_string_url,
self.normalized_encoded_request_parameters)
- self.assertRaises(ValueError, construct_base_string,
+ self.assertRaises(ValueError, signature_base_string,
self.http_method.decode('utf-8'),
self.base_string_url,
self.normalized_encoded_request_parameters)
- self.assertRaises(ValueError, construct_base_string,
+ self.assertRaises(ValueError, signature_base_string,
self.http_method.decode('utf-8'),
self.base_string_url.decode('utf-8'),
self.normalized_encoded_request_parameters)
- base_string = construct_base_string(
+ base_string = signature_base_string(
self.http_method.decode('utf-8'),
self.base_string_url.decode('utf-8'),
self.normalized_encoded_request_parameters.decode('utf-8')
@@ -125,7 +125,7 @@ class SignatureTests(TestCase):
self.assertEqual(self.control_base_string, base_string)
- def test_normalize_base_string_uri(self):
+ def test_base_string_uri(self):
"""
Example text to be turned into a normalized base string uri::
@@ -137,33 +137,44 @@ class SignatureTests(TestCase):
https://www.example.net:8080/
"""
+ # test first example from RFC 5849 section 3.4.1.2.
+ # Note: there is a space between "r" and "v"
+ uri = 'http://EXAMPLE.COM:80/r v/X?id=123'
+ self.assertEqual(base_string_uri(uri),
+ 'http://example.com/r%20v/X')
+
+ # test second example from RFC 5849 section 3.4.1.2.
+ uri = 'https://www.example.net:8080/?q=1'
+ self.assertEqual(base_string_uri(uri),
+ 'https://www.example.net:8080/')
+
# test for unicode failure
uri = b"www.example.com:8080"
- self.assertRaises(ValueError, normalize_base_string_uri, uri)
+ self.assertRaises(ValueError, base_string_uri, uri)
# test for missing scheme
uri = "www.example.com:8080"
- self.assertRaises(ValueError, normalize_base_string_uri, uri)
+ self.assertRaises(ValueError, base_string_uri, uri)
# test a URI with the default port
uri = "http://www.example.com:80/"
- self.assertEqual(normalize_base_string_uri(uri),
+ self.assertEqual(base_string_uri(uri),
"http://www.example.com/")
# test a URI missing a path
uri = "http://www.example.com"
- self.assertEqual(normalize_base_string_uri(uri),
+ self.assertEqual(base_string_uri(uri),
"http://www.example.com/")
# test a relative URI
uri = "/a-host-relative-uri"
host = "www.example.com"
- self.assertRaises(ValueError, normalize_base_string_uri, (uri, host))
+ self.assertRaises(ValueError, base_string_uri, (uri, host))
# test overriding the URI's netloc with a host argument
uri = "http://www.example.com/a-path"
host = "alternatehost.example.com"
- self.assertEqual(normalize_base_string_uri(uri, host),
+ self.assertEqual(base_string_uri(uri, host),
"http://alternatehost.example.com/a-path")
def test_collect_parameters(self):
diff --git a/tests/oauth2/rfc6749/endpoints/test_base_endpoint.py b/tests/oauth2/rfc6749/endpoints/test_base_endpoint.py
index 4f78d9b..bf04a42 100644
--- a/tests/oauth2/rfc6749/endpoints/test_base_endpoint.py
+++ b/tests/oauth2/rfc6749/endpoints/test_base_endpoint.py
@@ -25,7 +25,7 @@ class BaseEndpointTest(TestCase):
server = Server(validator)
server.catch_errors = True
h, b, s = server.create_token_response(
- 'https://example.com?grant_type=authorization_code&code=abc'
+ 'https://example.com', body='grant_type=authorization_code&code=abc'
)
self.assertIn("server_error", b)
self.assertEqual(s, 500)
diff --git a/tests/oauth2/rfc6749/endpoints/test_error_responses.py b/tests/oauth2/rfc6749/endpoints/test_error_responses.py
index a249cb1..2479836 100644
--- a/tests/oauth2/rfc6749/endpoints/test_error_responses.py
+++ b/tests/oauth2/rfc6749/endpoints/test_error_responses.py
@@ -6,11 +6,11 @@ import json
import mock
+from oauthlib.common import urlencode
from oauthlib.oauth2 import (BackendApplicationServer, LegacyApplicationServer,
MobileApplicationServer, RequestValidator,
WebApplicationServer)
from oauthlib.oauth2.rfc6749 import errors
-
from ....unittest import TestCase
@@ -437,3 +437,56 @@ class ErrorResponseTest(TestCase):
_, body, _ = self.backend.create_token_response('https://i.b/token',
body='grant_type=bar')
self.assertEqual('unsupported_grant_type', json.loads(body)['error'])
+
+ def test_invalid_request_method(self):
+ test_methods = ['GET', 'pUt', 'dEleTe', 'paTcH']
+ test_methods = test_methods + [x.lower() for x in test_methods] + [x.upper() for x in test_methods]
+ for method in test_methods:
+ self.validator.authenticate_client.side_effect = self.set_client
+
+ uri = "http://i/b/token/"
+ try:
+ _, body, s = self.web.create_token_response(uri,
+ body='grant_type=access_token&code=123', http_method=method)
+ self.fail('This should have failed with InvalidRequestError')
+ except errors.InvalidRequestError as ire:
+ self.assertIn('Unsupported request method', ire.description)
+
+ try:
+ _, body, s = self.legacy.create_token_response(uri,
+ body='grant_type=access_token&code=123', http_method=method)
+ self.fail('This should have failed with InvalidRequestError')
+ except errors.InvalidRequestError as ire:
+ self.assertIn('Unsupported request method', ire.description)
+
+ try:
+ _, body, s = self.backend.create_token_response(uri,
+ body='grant_type=access_token&code=123', http_method=method)
+ self.fail('This should have failed with InvalidRequestError')
+ except errors.InvalidRequestError as ire:
+ self.assertIn('Unsupported request method', ire.description)
+
+ def test_invalid_post_request(self):
+ self.validator.authenticate_client.side_effect = self.set_client
+ for param in ['token', 'secret', 'code', 'foo']:
+ uri = 'https://i/b/token?' + urlencode([(param, 'secret')])
+ try:
+ _, body, s = self.web.create_token_response(uri,
+ body='grant_type=access_token&code=123')
+ self.fail('This should have failed with InvalidRequestError')
+ except errors.InvalidRequestError as ire:
+ self.assertIn('URL query parameters are not allowed', ire.description)
+
+ try:
+ _, body, s = self.legacy.create_token_response(uri,
+ body='grant_type=access_token&code=123')
+ self.fail('This should have failed with InvalidRequestError')
+ except errors.InvalidRequestError as ire:
+ self.assertIn('URL query parameters are not allowed', ire.description)
+
+ try:
+ _, body, s = self.backend.create_token_response(uri,
+ body='grant_type=access_token&code=123')
+ self.fail('This should have failed with InvalidRequestError')
+ except errors.InvalidRequestError as ire:
+ self.assertIn('URL query parameters are not allowed', ire.description)
diff --git a/tests/oauth2/rfc6749/endpoints/test_introspect_endpoint.py b/tests/oauth2/rfc6749/endpoints/test_introspect_endpoint.py
index b9bf76a..ae3deae 100644
--- a/tests/oauth2/rfc6749/endpoints/test_introspect_endpoint.py
+++ b/tests/oauth2/rfc6749/endpoints/test_introspect_endpoint.py
@@ -139,3 +139,33 @@ class IntrospectEndpointTest(TestCase):
self.assertEqual(h, self.resp_h)
self.assertEqual(loads(b)['error'], 'invalid_request')
self.assertEqual(s, 400)
+
+ def test_introspect_invalid_request_method(self):
+ endpoint = IntrospectEndpoint(self.validator,
+ supported_token_types=['access_token'])
+ test_methods = ['GET', 'pUt', 'dEleTe', 'paTcH']
+ test_methods = test_methods + [x.lower() for x in test_methods] + [x.upper() for x in test_methods]
+ for method in test_methods:
+ body = urlencode([('token', 'foo'),
+ ('token_type_hint', 'refresh_token')])
+ h, b, s = endpoint.create_introspect_response(self.uri,
+ http_method = method, headers=self.headers, body=body)
+ self.assertEqual(h, self.resp_h)
+ self.assertEqual(loads(b)['error'], 'invalid_request')
+ self.assertIn('Unsupported request method', loads(b)['error_description'])
+ self.assertEqual(s, 400)
+
+ def test_introspect_bad_post_request(self):
+ endpoint = IntrospectEndpoint(self.validator,
+ supported_token_types=['access_token'])
+ for param in ['token', 'secret', 'code', 'foo']:
+ uri = 'http://some.endpoint?' + urlencode([(param, 'secret')])
+ body = urlencode([('token', 'foo'),
+ ('token_type_hint', 'access_token')])
+ h, b, s = endpoint.create_introspect_response(
+ uri,
+ headers=self.headers, body=body)
+ self.assertEqual(h, self.resp_h)
+ self.assertEqual(loads(b)['error'], 'invalid_request')
+ self.assertIn('query parameters are not allowed', loads(b)['error_description'])
+ self.assertEqual(s, 400)
diff --git a/tests/oauth2/rfc6749/endpoints/test_revocation_endpoint.py b/tests/oauth2/rfc6749/endpoints/test_revocation_endpoint.py
index 2a24177..17be3a5 100644
--- a/tests/oauth2/rfc6749/endpoints/test_revocation_endpoint.py
+++ b/tests/oauth2/rfc6749/endpoints/test_revocation_endpoint.py
@@ -120,3 +120,32 @@ class RevocationEndpointTest(TestCase):
self.assertEqual(h, self.resp_h)
self.assertEqual(loads(b)['error'], 'invalid_request')
self.assertEqual(s, 400)
+
+ def test_revoke_invalid_request_method(self):
+ endpoint = RevocationEndpoint(self.validator,
+ supported_token_types=['access_token'])
+ test_methods = ['GET', 'pUt', 'dEleTe', 'paTcH']
+ test_methods = test_methods + [x.lower() for x in test_methods] + [x.upper() for x in test_methods]
+ for method in test_methods:
+ body = urlencode([('token', 'foo'),
+ ('token_type_hint', 'refresh_token')])
+ h, b, s = endpoint.create_revocation_response(self.uri,
+ http_method = method, headers=self.headers, body=body)
+ self.assertEqual(h, self.resp_h)
+ self.assertEqual(loads(b)['error'], 'invalid_request')
+ self.assertIn('Unsupported request method', loads(b)['error_description'])
+ self.assertEqual(s, 400)
+
+ def test_revoke_bad_post_request(self):
+ endpoint = RevocationEndpoint(self.validator,
+ supported_token_types=['access_token'])
+ for param in ['token', 'secret', 'code', 'foo']:
+ uri = 'http://some.endpoint?' + urlencode([(param, 'secret')])
+ body = urlencode([('token', 'foo'),
+ ('token_type_hint', 'access_token')])
+ h, b, s = endpoint.create_revocation_response(uri,
+ headers=self.headers, body=body)
+ self.assertEqual(h, self.resp_h)
+ self.assertEqual(loads(b)['error'], 'invalid_request')
+ self.assertIn('query parameters are not allowed', loads(b)['error_description'])
+ self.assertEqual(s, 400)
diff --git a/tests/oauth2/rfc6749/grant_types/test_authorization_code.py b/tests/oauth2/rfc6749/grant_types/test_authorization_code.py
index 00e2b6d..2c9db3c 100644
--- a/tests/oauth2/rfc6749/grant_types/test_authorization_code.py
+++ b/tests/oauth2/rfc6749/grant_types/test_authorization_code.py
@@ -215,8 +215,10 @@ class AuthorizationCodeGrantTest(TestCase):
self.mock_validator.is_pkce_required.return_value = required
self.request.code_challenge = "present"
_, ri = self.auth.validate_authorization_request(self.request)
- self.assertIsNotNone(ri["request"].code_challenge_method)
- self.assertEqual(ri["request"].code_challenge_method, "plain")
+ self.assertIn("code_challenge", ri)
+ self.assertIn("code_challenge_method", ri)
+ self.assertEqual(ri["code_challenge"], "present")
+ self.assertEqual(ri["code_challenge_method"], "plain")
def test_pkce_wrong_method(self):
for required in [True, False]:
diff --git a/tests/oauth2/rfc6749/test_parameters.py b/tests/oauth2/rfc6749/test_parameters.py
index c42f516..48b7eac 100644
--- a/tests/oauth2/rfc6749/test_parameters.py
+++ b/tests/oauth2/rfc6749/test_parameters.py
@@ -73,7 +73,8 @@ class ParameterTests(TestCase):
error_nocode = 'https://client.example.com/cb?state=xyz'
error_nostate = 'https://client.example.com/cb?code=SplxlOBeZQQYbYS6WxSbIA'
error_wrongstate = 'https://client.example.com/cb?code=SplxlOBeZQQYbYS6WxSbIA&state=abc'
- error_response = 'https://client.example.com/cb?error=access_denied&state=xyz'
+ error_denied = 'https://client.example.com/cb?error=access_denied&state=xyz'
+ error_invalid = 'https://client.example.com/cb?error=invalid_request&state=xyz'
implicit_base = 'https://example.com/cb#access_token=2YotnFZFEjr1zCsicMWpAA&scope=abc&'
implicit_response = implicit_base + 'state={0}&token_type=example&expires_in=3600'.format(state)
@@ -102,6 +103,15 @@ class ParameterTests(TestCase):
' "expires_in": 3600,'
' "refresh_token": "tGzv3JOkF0XG5Qx2TlKWIA",'
' "example_parameter": "example_value" }')
+ json_response_noexpire = ('{ "access_token": "2YotnFZFEjr1zCsicMWpAA",'
+ ' "token_type": "example",'
+ ' "refresh_token": "tGzv3JOkF0XG5Qx2TlKWIA",'
+ ' "example_parameter": "example_value"}')
+ json_response_expirenull = ('{ "access_token": "2YotnFZFEjr1zCsicMWpAA",'
+ ' "token_type": "example",'
+ ' "expires_in": null,'
+ ' "refresh_token": "tGzv3JOkF0XG5Qx2TlKWIA",'
+ ' "example_parameter": "example_value"}')
json_custom_error = '{ "error": "incorrect_client_credentials" }'
json_error = '{ "error": "access_denied" }'
@@ -135,6 +145,13 @@ class ParameterTests(TestCase):
'example_parameter': 'example_value'
}
+ json_noexpire_dict = {
+ 'access_token': '2YotnFZFEjr1zCsicMWpAA',
+ 'token_type': 'example',
+ 'refresh_token': 'tGzv3JOkF0XG5Qx2TlKWIA',
+ 'example_parameter': 'example_value'
+ }
+
json_notype_dict = {
'access_token': '2YotnFZFEjr1zCsicMWpAA',
'expires_in': 3600,
@@ -180,8 +197,10 @@ class ParameterTests(TestCase):
self.assertRaises(MissingCodeError, parse_authorization_code_response,
self.error_nocode)
- self.assertRaises(MissingCodeError, parse_authorization_code_response,
- self.error_response)
+ self.assertRaises(AccessDeniedError, parse_authorization_code_response,
+ self.error_denied)
+ self.assertRaises(InvalidRequestFatalError, parse_authorization_code_response,
+ self.error_invalid)
self.assertRaises(MismatchingStateError, parse_authorization_code_response,
self.error_nostate, state=self.state)
self.assertRaises(MismatchingStateError, parse_authorization_code_response,
@@ -209,6 +228,8 @@ class ParameterTests(TestCase):
self.assertEqual(parse_token_response(self.json_response_noscope,
scope=['all', 'the', 'scopes']), self.json_noscope_dict)
+ self.assertEqual(parse_token_response(self.json_response_noexpire), self.json_noexpire_dict)
+ self.assertEqual(parse_token_response(self.json_response_expirenull), self.json_noexpire_dict)
scope_changes_recorded = []
def record_scope_change(sender, message, old, new):
diff --git a/tests/oauth2/rfc6749/test_tokens.py b/tests/oauth2/rfc6749/test_tokens.py
index 061754f..e6f49b1 100644
--- a/tests/oauth2/rfc6749/test_tokens.py
+++ b/tests/oauth2/rfc6749/test_tokens.py
@@ -1,10 +1,14 @@
from __future__ import absolute_import, unicode_literals
+import mock
+
+from oauthlib.common import Request
from oauthlib.oauth2.rfc6749.tokens import (
- prepare_mac_header,
- prepare_bearer_headers,
+ BearerToken,
prepare_bearer_body,
+ prepare_bearer_headers,
prepare_bearer_uri,
+ prepare_mac_header,
)
from ...unittest import TestCase
@@ -64,6 +68,7 @@ class TokenTest(TestCase):
bearer_headers = {
'Authorization': 'Bearer vF9dft4qmT'
}
+ valid_bearer_header_lowercase = {"Authorization": "bearer vF9dft4qmT"}
fake_bearer_headers = [
{'Authorization': 'Beaver vF9dft4qmT'},
{'Authorization': 'BeavervF9dft4qmT'},
@@ -98,3 +103,73 @@ class TokenTest(TestCase):
self.assertEqual(prepare_bearer_headers(self.token), self.bearer_headers)
self.assertEqual(prepare_bearer_body(self.token), self.bearer_body)
self.assertEqual(prepare_bearer_uri(self.token, uri=self.uri), self.bearer_uri)
+
+ def test_valid_bearer_is_validated(self):
+ request_validator = mock.MagicMock()
+ request_validator.validate_bearer_token = self._mocked_validate_bearer_token
+
+ request = Request("/", headers=self.bearer_headers)
+ result = BearerToken(request_validator=request_validator).validate_request(
+ request
+ )
+ self.assertTrue(result)
+
+ def test_lowercase_bearer_is_validated(self):
+ request_validator = mock.MagicMock()
+ request_validator.validate_bearer_token = self._mocked_validate_bearer_token
+
+ request = Request("/", headers=self.valid_bearer_header_lowercase)
+ result = BearerToken(request_validator=request_validator).validate_request(
+ request
+ )
+ self.assertTrue(result)
+
+ def test_fake_bearer_is_not_validated(self):
+ request_validator = mock.MagicMock()
+ request_validator.validate_bearer_token = self._mocked_validate_bearer_token
+
+ for fake_header in self.fake_bearer_headers:
+ request = Request("/", headers=fake_header)
+ result = BearerToken(request_validator=request_validator).validate_request(
+ request
+ )
+
+ self.assertFalse(result)
+
+ def test_header_with_multispaces_is_validated(self):
+ request_validator = mock.MagicMock()
+ request_validator.validate_bearer_token = self._mocked_validate_bearer_token
+
+ request = Request("/", headers=self.valid_header_with_multiple_spaces)
+ result = BearerToken(request_validator=request_validator).validate_request(
+ request
+ )
+
+ self.assertTrue(result)
+
+ def test_estimate_type(self):
+ request_validator = mock.MagicMock()
+ request_validator.validate_bearer_token = self._mocked_validate_bearer_token
+ request = Request("/", headers=self.bearer_headers)
+ result = BearerToken(request_validator=request_validator).estimate_type(request)
+ self.assertEqual(result, 9)
+
+ def test_estimate_type_with_fake_header_returns_type_0(self):
+ request_validator = mock.MagicMock()
+ request_validator.validate_bearer_token = self._mocked_validate_bearer_token
+
+ for fake_header in self.fake_bearer_headers:
+ request = Request("/", headers=fake_header)
+ result = BearerToken(request_validator=request_validator).estimate_type(
+ request
+ )
+
+ if (
+ fake_header["Authorization"].count(" ") == 2
+ and fake_header["Authorization"].split()[0] == "Bearer"
+ ):
+ # If we're dealing with the header containing 2 spaces, it will be recognized
+ # as a Bearer valid header, the token itself will be invalid by the way.
+ self.assertEqual(result, 9)
+ else:
+ self.assertEqual(result, 0)
diff --git a/tests/openid/connect/core/endpoints/test_claims_handling.py b/tests/openid/connect/core/endpoints/test_claims_handling.py
index 270ef69..5f39d96 100644
--- a/tests/openid/connect/core/endpoints/test_claims_handling.py
+++ b/tests/openid/connect/core/endpoints/test_claims_handling.py
@@ -10,7 +10,7 @@ from __future__ import absolute_import, unicode_literals
import mock
-from oauthlib.oauth2 import RequestValidator
+from oauthlib.openid import RequestValidator
from oauthlib.openid.connect.core.endpoints.pre_configured import Server
from tests.unittest import TestCase
diff --git a/tests/openid/connect/core/grant_types/test_authorization_code.py b/tests/openid/connect/core/grant_types/test_authorization_code.py
index 89401ab..b721a19 100644
--- a/tests/openid/connect/core/grant_types/test_authorization_code.py
+++ b/tests/openid/connect/core/grant_types/test_authorization_code.py
@@ -117,7 +117,6 @@ class OpenIDAuthCodeTest(TestCase):
def set_scopes(self, client_id, code, client, request):
request.scopes = self.request.scopes
- request.state = self.request.state
request.user = 'bob'
return True
diff --git a/tests/openid/connect/core/grant_types/test_base.py b/tests/openid/connect/core/grant_types/test_base.py
new file mode 100644
index 0000000..d506b7e
--- /dev/null
+++ b/tests/openid/connect/core/grant_types/test_base.py
@@ -0,0 +1,104 @@
+# -*- coding: utf-8 -*-
+import mock
+import time
+
+from oauthlib.common import Request
+from oauthlib.openid.connect.core.grant_types.base import GrantTypeBase
+
+from tests.unittest import TestCase
+
+
+class GrantBase(GrantTypeBase):
+ """Class to test GrantTypeBase"""
+ def __init__(self, request_validator=None, **kwargs):
+ self.request_validator = request_validator
+
+
+class IDTokenTest(TestCase):
+
+ def setUp(self):
+ self.request = Request('http://a.b/path')
+ self.request.scopes = ('hello', 'openid')
+ self.request.expires_in = 1800
+ self.request.client_id = 'abcdef'
+ self.request.code = '1234'
+ self.request.response_type = 'id_token'
+ self.request.grant_type = 'authorization_code'
+ self.request.redirect_uri = 'https://a.b/cb'
+ self.request.state = 'abc'
+ self.request.nonce = None
+
+ self.mock_validator = mock.MagicMock()
+ self.mock_validator.get_id_token.return_value = None
+ self.mock_validator.finalize_id_token.return_value = "eyJ.body.signature"
+ self.token = {}
+
+ self.grant = GrantBase(request_validator=self.mock_validator)
+
+ self.url_query = 'https://a.b/cb?code=abc&state=abc'
+ self.url_fragment = 'https://a.b/cb#code=abc&state=abc'
+
+ def test_id_token_hash(self):
+ self.assertEqual(self.grant.id_token_hash(
+ "Qcb0Orv1zh30vL1MPRsbm-diHiMwcLyZvn1arpZv-Jxf_11jnpEX3Tgfvk",
+ ), "LDktKdoQak3Pk0cnXxCltA", "hash differs from RFC")
+
+ def test_get_id_token_no_openid(self):
+ self.request.scopes = ('hello')
+ token = self.grant.add_id_token(self.token, "token_handler_mock", self.request)
+ self.assertNotIn("id_token", token)
+
+ self.request.scopes = None
+ token = self.grant.add_id_token(self.token, "token_handler_mock", self.request)
+ self.assertNotIn("id_token", token)
+
+ self.request.scopes = ()
+ token = self.grant.add_id_token(self.token, "token_handler_mock", self.request)
+ self.assertNotIn("id_token", token)
+
+ def test_get_id_token(self):
+ self.mock_validator.get_id_token.return_value = "toto"
+ token = self.grant.add_id_token(self.token, "token_handler_mock", self.request)
+ self.assertIn("id_token", token)
+ self.assertEqual(token["id_token"], "toto")
+
+ def test_finalize_id_token(self):
+ token = self.grant.add_id_token(self.token, "token_handler_mock", self.request)
+ self.assertIn("id_token", token)
+ self.assertEqual(token["id_token"], "eyJ.body.signature")
+ id_token = self.mock_validator.finalize_id_token.call_args[0][0]
+ self.assertEqual(id_token['aud'], 'abcdef')
+ self.assertGreaterEqual(int(time.time()), id_token['iat'])
+
+ def test_finalize_id_token_with_nonce(self):
+ token = self.grant.add_id_token(self.token, "token_handler_mock", self.request, "my_nonce")
+ self.assertIn("id_token", token)
+ self.assertEqual(token["id_token"], "eyJ.body.signature")
+ id_token = self.mock_validator.finalize_id_token.call_args[0][0]
+ self.assertEqual(id_token['nonce'], 'my_nonce')
+
+ def test_finalize_id_token_with_at_hash(self):
+ self.token["access_token"] = "Qcb0Orv1zh30vL1MPRsbm-diHiMwcLyZvn1arpZv-Jxf_11jnpEX3Tgfvk"
+ token = self.grant.add_id_token(self.token, "token_handler_mock", self.request)
+ self.assertIn("id_token", token)
+ self.assertEqual(token["id_token"], "eyJ.body.signature")
+ id_token = self.mock_validator.finalize_id_token.call_args[0][0]
+ self.assertEqual(id_token['at_hash'], 'LDktKdoQak3Pk0cnXxCltA')
+
+ def test_finalize_id_token_with_c_hash(self):
+ self.token["code"] = "Qcb0Orv1zh30vL1MPRsbm-diHiMwcLyZvn1arpZv-Jxf_11jnpEX3Tgfvk"
+ token = self.grant.add_id_token(self.token, "token_handler_mock", self.request)
+ self.assertIn("id_token", token)
+ self.assertEqual(token["id_token"], "eyJ.body.signature")
+ id_token = self.mock_validator.finalize_id_token.call_args[0][0]
+ self.assertEqual(id_token['c_hash'], 'LDktKdoQak3Pk0cnXxCltA')
+
+ def test_finalize_id_token_with_c_and_at_hash(self):
+ self.token["code"] = "Qcb0Orv1zh30vL1MPRsbm-diHiMwcLyZvn1arpZv-Jxf_11jnpEX3Tgfvk"
+ self.token["access_token"] = "Qcb0Orv1zh30vL1MPRsbm-diHiMwcLyZvn1arpZv-Jxf_11jnpEX3Tgfvk"
+ token = self.grant.add_id_token(self.token, "token_handler_mock", self.request)
+ self.assertIn("id_token", token)
+ self.assertEqual(token["id_token"], "eyJ.body.signature")
+ id_token = self.mock_validator.finalize_id_token.call_args[0][0]
+ self.assertEqual(id_token['at_hash'], 'LDktKdoQak3Pk0cnXxCltA')
+ self.assertEqual(id_token['c_hash'], 'LDktKdoQak3Pk0cnXxCltA')
diff --git a/tests/openid/connect/core/test_request_validator.py b/tests/openid/connect/core/test_request_validator.py
index 1e71fb1..ebe0aeb 100644
--- a/tests/openid/connect/core/test_request_validator.py
+++ b/tests/openid/connect/core/test_request_validator.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, unicode_literals
-from oauthlib.openid.connect.core.request_validator import RequestValidator
+from oauthlib.openid import RequestValidator
from tests.unittest import TestCase
@@ -22,8 +22,8 @@ class RequestValidatorTest(TestCase):
)
self.assertRaises(
NotImplementedError,
- v.get_id_token,
- 'token', 'token_handler', 'request'
+ v.finalize_id_token,
+ 'id_token', 'token', 'token_handler', 'request'
)
self.assertRaises(
NotImplementedError,
diff --git a/tests/openid/connect/core/test_tokens.py b/tests/openid/connect/core/test_tokens.py
index 1fcfb51..fde89d6 100644
--- a/tests/openid/connect/core/test_tokens.py
+++ b/tests/openid/connect/core/test_tokens.py
@@ -42,7 +42,7 @@ class JWTTokenTestCase(TestCase):
"""
request_mock = mock.MagicMock()
- with mock.patch('oauthlib.oauth2.rfc6749.request_validator.RequestValidator',
+ with mock.patch('oauthlib.openid.RequestValidator',
autospec=True) as RequestValidatorMock:
request_validator = RequestValidatorMock()
@@ -58,7 +58,7 @@ class JWTTokenTestCase(TestCase):
"""
with mock.patch('oauthlib.common.Request', autospec=True) as RequestMock, \
- mock.patch('oauthlib.oauth2.rfc6749.request_validator.RequestValidator',
+ mock.patch('oauthlib.openid.RequestValidator',
autospec=True) as RequestValidatorMock:
request_validator_mock = RequestValidatorMock()
@@ -84,7 +84,7 @@ class JWTTokenTestCase(TestCase):
"""
with mock.patch('oauthlib.common.Request', autospec=True) as RequestMock, \
- mock.patch('oauthlib.oauth2.rfc6749.request_validator.RequestValidator',
+ mock.patch('oauthlib.openid.RequestValidator',
autospec=True) as RequestValidatorMock:
request_validator_mock = RequestValidatorMock()
diff --git a/tests/test_common.py b/tests/test_common.py
index 20d9f5b..ae2531b 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -1,8 +1,10 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, unicode_literals
+import os
import sys
+import oauthlib
from oauthlib.common import (CaseInsensitiveDict, Request, add_params_to_uri,
extract_params, generate_client_id,
generate_nonce, generate_timestamp,
@@ -214,6 +216,20 @@ class RequestTest(TestCase):
self.assertEqual(r.headers['token'], 'foobar')
self.assertEqual(r.token, 'banana')
+ def test_sanitized_request_non_debug_mode(self):
+ """make sure requests are sanitized when in non debug mode.
+ For the debug mode, the other tests checking sanitization should prove
+ that debug mode is working.
+ """
+ try:
+ oauthlib.set_debug(False)
+ r = Request(URI, headers={'token': 'foobar'}, body='token=banana')
+ self.assertNotIn('token', repr(r))
+ self.assertIn('SANITIZED', repr(r))
+ finally:
+ # set flag back for other tests
+ oauthlib.set_debug(True)
+
class CaseInsensitiveDictTest(TestCase):