diff options
-rw-r--r-- | oauthlib/oauth2/draft25/__init__.py | 6 | ||||
-rw-r--r-- | tests/oauth2/draft25/test_client.py | 8 |
2 files changed, 12 insertions, 2 deletions
diff --git a/oauthlib/oauth2/draft25/__init__.py b/oauthlib/oauth2/draft25/__init__.py index bf916b5..98349e7 100644 --- a/oauthlib/oauth2/draft25/__init__.py +++ b/oauthlib/oauth2/draft25/__init__.py @@ -8,6 +8,7 @@ oauthlib.oauth2.draft_25 This module is an implementation of various logic needed for signing and checking OAuth 2.0 draft 25 requests. """ +import string import datetime import functools import logging @@ -121,7 +122,8 @@ class Client(object): token_placement = token_placement or self.default_token_placement - if not self.token_type in self.token_types: + case_insensitive_token_types = dict(zip(map(string.lower, self.token_types.keys()), self.token_types.values())) + if not self.token_type.lower() in case_insensitive_token_types: raise ValueError("Unsupported token type: %s" % self.token_type) if not self.access_token: @@ -130,7 +132,7 @@ class Client(object): if self._expires_at and self._expires_at < datetime.datetime.now(): raise TokenExpiredError() - return self.token_types[self.token_type](uri, http_method, body, + return case_insensitive_token_types[self.token_type.lower()](uri, http_method, body, headers, token_placement, **kwargs) def prepare_refresh_body(self, body='', refresh_token=None, scope=None, **kwargs): diff --git a/tests/oauth2/draft25/test_client.py b/tests/oauth2/draft25/test_client.py index 47ed538..acceaca 100644 --- a/tests/oauth2/draft25/test_client.py +++ b/tests/oauth2/draft25/test_client.py @@ -43,6 +43,14 @@ class ClientTest(TestCase): client = Client(self.client_id, token_type="invalid") self.assertRaises(ValueError, client.add_token, self.uri) + # Case-insensitive token type + client = Client(self.client_id, access_token=self.access_token, token_type="bEAreR") + uri, headers, body = client.add_token(self.uri, body=self.body, + headers=self.headers) + self.assertURLEqual(uri, self.uri) + self.assertFormBodyEqual(body, self.body) + self.assertEqual(headers, self.bearer_header) + # Missing access token client = Client(self.client_id) self.assertRaises(ValueError, client.add_token, self.uri) |