summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIb Lundgren <ib.lundgren@gmail.com>2012-08-11 23:15:02 +0200
committerIb Lundgren <ib.lundgren@gmail.com>2012-08-11 23:15:02 +0200
commitb3a769e8a85bb3fbd3d39d6facf4386c768730af (patch)
tree4c436ecf2e9ac2980159406bce4b6a10d0f59a9c
parent295a8b2c41f16f38def61f806280d2386c5bffa7 (diff)
downloadoauthlib-b3a769e8a85bb3fbd3d39d6facf4386c768730af.tar.gz
and again
-rw-r--r--oauthlib/common.py9
-rw-r--r--oauthlib/oauth2/draft25/__init__.py126
-rw-r--r--tests/oauth2/draft25/test_server.py27
3 files changed, 102 insertions, 60 deletions
diff --git a/oauthlib/common.py b/oauthlib/common.py
index 109e4bb..c2ef553 100644
--- a/oauthlib/common.py
+++ b/oauthlib/common.py
@@ -177,12 +177,16 @@ def add_params_to_qs(query, params):
return urlencode(queryparams)
-def add_params_to_uri(uri, params):
+def add_params_to_uri(uri, params, fragment=False):
"""Add a list of two-tuples to the uri query components."""
sch, net, path, par, query, fra = urlparse.urlparse(uri)
- query = add_params_to_qs(query, params)
+ if fragment:
+ fra = add_params_to_qs(query, params)
+ else:
+ query = add_params_to_qs(query, params)
return urlparse.urlunparse((sch, net, path, par, query, fra))
+
def safe_string_equals(a, b):
""" Near-constant time string comparison.
@@ -200,6 +204,7 @@ def safe_string_equals(a, b):
result |= ord(x) ^ ord(y)
return result == 0
+
class Request(object):
"""A malleable representation of a signable HTTP request.
diff --git a/oauthlib/oauth2/draft25/__init__.py b/oauthlib/oauth2/draft25/__init__.py
index 927c166..779efec 100644
--- a/oauthlib/oauth2/draft25/__init__.py
+++ b/oauthlib/oauth2/draft25/__init__.py
@@ -564,9 +564,6 @@ class OAuth2Error(Exception):
def json(self):
pass
- def add_to_uri(self, uri):
- return add_params_to_uri(uri, self.twotuples)
-
class AuthorizationEndpoint(object):
"""Authorization endpoint - used by the client to obtain authorization
@@ -650,44 +647,10 @@ class AuthorizationEndpoint(object):
@property
def response_type_handlers(self):
return {
- u'code': self.authorization_code_handler,
- u'token': self.implicit_handler,
+ u'code': AuthorizationGrantCodeHandler(),
+ u'token': ImplicitGrantHandler(),
}
- def authorization_code_handler(self, params):
- self.grant = self.generate_authorization_grant()
- return add_params_to_uri(self.redirect_uri, self.grant.items())
-
- def generate_authorization_grant(self):
- """Generates an authorization grant represented as a dictionary."""
- grant = {u'code': generate_token()}
- if self.state:
- grant[u'state'] = self.state
- return grant
-
- def save_authorization_grant(self, client_id, code, state=None):
- """Saves authorization codes for later use by the token endpoint.
-
- code: The authorization code generated by the authorization server.
- The authorization code MUST expire shortly after it is issued
- to mitigate the risk of leaks. A maximum authorization code
- lifetime of 10 minutes is RECOMMENDED.
-
- state: A CSRF protection value received from the client.
- """
- raise NotImplementedError('Subclasses must implement this method.')
-
- def generate_implicit_grant(self):
- pass
-
- def save_implicit_grant(self, client_id, creds):
- raise NotImplementedError('Subclasses must implement this method.')
-
- def implicit_handler(self, params):
- self.grant = self.generate_implicit_grant()
- return add_params_to_uri(self.redirect_uri, self.grant.items())
-
- # in django, use decorate like @before_authorization
def parse_authorization_parameters(self, uri):
self.params = params_from_uri(uri)
self.client_id = self.params.get(u'client_id', None)
@@ -733,18 +696,14 @@ class AuthorizationEndpoint(object):
return True
- def validate_authorized_scopes(self, authorized_scopes):
- self.authorized_scopes = authorized_scopes
-
def create_authorization_response(self, authorized_scopes):
self.scopes = authorized_scopes
- try:
- self.validate_authorization_parameters()
- except OAuth2Error as error:
- return error.add_to_uri(self.redirect_uri)
+ if not self.response_type in self.response_type_handlers:
+ raise AuthorizationEndpoint.UnsupportedResponseTypeError(
+ state=self.state, description=u'Invalid response type')
- return self.response_type_handlers.get(self.response_type)(self.params)
+ return self.response_type_handlers.get(self.response_type)(self)
def validate_client(self, client_id):
raise NotImplementedError('Subclasses must implement this method.')
@@ -761,6 +720,21 @@ class AuthorizationEndpoint(object):
def get_default_scopes(self, client_id):
raise NotImplementedError('Subclasses must implement this method.')
+ def save_authorization_grant(self, client_id, grant, state=None):
+ """Saves authorization codes for later use by the token endpoint.
+
+ code: The authorization code generated by the authorization server.
+ The authorization code MUST expire shortly after it is issued
+ to mitigate the risk of leaks. A maximum authorization code
+ lifetime of 10 minutes is RECOMMENDED.
+
+ state: A CSRF protection value received from the client.
+ """
+ raise NotImplementedError('Subclasses must implement this method.')
+
+ def save_implicit_grant(self, client_id, grant, state=None):
+ raise NotImplementedError('Subclasses must implement this method.')
+
def params_from_uri(uri):
import urlparse
@@ -771,6 +745,64 @@ def params_from_uri(uri):
return params
+class AuthorizationGrantCodeHandler(object):
+
+ def __call__(self, endpoint):
+ self.endpoint = endpoint
+ try:
+ self.endpoint.validate_authorization_parameters()
+
+ except OAuth2Error as e:
+ return add_params_to_uri(self.endpoint.redirect_uri, e.twotuples)
+
+ self.grant = self.create_authorization_grant()
+ self.endpoint.save_authorization_grant(
+ self.endpoint.client_id, self.grant, state=self.endpoint.state)
+ return add_params_to_uri(self.endpoint.redirect_uri, self.grant.items())
+
+ def create_authorization_grant(self):
+ """Generates an authorization grant represented as a dictionary."""
+ grant = {u'code': generate_token()}
+ if self.endpoint.state:
+ grant[u'state'] = self.endpoint.state
+ return grant
+
+
+class ImplicitGrantHandler(object):
+
+ @property
+ def expires_in(self):
+ return 3600
+
+ @property
+ def token_type(self):
+ return u'Bearer'
+
+ def create_implicit_grant(self):
+ return {
+ u'access_token': generate_token(),
+ u'token_type': self.token_type,
+ u'expires_in': self.expires_in,
+ u'scope': ' '.join(self.endpoint.scopes),
+ u'state': self.endpoint.state
+ }
+
+ def __call__(self, endpoint):
+ self.endpoint = endpoint
+ try:
+ self.endpoint.validate_authorization_parameters()
+
+ except OAuth2Error as e:
+ return add_params_to_uri(
+ self.endpoint.redirect_uri, e.twotuples, fragment=True)
+
+ self.grant = self.create_implicit_grant()
+ self.endpoint.save_implicit_grant(
+ self.endpoint.client_id, self.grant, state=self.endpoint.state)
+ return add_params_to_uri(
+ self.endpoint.redirect_uri, self.grant.items(), fragment=True)
+
+
class TokenEndpoint(object):
def access_token(self, uri, body, http_method=u'GET', headers=None):
diff --git a/tests/oauth2/draft25/test_server.py b/tests/oauth2/draft25/test_server.py
index 505625a..9c797df 100644
--- a/tests/oauth2/draft25/test_server.py
+++ b/tests/oauth2/draft25/test_server.py
@@ -51,6 +51,12 @@ class AuthorizationEndpointTest(TestCase):
def get_default_redirect_uri(self, client_id):
return u'http://default.redirect/uri'
+ def save_authorization_grant(self, client_id, grant, state=None):
+ pass
+
+ def save_implicit_grant(self, client_id, grant, state=None):
+ pass
+
def test_authorization_parameters(self):
tests = ((self.uri, None, []),
@@ -90,7 +96,6 @@ class AuthorizationEndpointTest(TestCase):
ae = self.SimpleAuthorizationEndpoint(valid_scopes=self.scopes_decoded)
ae.parse_authorization_parameters(uri)
uri = ae.create_authorization_response(self.scopes_decoded)
- self.assertIsNotNone(ae.grant)
self.assertIn(u'state', uri)
self.assertIn(u'code', uri)
for value in extras:
@@ -104,7 +109,6 @@ class AuthorizationEndpointTest(TestCase):
ae = self.SimpleAuthorizationEndpoint(valid_scopes=self.scopes_decoded)
ae.parse_authorization_parameters(uri)
uri = ae.create_authorization_response(self.scopes_decoded)
- self.assertIsNotNone(ae.grant)
self.assertIn(u'access_token', uri)
self.assertIn(u'token_type', uri)
self.assertIn(u'expires_in', uri)
@@ -115,24 +119,25 @@ class AuthorizationEndpointTest(TestCase):
def test_authorization_error_response(self):
tests = ((u'client_id', None, u'invalid_request'),
- (u'response_type', u'invalid', u'unsupported_response_type'),
(u'validate_client', lambda *x: False, u'unauthorized_client'),
(u'validate_scopes', lambda *x: False, u'invalid_scope'),
(u'validate_redirect_uri', lambda *x: False, u'access_denied'))
- for name, attr, result in tests:
- ae = self.SimpleAuthorizationEndpoint(valid_scopes=self.scopes_decoded)
- ae.parse_authorization_parameters(self.uri)
- setattr(ae, name, attr)
- uri = ae.create_authorization_response(self.scopes_decoded)
- self.assertIn(u'error', uri)
- self.assertIn(result, uri)
+ for uri in (self.uri, self.implicit_uri):
+ for name, attr, result in tests:
+ ae = self.SimpleAuthorizationEndpoint(valid_scopes=self.scopes_decoded)
+ ae.parse_authorization_parameters(uri)
+ setattr(ae, name, attr)
+ response_uri = ae.create_authorization_response(self.scopes_decoded)
+ self.assertIn(u'error', response_uri)
+ self.assertIn(result, response_uri)
def test_not_implemented(self):
ae = AuthorizationEndpoint()
- self.assertRaises(NotImplementedError, ae.save_authorization_grant, None, None)
self.assertRaises(NotImplementedError, ae.validate_client, None)
self.assertRaises(NotImplementedError, ae.validate_scopes, None, None)
self.assertRaises(NotImplementedError, ae.validate_redirect_uri, None, None)
self.assertRaises(NotImplementedError, ae.get_default_scopes, None)
self.assertRaises(NotImplementedError, ae.get_default_redirect_uri, None)
+ self.assertRaises(NotImplementedError, ae.save_authorization_grant, None, None)
+ self.assertRaises(NotImplementedError, ae.save_implicit_grant, None, None)