diff options
author | Abhishek Patel <5524161+Abhishek8394@users.noreply.github.com> | 2019-05-06 23:26:29 -0700 |
---|---|---|
committer | Abhishek Patel <5524161+Abhishek8394@users.noreply.github.com> | 2019-05-14 00:37:59 -0700 |
commit | 047ceccf48ea7ccd4ecc6b48a8ddb6dd4a14abd6 (patch) | |
tree | b4a8b62f205d5e41dc245273e34669319b1734f1 | |
parent | bbbcca731d5db16d7b1765070880aa54288788e9 (diff) | |
download | oauthlib-047ceccf48ea7ccd4ecc6b48a8ddb6dd4a14abd6.tar.gz |
Add tests + create a global variable for blacklisted query parameters
4 files changed, 68 insertions, 7 deletions
diff --git a/oauthlib/oauth2/rfc6749/endpoints/base.py b/oauthlib/oauth2/rfc6749/endpoints/base.py index 29086e4..dc3204b 100644 --- a/oauthlib/oauth2/rfc6749/endpoints/base.py +++ b/oauthlib/oauth2/rfc6749/endpoints/base.py @@ -15,17 +15,18 @@ from ..errors import (FatalClientError, OAuth2Error, ServerError, TemporarilyUnavailableError, InvalidRequestError, InvalidClientError, UnsupportedTokenTypeError) -from oauthlib.common import CaseInsensitiveDict +from oauthlib.common import CaseInsensitiveDict, urldecode log = logging.getLogger(__name__) +BLACKLIST_QUERY_PARAMS = {'client_secret', 'code_verifier'} class BaseEndpoint(object): def __init__(self): self._available = True self._catch_errors = False - self._blacklist_query_params = {'client_secret', 'code_verifier'} + self._blacklist_query_params = BLACKLIST_QUERY_PARAMS @property def available(self): @@ -33,7 +34,7 @@ class BaseEndpoint(object): @available.setter def available(self, available): - self._available = available + self._available = available @property def catch_errors(self): @@ -69,11 +70,12 @@ class BaseEndpoint(object): """Raise if invalid POST request received """ if request.http_method.lower() == 'post': - query_params = CaseInsensitiveDict(urldecode(request.uri_query)) - for k in self._blacklist_query_params: - if k in query_params: + query_params = CaseInsensitiveDict(dict(urldecode(request.uri_query))) + for param in self._blacklist_query_params: + if param in query_params: raise InvalidRequestError(request=request, - description='Query parameters not allowed') + description=('"%s" is not allowed as a url query' +\ + ' parameter') % (param)) def catch_errors_and_unavailability(f): @functools.wraps(f) diff --git a/tests/oauth2/rfc6749/endpoints/test_error_responses.py b/tests/oauth2/rfc6749/endpoints/test_error_responses.py index a249cb1..4a288ad 100644 --- a/tests/oauth2/rfc6749/endpoints/test_error_responses.py +++ b/tests/oauth2/rfc6749/endpoints/test_error_responses.py @@ -6,10 +6,12 @@ import json import mock +from oauthlib.common import urlencode from oauthlib.oauth2 import (BackendApplicationServer, LegacyApplicationServer, MobileApplicationServer, RequestValidator, WebApplicationServer) from oauthlib.oauth2.rfc6749 import errors +from oauthlib.oauth2.rfc6749.endpoints.base import BLACKLIST_QUERY_PARAMS from ....unittest import TestCase @@ -437,3 +439,28 @@ class ErrorResponseTest(TestCase): _, body, _ = self.backend.create_token_response('https://i.b/token', body='grant_type=bar') self.assertEqual('unsupported_grant_type', json.loads(body)['error']) + + def test_invalid_post_request(self): + self.validator.authenticate_client.side_effect = self.set_client + for param in BLACKLIST_QUERY_PARAMS: + uri = 'https://i/b/token?' + urlencode([(param, 'secret')]) + _, body, s = self.web.create_introspect_response(uri, + body='grant_type=access_token&code=123') + self.assertEqual(json.loads(body)['error'], 'invalid_request') + self.assertIn(param, json.loads(body)['error_description']) + self.assertIn('not allowed', json.loads(body)['error_description']) + self.assertEqual(s, 400) + + _, body, s = self.legacy.create_introspect_response(uri, + body='grant_type=access_token&code=123') + self.assertEqual(json.loads(body)['error'], 'invalid_request') + self.assertIn(param, json.loads(body)['error_description']) + self.assertIn('not allowed', json.loads(body)['error_description']) + self.assertEqual(s, 400) + + _, body, s = self.backend.create_introspect_response(uri, + body='grant_type=access_token&code=123') + self.assertEqual(json.loads(body)['error'], 'invalid_request') + self.assertIn(param, json.loads(body)['error_description']) + self.assertIn('not allowed', json.loads(body)['error_description']) + self.assertEqual(s, 400) diff --git a/tests/oauth2/rfc6749/endpoints/test_introspect_endpoint.py b/tests/oauth2/rfc6749/endpoints/test_introspect_endpoint.py index b9bf76a..234a4ef 100644 --- a/tests/oauth2/rfc6749/endpoints/test_introspect_endpoint.py +++ b/tests/oauth2/rfc6749/endpoints/test_introspect_endpoint.py @@ -7,6 +7,7 @@ from mock import MagicMock from oauthlib.common import urlencode from oauthlib.oauth2 import RequestValidator, IntrospectEndpoint +from oauthlib.oauth2.rfc6749.endpoints.base import BLACKLIST_QUERY_PARAMS from ....unittest import TestCase @@ -139,3 +140,18 @@ class IntrospectEndpointTest(TestCase): self.assertEqual(h, self.resp_h) self.assertEqual(loads(b)['error'], 'invalid_request') self.assertEqual(s, 400) + + def test_introspect_bad_post_request(self): + endpoint = IntrospectEndpoint(self.validator, + supported_token_types=['access_token']) + for param in BLACKLIST_QUERY_PARAMS: + uri = 'http://some.endpoint?' + urlencode([(param, 'secret')]) + body = urlencode([('token', 'foo'), + ('token_type_hint', 'access_token')]) + h, b, s = endpoint.create_introspect_response(uri, + headers=self.headers, body=body) + self.assertEqual(h, self.resp_h) + self.assertEqual(loads(b)['error'], 'invalid_request') + self.assertIn(param, loads(b)['error_description']) + self.assertIn('not allowed', loads(b)['error_description']) + self.assertEqual(s, 400) diff --git a/tests/oauth2/rfc6749/endpoints/test_revocation_endpoint.py b/tests/oauth2/rfc6749/endpoints/test_revocation_endpoint.py index 2a24177..e89c3bd 100644 --- a/tests/oauth2/rfc6749/endpoints/test_revocation_endpoint.py +++ b/tests/oauth2/rfc6749/endpoints/test_revocation_endpoint.py @@ -7,6 +7,7 @@ from mock import MagicMock from oauthlib.common import urlencode from oauthlib.oauth2 import RequestValidator, RevocationEndpoint +from oauthlib.oauth2.rfc6749.endpoints.base import BLACKLIST_QUERY_PARAMS from ....unittest import TestCase @@ -120,3 +121,18 @@ class RevocationEndpointTest(TestCase): self.assertEqual(h, self.resp_h) self.assertEqual(loads(b)['error'], 'invalid_request') self.assertEqual(s, 400) + + def test_revoke_bad_post_request(self): + endpoint = RevocationEndpoint(self.validator, + supported_token_types=['access_token']) + for param in BLACKLIST_QUERY_PARAMS: + uri = 'http://some.endpoint?' + urlencode([(param, 'secret')]) + body = urlencode([('token', 'foo'), + ('token_type_hint', 'access_token')]) + h, b, s = endpoint.create_revocation_response(uri, + headers=self.headers, body=body) + self.assertEqual(h, self.resp_h) + self.assertEqual(loads(b)['error'], 'invalid_request') + self.assertIn(param, loads(b)['error_description']) + self.assertIn('not allowed', loads(b)['error_description']) + self.assertEqual(s, 400) |