diff options
author | Ib Lundgren <ib.lundgren@gmail.com> | 2013-06-18 21:23:05 +0100 |
---|---|---|
committer | Ib Lundgren <ib.lundgren@gmail.com> | 2013-06-18 21:23:05 +0100 |
commit | 4d627ce3e0f1ebe346052b8dcae92d04a42af105 (patch) | |
tree | 04109160e73ac178cacacf4cbbf8726edd02eef2 | |
parent | 2261e99ae65fafd03aed337bf100faa6942108e3 (diff) | |
download | oauthlib-4d627ce3e0f1ebe346052b8dcae92d04a42af105.tar.gz |
Base endpoint for parameter checking and signature verification. #95
-rw-r--r-- | oauthlib/oauth1/rfc5849/endpoints/base.py | 213 | ||||
-rw-r--r-- | tests/oauth1/rfc5849/endpoints/test_base.py | 357 |
2 files changed, 570 insertions, 0 deletions
diff --git a/oauthlib/oauth1/rfc5849/endpoints/base.py b/oauthlib/oauth1/rfc5849/endpoints/base.py new file mode 100644 index 0000000..4d85c9a --- /dev/null +++ b/oauthlib/oauth1/rfc5849/endpoints/base.py @@ -0,0 +1,213 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import, unicode_literals + +""" +oauthlib.oauth1.rfc5849 +~~~~~~~~~~~~~~ + +This module is an implementation of various logic needed +for signing and checking OAuth 1.0 RFC 5849 requests. +""" + +import time + +from oauthlib.common import Request, generate_token + +from .. import signature, utils, errors +from .. import CONTENT_TYPE_FORM_URLENCODED +from .. import SIGNATURE_HMAC, SIGNATURE_RSA +from .. import SIGNATURE_TYPE_AUTH_HEADER +from .. import SIGNATURE_TYPE_QUERY +from .. import SIGNATURE_TYPE_BODY + + +class BaseEndpoint(object): + + def __init__(self, request_validator, token_generator=None): + self.request_validator = request_validator + self.token_generator = token_generator or generate_token + + def _get_signature_type_and_params(self, request): + """Extracts parameters from query, headers and body. Signature type + is set to the source in which parameters were found. + """ + # Per RFC5849, only the Authorization header may contain the 'realm' optional parameter. + header_params = signature.collect_parameters(headers=request.headers, + exclude_oauth_signature=False, with_realm=True) + body_params = signature.collect_parameters(body=request.body, + exclude_oauth_signature=False) + query_params = signature.collect_parameters(uri_query=request.uri_query, + exclude_oauth_signature=False) + + params = [] + params.extend(header_params) + params.extend(body_params) + params.extend(query_params) + signature_types_with_oauth_params = list(filter(lambda s: s[2], ( + (SIGNATURE_TYPE_AUTH_HEADER, params, + utils.filter_oauth_params(header_params)), + (SIGNATURE_TYPE_BODY, params, + utils.filter_oauth_params(body_params)), + (SIGNATURE_TYPE_QUERY, params, + utils.filter_oauth_params(query_params)) + ))) + + if len(signature_types_with_oauth_params) > 1: + found_types = [s[0] for s in signature_types_with_oauth_params] + raise errors.InvalidRequestError( + description=('oauth_ params must come from only 1 signature' + 'type but were found in %s', + ', '.join(found_types))) + + try: + signature_type, params, oauth_params = signature_types_with_oauth_params[0] + except IndexError: + raise errors.InvalidRequestError( + description='Missing mandatory OAuth parameters.') + + return signature_type, params, oauth_params + + def _create_request(self, uri, http_method, body, headers): + # Only include body data from x-www-form-urlencoded requests + headers = headers or {} + if ("Content-Type" in headers and + headers["Content-Type"] == CONTENT_TYPE_FORM_URLENCODED): + request = Request(uri, http_method, body, headers) + else: + request = Request(uri, http_method, '', headers) + + signature_type, params, oauth_params = ( + self._get_signature_type_and_params(request)) + + # The server SHOULD return a 400 (Bad Request) status code when + # receiving a request with duplicated protocol parameters. + if len(dict(oauth_params)) != len(oauth_params): + raise errors.InvalidRequestError( + description='Duplicate OAuth2 entries.') + + oauth_params = dict(oauth_params) + request.signature = oauth_params.get('oauth_signature') + request.client_key = oauth_params.get('oauth_consumer_key') + request.resource_owner_key = oauth_params.get('oauth_token') + request.nonce = oauth_params.get('oauth_nonce') + request.timestamp = oauth_params.get('oauth_timestamp') + request.redirect_uri = oauth_params.get('oauth_callback') + request.verifier = oauth_params.get('oauth_verifier') + request.signature_method = oauth_params.get('oauth_signature_method') + request.realm = dict(params).get('realm') + request.oauth_params = oauth_params + + # Parameters to Client depend on signature method which may vary + # for each request. Note that HMAC-SHA1 and PLAINTEXT share parameters + request.params = filter(lambda x: x[0] not in ("oauth_signature", "realm"), params) + + return request + + def _check_transport_security(self, request): + # TODO: move into oauthlib.common from oauth2.utils + if (self.request_validator.enforce_ssl and + not request.uri.lower().startswith("https://")): + raise errors.InsecureTransportError() + + def _check_mandatory_parameters(self, request): + # The server SHOULD return a 400 (Bad Request) status code when + # receiving a request with missing parameters. + if not all((request.signature, request.client_key, + request.nonce, request.timestamp, + request.signature_method)): + raise errors.InvalidRequestError( + description='Missing mandatory OAuth parameters.') + + # OAuth does not mandate a particular signature method, as each + # implementation can have its own unique requirements. Servers are + # free to implement and document their own custom methods. + # Recommending any particular method is beyond the scope of this + # specification. Implementers should review the Security + # Considerations section (`Section 4`_) before deciding on which + # method to support. + # .. _`Section 4`: http://tools.ietf.org/html/rfc5849#section-4 + if (not request.signature_method in + self.request_validator.allowed_signature_methods): + raise errors.InvalidSignatureMethodError( + description="Invalid signature, %s not in %r." % ( + request.signature_method, + self.request_validator.allowed_signature_methods)) + + # Servers receiving an authenticated request MUST validate it by: + # If the "oauth_version" parameter is present, ensuring its value is + # "1.0". + if ('oauth_version' in request.oauth_params and + request.oauth_params['oauth_version'] != '1.0'): + raise errors.InvalidRequestError( + description='Invalid OAuth version.') + + # The timestamp value MUST be a positive integer. Unless otherwise + # specified by the server's documentation, the timestamp is expressed + # in the number of seconds since January 1, 1970 00:00:00 GMT. + if len(request.timestamp) != 10: + raise errors.InvalidRequestError( + description='Invalid timestamp size') + + try: + ts = int(request.timestamp) + + except ValueError: + raise errors.InvalidRequestError( + description='Timestamp must be an integer.') + + else: + # To avoid the need to retain an infinite number of nonce values for + # future checks, servers MAY choose to restrict the time period after + # which a request with an old timestamp is rejected. + print(self.request_validator.timestamp_lifetime) + print(float(self.request_validator.timestamp_lifetime)) + if abs(time.time() - ts) > self.request_validator.timestamp_lifetime: + raise errors.InvalidRequestError( + description=('Timestamp given is invalid, differ from ' + 'allowed by over %s seconds.' % ( + self.request_validator.timestamp_lifetime))) + + # Provider specific validation of parameters, used to enforce + # restrictions such as character set and length. + if not self.request_validator.check_client_key(request.client_key): + raise errors.InvalidRequestError( + description='Invalid client key format.') + + if not self.request_validator.check_nonce(request.nonce): + raise errors.InvalidRequestError( + description='Invalid nonce format.') + + def _check_signature(self, request, is_token_request=False): + # ---- RSA Signature verification ---- + if request.signature_method == SIGNATURE_RSA: + # The server verifies the signature per `[RFC3447] section 8.2.2`_ + # .. _`[RFC3447] section 8.2.2`: http://tools.ietf.org/html/rfc3447#section-8.2.1 + rsa_key = self.request_validator.get_rsa_key( + request.client_key, request) + valid_signature = signature.verify_rsa_sha1(request, rsa_key) + + # ---- HMAC or Plaintext Signature verification ---- + else: + # Servers receiving an authenticated request MUST validate it by: + # Recalculating the request signature independently as described in + # `Section 3.4`_ and comparing it to the value received from the + # client via the "oauth_signature" parameter. + # .. _`Section 3.4`: http://tools.ietf.org/html/rfc5849#section-3.4 + client_secret = self.request_validator.get_client_secret( + request.client_key, request) + resource_owner_secret = None + if request.resource_owner_key: + if is_token_request: + resource_owner_secret = self.request_validator.get_request_token_secret( + request.client_key, request.resource_owner_key, request) + else: + resource_owner_secret = self.request_validator.get_access_token_secret( + request.client_key, request.resource_owner_key, request) + + if request.signature_method == SIGNATURE_HMAC: + valid_signature = signature.verify_hmac_sha1(request, + client_secret, resource_owner_secret) + else: + valid_signature = signature.verify_plaintext(request, + client_secret, resource_owner_secret) + return valid_signature diff --git a/tests/oauth1/rfc5849/endpoints/test_base.py b/tests/oauth1/rfc5849/endpoints/test_base.py new file mode 100644 index 0000000..f175e6d --- /dev/null +++ b/tests/oauth1/rfc5849/endpoints/test_base.py @@ -0,0 +1,357 @@ +from __future__ import unicode_literals, absolute_import + +from mock import MagicMock +from re import sub +from ....unittest import TestCase + +from oauthlib.common import safe_string_equals +from oauthlib.oauth1 import Client, RequestValidator +from oauthlib.oauth1.rfc5849 import errors, SIGNATURE_RSA, SIGNATURE_HMAC +from oauthlib.oauth1.rfc5849 import SIGNATURE_PLAINTEXT +from oauthlib.oauth1.rfc5849.endpoints import RequestTokenEndpoint, BaseEndpoint + + +URLENCODED = {"Content-Type": "application/x-www-form-urlencoded"} + + +class BaseEndpointTest(TestCase): + + def setUp(self): + self.validator = MagicMock(spec=RequestValidator) + self.validator.allowed_signature_methods = ['HMAC-SHA1'] + self.validator.timestamp_lifetime = 600 + self.endpoint = RequestTokenEndpoint(self.validator) + self.client = Client('foo', callback_uri='https://c.b/cb') + self.uri, self.headers, self.body = self.client.sign( + 'https://i.b/request_token') + + def test_ssl_enforcement(self): + uri, headers, _ = self.client.sign('http://i.b/request_token') + u, h, b, s = self.endpoint.create_request_token_response( + uri, headers=headers) + self.assertEqual(s, 400) + self.assertIn('insecure_transport_protocol', b) + + def test_missing_parameters(self): + u, h, b, s = self.endpoint.create_request_token_response(self.uri) + self.assertEqual(s, 400) + self.assertIn('invalid_request', b) + + def test_signature_methods(self): + headers = {} + headers['Authorization'] = self.headers['Authorization'].replace( + 'HMAC', 'RSA') + u, h, b, s = self.endpoint.create_request_token_response( + self.uri, headers=headers) + self.assertEqual(s, 400) + self.assertIn('invalid_signature_method', b) + + def test_invalid_version(self): + headers = {} + headers['Authorization'] = self.headers['Authorization'].replace( + '1.0', '2.0') + u, h, b, s = self.endpoint.create_request_token_response( + self.uri, headers=headers) + self.assertEqual(s, 400) + self.assertIn('invalid_request', b) + + def test_expired_timestamp(self): + headers = {} + for pattern in ('12345678901', '4567890123', '123456789K'): + headers['Authorization'] = sub('timestamp="\d*k?"', + 'timestamp="%s"' % pattern, + self.headers['Authorization']) + u, h, b, s = self.endpoint.create_request_token_response( + self.uri, headers=headers) + self.assertEqual(s, 400) + self.assertIn('invalid_request', b) + + def test_client_key_check(self): + self.validator.check_client_key.return_value = False + u, h, b, s = self.endpoint.create_request_token_response( + self.uri, headers=self.headers) + self.assertEqual(s, 400) + self.assertIn('invalid_request', b) + + def test_noncecheck(self): + self.validator.check_nonce.return_value = False + u, h, b, s = self.endpoint.create_request_token_response( + self.uri, headers=self.headers) + self.assertEqual(s, 400) + self.assertIn('invalid_request', b) + + def test_enforce_ssl(self): + """Ensure SSL is enforced by default.""" + v = RequestValidator() + e = BaseEndpoint(v) + c = Client('foo') + u, h, b = c.sign('http://example.com') + r = e._create_request(u, 'GET', b, h) + self.assertRaises(errors.InsecureTransportError, + e._check_transport_security, r) + + def test_multiple_source_params(self): + """Check for duplicate params""" + v = RequestValidator() + e = BaseEndpoint(v) + self.assertRaises(errors.InvalidRequestError, e._create_request, + 'https://a.b/?oauth_signature_method=HMAC-SHA1', + 'GET', 'oauth_version=foo', URLENCODED) + headers = {'Authorization': 'OAuth oauth_signature="foo"'} + headers.update(URLENCODED) + self.assertRaises(errors.InvalidRequestError, e._create_request, + 'https://a.b/?oauth_signature_method=HMAC-SHA1', + 'GET', + 'oauth_version=foo', + headers) + headers = {'Authorization': 'OAuth oauth_signature_method="foo"'} + headers.update(URLENCODED) + self.assertRaises(errors.InvalidRequestError, e._create_request, + 'https://a.b/', + 'GET', + 'oauth_signature=foo', + headers) + + def test_duplicate_params(self): + """Ensure params are only supplied once""" + v = RequestValidator() + e = BaseEndpoint(v) + self.assertRaises(errors.InvalidRequestError, e._create_request, + 'https://a.b/?oauth_version=a&oauth_version=b', + 'GET', None, URLENCODED) + self.assertRaises(errors.InvalidRequestError, e._create_request, + 'https://a.b/', 'GET', 'oauth_version=a&oauth_version=b', + URLENCODED) + + def test_mandated_params(self): + """Ensure all mandatory params are present.""" + v = RequestValidator() + e = BaseEndpoint(v) + r = e._create_request('https://a.b/', 'GET', + 'oauth_signature=a&oauth_consumer_key=b&oauth_nonce', + URLENCODED) + self.assertRaises(errors.InvalidRequestError, + e._check_mandatory_parameters, r) + + def test_oauth_version(self): + """OAuth version must be 1.0 if present.""" + v = RequestValidator() + e = BaseEndpoint(v) + r = e._create_request('https://a.b/', 'GET', + ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&' + 'oauth_timestamp=a&oauth_signature_method=RSA-SHA1&' + 'oauth_version=2.0'), + URLENCODED) + self.assertRaises(errors.InvalidRequestError, + e._check_mandatory_parameters, r) + + def test_oauth_timestamp(self): + """Check for a valid UNIX timestamp.""" + v = RequestValidator() + e = BaseEndpoint(v) + + # Invalid timestamp length, must be 10 + r = e._create_request('https://a.b/', 'GET', + ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&' + 'oauth_version=1.0&oauth_signature_method=RSA-SHA1&' + 'oauth_timestamp=123456789'), + URLENCODED) + self.assertRaises(errors.InvalidRequestError, + e._check_mandatory_parameters, r) + + # Invalid timestamp age, must be younger than 10 minutes + r = e._create_request('https://a.b/', 'GET', + ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&' + 'oauth_version=1.0&oauth_signature_method=RSA-SHA1&' + 'oauth_timestamp=1234567890'), + URLENCODED) + self.assertRaises(errors.InvalidRequestError, + e._check_mandatory_parameters, r) + + # Timestamp must be an integer + r = e._create_request('https://a.b/', 'GET', + ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&' + 'oauth_version=1.0&oauth_signature_method=RSA-SHA1&' + 'oauth_timestamp=123456789a'), + URLENCODED) + self.assertRaises(errors.InvalidRequestError, + e._check_mandatory_parameters, r) + + def test_signature_method_validation(self): + """Ensure valid signature method is used.""" + + body = ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&' + 'oauth_version=1.0&oauth_signature_method=%s&' + 'oauth_timestamp=1234567890') + + uri = 'https://example.com/' + + class HMACValidator(RequestValidator): + + @property + def allowed_signature_methods(self): + return (SIGNATURE_HMAC,) + + v = HMACValidator() + e = BaseEndpoint(v) + r = e._create_request(uri, 'GET', body % 'RSA-SHA1', URLENCODED) + self.assertRaises(errors.InvalidSignatureMethodError, + e._check_mandatory_parameters, r) + r = e._create_request(uri, 'GET', body % 'PLAINTEXT', URLENCODED) + self.assertRaises(errors.InvalidSignatureMethodError, + e._check_mandatory_parameters, r) + r = e._create_request(uri, 'GET', body % 'shibboleth', URLENCODED) + self.assertRaises(errors.InvalidSignatureMethodError, + e._check_mandatory_parameters, r) + + class RSAValidator(RequestValidator): + + @property + def allowed_signature_methods(self): + return (SIGNATURE_RSA,) + + v = RSAValidator() + e = BaseEndpoint(v) + r = e._create_request(uri, 'GET', body % 'HMAC-SHA1', URLENCODED) + self.assertRaises(errors.InvalidSignatureMethodError, + e._check_mandatory_parameters, r) + r = e._create_request(uri, 'GET', body % 'PLAINTEXT', URLENCODED) + self.assertRaises(errors.InvalidSignatureMethodError, + e._check_mandatory_parameters, r) + r = e._create_request(uri, 'GET', body % 'shibboleth', URLENCODED) + self.assertRaises(errors.InvalidSignatureMethodError, + e._check_mandatory_parameters, r) + + class PlainValidator(RequestValidator): + + @property + def allowed_signature_methods(self): + return (SIGNATURE_PLAINTEXT,) + + v = PlainValidator() + e = BaseEndpoint(v) + r = e._create_request(uri, 'GET', body % 'HMAC-SHA1', URLENCODED) + self.assertRaises(errors.InvalidSignatureMethodError, + e._check_mandatory_parameters, r) + r = e._create_request(uri, 'GET', body % 'RSA-SHA1', URLENCODED) + self.assertRaises(errors.InvalidSignatureMethodError, + e._check_mandatory_parameters, r) + r = e._create_request(uri, 'GET', body % 'shibboleth', URLENCODED) + self.assertRaises(errors.InvalidSignatureMethodError, + e._check_mandatory_parameters, r) + + +class ClientValidator(RequestValidator): + clients = ['foo'] + nonces = [('foo', 'once', '1234567891', 'fez')] + owners = {'foo': ['abcdefghijklmnopqrstuvxyz', 'fez']} + assigned_realms = {('foo', 'abcdefghijklmnopqrstuvxyz'): 'photos'} + verifiers = {('foo', 'fez'): 'shibboleth'} + + @property + def client_key_length(self): + return 1, 30 + + @property + def request_token_length(self): + return 1, 30 + + @property + def access_token_length(self): + return 1, 30 + + @property + def nonce_length(self): + return 2, 30 + + @property + def verifier_length(self): + return 2, 30 + + @property + def realms(self): + return ['photos'] + + @property + def timestamp_lifetime(self): + # Disabled check to allow hardcoded verification signatures + return 1000000000 + + @property + def dummy_client(self): + return 'dummy' + + @property + def dummy_request_token(self): + return 'dumbo' + + @property + def dummy_access_token(self): + return 'dumbo' + + def validate_timestamp_and_nonce(self, client_key, timestamp, nonce, + request, request_token=None, access_token=None): + resource_owner_key = request_token if request_token else access_token + return not (client_key, nonce, timestamp, resource_owner_key) in self.nonces + + def validate_client_key(self, client_key): + return client_key in self.clients + + def validate_access_token(self, client_key, access_token, request): + return (self.owners.get(client_key) and + access_token in self.owners.get(client_key)) + + def validate_request_token(self, client_key, request_token, request): + return (self.owners.get(client_key) and + request_token in self.owners.get(client_key)) + + def validate_requested_realm(self, client_key, realm, request): + return True + + def validate_realm(self, client_key, access_token, request, uri=None, + required_realm=None): + return (client_key, access_token) in self.assigned_realms + + def validate_verifier(self, client_key, request_token, verifier, + request): + return ((client_key, request_token) in self.verifiers and + safe_string_equals(verifier, self.verifiers.get( + (client_key, request_token)))) + + def validate_redirect_uri(self, client_key, redirect_uri, request): + return redirect_uri.startswith('http://client.example.com/') + + def get_client_secret(self, client_key, request): + return 'super secret' + + def get_access_token_secret(self, client_key, access_token, request): + return 'even more secret' + + def get_request_token_secret(self, client_key, request_token, request): + return 'even more secret' + + +class SignatureVerificationTest(TestCase): + + def test_signature_verification(self): + v = ClientValidator() + e = BaseEndpoint(v) + + uri = 'https://example.com/' + short_sig = ('oauth_signature=fmrXnTF4lO4o%2BD0%2FlZaJHP%2FXqEY&' + 'oauth_timestamp=1234567890&' + 'oauth_nonce=abcdefghijklmnopqrstuvwxyz&' + 'oauth_version=1.0&oauth_signature_method=HMAC-SHA1&' + 'oauth_token=abcdefghijklmnopqrstuvxyz&' + 'oauth_consumer_key=foo') + r = e._create_request(uri, 'GET', short_sig, URLENCODED) + self.assertFalse(e._check_signature(r)) + + plain = ('oauth_signature=correctlengthbutthewrongcontent1111&' + 'oauth_timestamp=1234567890&' + 'oauth_nonce=abcdefghijklmnopqrstuvwxyz&' + 'oauth_version=1.0&oauth_signature_method=PLAINTEXT&' + 'oauth_token=abcdefghijklmnopqrstuvxyz&' + 'oauth_consumer_key=foo') + r = e._create_request(uri, 'GET', plain, URLENCODED) + self.assertFalse(e._check_signature(r)) |