diff options
Diffstat (limited to 'oauthlib/oauth2/rfc6749/endpoints/base.py')
-rw-r--r-- | oauthlib/oauth2/rfc6749/endpoints/base.py | 31 |
1 files changed, 30 insertions, 1 deletions
diff --git a/oauthlib/oauth2/rfc6749/endpoints/base.py b/oauthlib/oauth2/rfc6749/endpoints/base.py index c0fc726..e39232f 100644 --- a/oauthlib/oauth2/rfc6749/endpoints/base.py +++ b/oauthlib/oauth2/rfc6749/endpoints/base.py @@ -15,6 +15,8 @@ from ..errors import (FatalClientError, OAuth2Error, ServerError, TemporarilyUnavailableError, InvalidRequestError, InvalidClientError, UnsupportedTokenTypeError) +from oauthlib.common import CaseInsensitiveDict, urldecode + log = logging.getLogger(__name__) @@ -23,6 +25,18 @@ class BaseEndpoint(object): def __init__(self): self._available = True self._catch_errors = False + self._valid_request_methods = None + + @property + def valid_request_methods(self): + return self._valid_request_methods + + @valid_request_methods.setter + def valid_request_methods(self, valid_request_methods): + if valid_request_methods is not None: + valid_request_methods = [x.upper() for x in valid_request_methods] + self._valid_request_methods = valid_request_methods + @property def available(self): @@ -30,7 +44,7 @@ class BaseEndpoint(object): @available.setter def available(self, available): - self._available = available + self._available = available @property def catch_errors(self): @@ -62,6 +76,21 @@ class BaseEndpoint(object): request.token_type_hint not in self.supported_token_types): raise UnsupportedTokenTypeError(request=request) + def _raise_on_bad_method(self, request): + if self.valid_request_methods is None: + raise ValueError('Configure "valid_request_methods" property first') + if request.http_method.upper() not in self.valid_request_methods: + raise InvalidRequestError(request=request, + description=('Unsupported request method %s' % request.http_method.upper())) + + def _raise_on_bad_post_request(self, request): + """Raise if invalid POST request received + """ + if request.http_method.upper() == 'POST': + query_params = request.uri_query or "" + if query_params: + raise InvalidRequestError(request=request, + description=('URL query parameters are not allowed')) def catch_errors_and_unavailability(f): @functools.wraps(f) |