summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIb Lundgren <ib.lundgren@gmail.com>2013-06-18 21:23:05 +0100
committerIb Lundgren <ib.lundgren@gmail.com>2013-06-18 21:23:05 +0100
commit4d627ce3e0f1ebe346052b8dcae92d04a42af105 (patch)
tree04109160e73ac178cacacf4cbbf8726edd02eef2
parent2261e99ae65fafd03aed337bf100faa6942108e3 (diff)
downloadoauthlib-4d627ce3e0f1ebe346052b8dcae92d04a42af105.tar.gz
Base endpoint for parameter checking and signature verification. #95
-rw-r--r--oauthlib/oauth1/rfc5849/endpoints/base.py213
-rw-r--r--tests/oauth1/rfc5849/endpoints/test_base.py357
2 files changed, 570 insertions, 0 deletions
diff --git a/oauthlib/oauth1/rfc5849/endpoints/base.py b/oauthlib/oauth1/rfc5849/endpoints/base.py
new file mode 100644
index 0000000..4d85c9a
--- /dev/null
+++ b/oauthlib/oauth1/rfc5849/endpoints/base.py
@@ -0,0 +1,213 @@
+# -*- coding: utf-8 -*-
+from __future__ import absolute_import, unicode_literals
+
+"""
+oauthlib.oauth1.rfc5849
+~~~~~~~~~~~~~~
+
+This module is an implementation of various logic needed
+for signing and checking OAuth 1.0 RFC 5849 requests.
+"""
+
+import time
+
+from oauthlib.common import Request, generate_token
+
+from .. import signature, utils, errors
+from .. import CONTENT_TYPE_FORM_URLENCODED
+from .. import SIGNATURE_HMAC, SIGNATURE_RSA
+from .. import SIGNATURE_TYPE_AUTH_HEADER
+from .. import SIGNATURE_TYPE_QUERY
+from .. import SIGNATURE_TYPE_BODY
+
+
+class BaseEndpoint(object):
+
+ def __init__(self, request_validator, token_generator=None):
+ self.request_validator = request_validator
+ self.token_generator = token_generator or generate_token
+
+ def _get_signature_type_and_params(self, request):
+ """Extracts parameters from query, headers and body. Signature type
+ is set to the source in which parameters were found.
+ """
+ # Per RFC5849, only the Authorization header may contain the 'realm' optional parameter.
+ header_params = signature.collect_parameters(headers=request.headers,
+ exclude_oauth_signature=False, with_realm=True)
+ body_params = signature.collect_parameters(body=request.body,
+ exclude_oauth_signature=False)
+ query_params = signature.collect_parameters(uri_query=request.uri_query,
+ exclude_oauth_signature=False)
+
+ params = []
+ params.extend(header_params)
+ params.extend(body_params)
+ params.extend(query_params)
+ signature_types_with_oauth_params = list(filter(lambda s: s[2], (
+ (SIGNATURE_TYPE_AUTH_HEADER, params,
+ utils.filter_oauth_params(header_params)),
+ (SIGNATURE_TYPE_BODY, params,
+ utils.filter_oauth_params(body_params)),
+ (SIGNATURE_TYPE_QUERY, params,
+ utils.filter_oauth_params(query_params))
+ )))
+
+ if len(signature_types_with_oauth_params) > 1:
+ found_types = [s[0] for s in signature_types_with_oauth_params]
+ raise errors.InvalidRequestError(
+ description=('oauth_ params must come from only 1 signature'
+ 'type but were found in %s',
+ ', '.join(found_types)))
+
+ try:
+ signature_type, params, oauth_params = signature_types_with_oauth_params[0]
+ except IndexError:
+ raise errors.InvalidRequestError(
+ description='Missing mandatory OAuth parameters.')
+
+ return signature_type, params, oauth_params
+
+ def _create_request(self, uri, http_method, body, headers):
+ # Only include body data from x-www-form-urlencoded requests
+ headers = headers or {}
+ if ("Content-Type" in headers and
+ headers["Content-Type"] == CONTENT_TYPE_FORM_URLENCODED):
+ request = Request(uri, http_method, body, headers)
+ else:
+ request = Request(uri, http_method, '', headers)
+
+ signature_type, params, oauth_params = (
+ self._get_signature_type_and_params(request))
+
+ # The server SHOULD return a 400 (Bad Request) status code when
+ # receiving a request with duplicated protocol parameters.
+ if len(dict(oauth_params)) != len(oauth_params):
+ raise errors.InvalidRequestError(
+ description='Duplicate OAuth2 entries.')
+
+ oauth_params = dict(oauth_params)
+ request.signature = oauth_params.get('oauth_signature')
+ request.client_key = oauth_params.get('oauth_consumer_key')
+ request.resource_owner_key = oauth_params.get('oauth_token')
+ request.nonce = oauth_params.get('oauth_nonce')
+ request.timestamp = oauth_params.get('oauth_timestamp')
+ request.redirect_uri = oauth_params.get('oauth_callback')
+ request.verifier = oauth_params.get('oauth_verifier')
+ request.signature_method = oauth_params.get('oauth_signature_method')
+ request.realm = dict(params).get('realm')
+ request.oauth_params = oauth_params
+
+ # Parameters to Client depend on signature method which may vary
+ # for each request. Note that HMAC-SHA1 and PLAINTEXT share parameters
+ request.params = filter(lambda x: x[0] not in ("oauth_signature", "realm"), params)
+
+ return request
+
+ def _check_transport_security(self, request):
+ # TODO: move into oauthlib.common from oauth2.utils
+ if (self.request_validator.enforce_ssl and
+ not request.uri.lower().startswith("https://")):
+ raise errors.InsecureTransportError()
+
+ def _check_mandatory_parameters(self, request):
+ # The server SHOULD return a 400 (Bad Request) status code when
+ # receiving a request with missing parameters.
+ if not all((request.signature, request.client_key,
+ request.nonce, request.timestamp,
+ request.signature_method)):
+ raise errors.InvalidRequestError(
+ description='Missing mandatory OAuth parameters.')
+
+ # OAuth does not mandate a particular signature method, as each
+ # implementation can have its own unique requirements. Servers are
+ # free to implement and document their own custom methods.
+ # Recommending any particular method is beyond the scope of this
+ # specification. Implementers should review the Security
+ # Considerations section (`Section 4`_) before deciding on which
+ # method to support.
+ # .. _`Section 4`: http://tools.ietf.org/html/rfc5849#section-4
+ if (not request.signature_method in
+ self.request_validator.allowed_signature_methods):
+ raise errors.InvalidSignatureMethodError(
+ description="Invalid signature, %s not in %r." % (
+ request.signature_method,
+ self.request_validator.allowed_signature_methods))
+
+ # Servers receiving an authenticated request MUST validate it by:
+ # If the "oauth_version" parameter is present, ensuring its value is
+ # "1.0".
+ if ('oauth_version' in request.oauth_params and
+ request.oauth_params['oauth_version'] != '1.0'):
+ raise errors.InvalidRequestError(
+ description='Invalid OAuth version.')
+
+ # The timestamp value MUST be a positive integer. Unless otherwise
+ # specified by the server's documentation, the timestamp is expressed
+ # in the number of seconds since January 1, 1970 00:00:00 GMT.
+ if len(request.timestamp) != 10:
+ raise errors.InvalidRequestError(
+ description='Invalid timestamp size')
+
+ try:
+ ts = int(request.timestamp)
+
+ except ValueError:
+ raise errors.InvalidRequestError(
+ description='Timestamp must be an integer.')
+
+ else:
+ # To avoid the need to retain an infinite number of nonce values for
+ # future checks, servers MAY choose to restrict the time period after
+ # which a request with an old timestamp is rejected.
+ print(self.request_validator.timestamp_lifetime)
+ print(float(self.request_validator.timestamp_lifetime))
+ if abs(time.time() - ts) > self.request_validator.timestamp_lifetime:
+ raise errors.InvalidRequestError(
+ description=('Timestamp given is invalid, differ from '
+ 'allowed by over %s seconds.' % (
+ self.request_validator.timestamp_lifetime)))
+
+ # Provider specific validation of parameters, used to enforce
+ # restrictions such as character set and length.
+ if not self.request_validator.check_client_key(request.client_key):
+ raise errors.InvalidRequestError(
+ description='Invalid client key format.')
+
+ if not self.request_validator.check_nonce(request.nonce):
+ raise errors.InvalidRequestError(
+ description='Invalid nonce format.')
+
+ def _check_signature(self, request, is_token_request=False):
+ # ---- RSA Signature verification ----
+ if request.signature_method == SIGNATURE_RSA:
+ # The server verifies the signature per `[RFC3447] section 8.2.2`_
+ # .. _`[RFC3447] section 8.2.2`: http://tools.ietf.org/html/rfc3447#section-8.2.1
+ rsa_key = self.request_validator.get_rsa_key(
+ request.client_key, request)
+ valid_signature = signature.verify_rsa_sha1(request, rsa_key)
+
+ # ---- HMAC or Plaintext Signature verification ----
+ else:
+ # Servers receiving an authenticated request MUST validate it by:
+ # Recalculating the request signature independently as described in
+ # `Section 3.4`_ and comparing it to the value received from the
+ # client via the "oauth_signature" parameter.
+ # .. _`Section 3.4`: http://tools.ietf.org/html/rfc5849#section-3.4
+ client_secret = self.request_validator.get_client_secret(
+ request.client_key, request)
+ resource_owner_secret = None
+ if request.resource_owner_key:
+ if is_token_request:
+ resource_owner_secret = self.request_validator.get_request_token_secret(
+ request.client_key, request.resource_owner_key, request)
+ else:
+ resource_owner_secret = self.request_validator.get_access_token_secret(
+ request.client_key, request.resource_owner_key, request)
+
+ if request.signature_method == SIGNATURE_HMAC:
+ valid_signature = signature.verify_hmac_sha1(request,
+ client_secret, resource_owner_secret)
+ else:
+ valid_signature = signature.verify_plaintext(request,
+ client_secret, resource_owner_secret)
+ return valid_signature
diff --git a/tests/oauth1/rfc5849/endpoints/test_base.py b/tests/oauth1/rfc5849/endpoints/test_base.py
new file mode 100644
index 0000000..f175e6d
--- /dev/null
+++ b/tests/oauth1/rfc5849/endpoints/test_base.py
@@ -0,0 +1,357 @@
+from __future__ import unicode_literals, absolute_import
+
+from mock import MagicMock
+from re import sub
+from ....unittest import TestCase
+
+from oauthlib.common import safe_string_equals
+from oauthlib.oauth1 import Client, RequestValidator
+from oauthlib.oauth1.rfc5849 import errors, SIGNATURE_RSA, SIGNATURE_HMAC
+from oauthlib.oauth1.rfc5849 import SIGNATURE_PLAINTEXT
+from oauthlib.oauth1.rfc5849.endpoints import RequestTokenEndpoint, BaseEndpoint
+
+
+URLENCODED = {"Content-Type": "application/x-www-form-urlencoded"}
+
+
+class BaseEndpointTest(TestCase):
+
+ def setUp(self):
+ self.validator = MagicMock(spec=RequestValidator)
+ self.validator.allowed_signature_methods = ['HMAC-SHA1']
+ self.validator.timestamp_lifetime = 600
+ self.endpoint = RequestTokenEndpoint(self.validator)
+ self.client = Client('foo', callback_uri='https://c.b/cb')
+ self.uri, self.headers, self.body = self.client.sign(
+ 'https://i.b/request_token')
+
+ def test_ssl_enforcement(self):
+ uri, headers, _ = self.client.sign('http://i.b/request_token')
+ u, h, b, s = self.endpoint.create_request_token_response(
+ uri, headers=headers)
+ self.assertEqual(s, 400)
+ self.assertIn('insecure_transport_protocol', b)
+
+ def test_missing_parameters(self):
+ u, h, b, s = self.endpoint.create_request_token_response(self.uri)
+ self.assertEqual(s, 400)
+ self.assertIn('invalid_request', b)
+
+ def test_signature_methods(self):
+ headers = {}
+ headers['Authorization'] = self.headers['Authorization'].replace(
+ 'HMAC', 'RSA')
+ u, h, b, s = self.endpoint.create_request_token_response(
+ self.uri, headers=headers)
+ self.assertEqual(s, 400)
+ self.assertIn('invalid_signature_method', b)
+
+ def test_invalid_version(self):
+ headers = {}
+ headers['Authorization'] = self.headers['Authorization'].replace(
+ '1.0', '2.0')
+ u, h, b, s = self.endpoint.create_request_token_response(
+ self.uri, headers=headers)
+ self.assertEqual(s, 400)
+ self.assertIn('invalid_request', b)
+
+ def test_expired_timestamp(self):
+ headers = {}
+ for pattern in ('12345678901', '4567890123', '123456789K'):
+ headers['Authorization'] = sub('timestamp="\d*k?"',
+ 'timestamp="%s"' % pattern,
+ self.headers['Authorization'])
+ u, h, b, s = self.endpoint.create_request_token_response(
+ self.uri, headers=headers)
+ self.assertEqual(s, 400)
+ self.assertIn('invalid_request', b)
+
+ def test_client_key_check(self):
+ self.validator.check_client_key.return_value = False
+ u, h, b, s = self.endpoint.create_request_token_response(
+ self.uri, headers=self.headers)
+ self.assertEqual(s, 400)
+ self.assertIn('invalid_request', b)
+
+ def test_noncecheck(self):
+ self.validator.check_nonce.return_value = False
+ u, h, b, s = self.endpoint.create_request_token_response(
+ self.uri, headers=self.headers)
+ self.assertEqual(s, 400)
+ self.assertIn('invalid_request', b)
+
+ def test_enforce_ssl(self):
+ """Ensure SSL is enforced by default."""
+ v = RequestValidator()
+ e = BaseEndpoint(v)
+ c = Client('foo')
+ u, h, b = c.sign('http://example.com')
+ r = e._create_request(u, 'GET', b, h)
+ self.assertRaises(errors.InsecureTransportError,
+ e._check_transport_security, r)
+
+ def test_multiple_source_params(self):
+ """Check for duplicate params"""
+ v = RequestValidator()
+ e = BaseEndpoint(v)
+ self.assertRaises(errors.InvalidRequestError, e._create_request,
+ 'https://a.b/?oauth_signature_method=HMAC-SHA1',
+ 'GET', 'oauth_version=foo', URLENCODED)
+ headers = {'Authorization': 'OAuth oauth_signature="foo"'}
+ headers.update(URLENCODED)
+ self.assertRaises(errors.InvalidRequestError, e._create_request,
+ 'https://a.b/?oauth_signature_method=HMAC-SHA1',
+ 'GET',
+ 'oauth_version=foo',
+ headers)
+ headers = {'Authorization': 'OAuth oauth_signature_method="foo"'}
+ headers.update(URLENCODED)
+ self.assertRaises(errors.InvalidRequestError, e._create_request,
+ 'https://a.b/',
+ 'GET',
+ 'oauth_signature=foo',
+ headers)
+
+ def test_duplicate_params(self):
+ """Ensure params are only supplied once"""
+ v = RequestValidator()
+ e = BaseEndpoint(v)
+ self.assertRaises(errors.InvalidRequestError, e._create_request,
+ 'https://a.b/?oauth_version=a&oauth_version=b',
+ 'GET', None, URLENCODED)
+ self.assertRaises(errors.InvalidRequestError, e._create_request,
+ 'https://a.b/', 'GET', 'oauth_version=a&oauth_version=b',
+ URLENCODED)
+
+ def test_mandated_params(self):
+ """Ensure all mandatory params are present."""
+ v = RequestValidator()
+ e = BaseEndpoint(v)
+ r = e._create_request('https://a.b/', 'GET',
+ 'oauth_signature=a&oauth_consumer_key=b&oauth_nonce',
+ URLENCODED)
+ self.assertRaises(errors.InvalidRequestError,
+ e._check_mandatory_parameters, r)
+
+ def test_oauth_version(self):
+ """OAuth version must be 1.0 if present."""
+ v = RequestValidator()
+ e = BaseEndpoint(v)
+ r = e._create_request('https://a.b/', 'GET',
+ ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
+ 'oauth_timestamp=a&oauth_signature_method=RSA-SHA1&'
+ 'oauth_version=2.0'),
+ URLENCODED)
+ self.assertRaises(errors.InvalidRequestError,
+ e._check_mandatory_parameters, r)
+
+ def test_oauth_timestamp(self):
+ """Check for a valid UNIX timestamp."""
+ v = RequestValidator()
+ e = BaseEndpoint(v)
+
+ # Invalid timestamp length, must be 10
+ r = e._create_request('https://a.b/', 'GET',
+ ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
+ 'oauth_version=1.0&oauth_signature_method=RSA-SHA1&'
+ 'oauth_timestamp=123456789'),
+ URLENCODED)
+ self.assertRaises(errors.InvalidRequestError,
+ e._check_mandatory_parameters, r)
+
+ # Invalid timestamp age, must be younger than 10 minutes
+ r = e._create_request('https://a.b/', 'GET',
+ ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
+ 'oauth_version=1.0&oauth_signature_method=RSA-SHA1&'
+ 'oauth_timestamp=1234567890'),
+ URLENCODED)
+ self.assertRaises(errors.InvalidRequestError,
+ e._check_mandatory_parameters, r)
+
+ # Timestamp must be an integer
+ r = e._create_request('https://a.b/', 'GET',
+ ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
+ 'oauth_version=1.0&oauth_signature_method=RSA-SHA1&'
+ 'oauth_timestamp=123456789a'),
+ URLENCODED)
+ self.assertRaises(errors.InvalidRequestError,
+ e._check_mandatory_parameters, r)
+
+ def test_signature_method_validation(self):
+ """Ensure valid signature method is used."""
+
+ body = ('oauth_signature=a&oauth_consumer_key=b&oauth_nonce=c&'
+ 'oauth_version=1.0&oauth_signature_method=%s&'
+ 'oauth_timestamp=1234567890')
+
+ uri = 'https://example.com/'
+
+ class HMACValidator(RequestValidator):
+
+ @property
+ def allowed_signature_methods(self):
+ return (SIGNATURE_HMAC,)
+
+ v = HMACValidator()
+ e = BaseEndpoint(v)
+ r = e._create_request(uri, 'GET', body % 'RSA-SHA1', URLENCODED)
+ self.assertRaises(errors.InvalidSignatureMethodError,
+ e._check_mandatory_parameters, r)
+ r = e._create_request(uri, 'GET', body % 'PLAINTEXT', URLENCODED)
+ self.assertRaises(errors.InvalidSignatureMethodError,
+ e._check_mandatory_parameters, r)
+ r = e._create_request(uri, 'GET', body % 'shibboleth', URLENCODED)
+ self.assertRaises(errors.InvalidSignatureMethodError,
+ e._check_mandatory_parameters, r)
+
+ class RSAValidator(RequestValidator):
+
+ @property
+ def allowed_signature_methods(self):
+ return (SIGNATURE_RSA,)
+
+ v = RSAValidator()
+ e = BaseEndpoint(v)
+ r = e._create_request(uri, 'GET', body % 'HMAC-SHA1', URLENCODED)
+ self.assertRaises(errors.InvalidSignatureMethodError,
+ e._check_mandatory_parameters, r)
+ r = e._create_request(uri, 'GET', body % 'PLAINTEXT', URLENCODED)
+ self.assertRaises(errors.InvalidSignatureMethodError,
+ e._check_mandatory_parameters, r)
+ r = e._create_request(uri, 'GET', body % 'shibboleth', URLENCODED)
+ self.assertRaises(errors.InvalidSignatureMethodError,
+ e._check_mandatory_parameters, r)
+
+ class PlainValidator(RequestValidator):
+
+ @property
+ def allowed_signature_methods(self):
+ return (SIGNATURE_PLAINTEXT,)
+
+ v = PlainValidator()
+ e = BaseEndpoint(v)
+ r = e._create_request(uri, 'GET', body % 'HMAC-SHA1', URLENCODED)
+ self.assertRaises(errors.InvalidSignatureMethodError,
+ e._check_mandatory_parameters, r)
+ r = e._create_request(uri, 'GET', body % 'RSA-SHA1', URLENCODED)
+ self.assertRaises(errors.InvalidSignatureMethodError,
+ e._check_mandatory_parameters, r)
+ r = e._create_request(uri, 'GET', body % 'shibboleth', URLENCODED)
+ self.assertRaises(errors.InvalidSignatureMethodError,
+ e._check_mandatory_parameters, r)
+
+
+class ClientValidator(RequestValidator):
+ clients = ['foo']
+ nonces = [('foo', 'once', '1234567891', 'fez')]
+ owners = {'foo': ['abcdefghijklmnopqrstuvxyz', 'fez']}
+ assigned_realms = {('foo', 'abcdefghijklmnopqrstuvxyz'): 'photos'}
+ verifiers = {('foo', 'fez'): 'shibboleth'}
+
+ @property
+ def client_key_length(self):
+ return 1, 30
+
+ @property
+ def request_token_length(self):
+ return 1, 30
+
+ @property
+ def access_token_length(self):
+ return 1, 30
+
+ @property
+ def nonce_length(self):
+ return 2, 30
+
+ @property
+ def verifier_length(self):
+ return 2, 30
+
+ @property
+ def realms(self):
+ return ['photos']
+
+ @property
+ def timestamp_lifetime(self):
+ # Disabled check to allow hardcoded verification signatures
+ return 1000000000
+
+ @property
+ def dummy_client(self):
+ return 'dummy'
+
+ @property
+ def dummy_request_token(self):
+ return 'dumbo'
+
+ @property
+ def dummy_access_token(self):
+ return 'dumbo'
+
+ def validate_timestamp_and_nonce(self, client_key, timestamp, nonce,
+ request, request_token=None, access_token=None):
+ resource_owner_key = request_token if request_token else access_token
+ return not (client_key, nonce, timestamp, resource_owner_key) in self.nonces
+
+ def validate_client_key(self, client_key):
+ return client_key in self.clients
+
+ def validate_access_token(self, client_key, access_token, request):
+ return (self.owners.get(client_key) and
+ access_token in self.owners.get(client_key))
+
+ def validate_request_token(self, client_key, request_token, request):
+ return (self.owners.get(client_key) and
+ request_token in self.owners.get(client_key))
+
+ def validate_requested_realm(self, client_key, realm, request):
+ return True
+
+ def validate_realm(self, client_key, access_token, request, uri=None,
+ required_realm=None):
+ return (client_key, access_token) in self.assigned_realms
+
+ def validate_verifier(self, client_key, request_token, verifier,
+ request):
+ return ((client_key, request_token) in self.verifiers and
+ safe_string_equals(verifier, self.verifiers.get(
+ (client_key, request_token))))
+
+ def validate_redirect_uri(self, client_key, redirect_uri, request):
+ return redirect_uri.startswith('http://client.example.com/')
+
+ def get_client_secret(self, client_key, request):
+ return 'super secret'
+
+ def get_access_token_secret(self, client_key, access_token, request):
+ return 'even more secret'
+
+ def get_request_token_secret(self, client_key, request_token, request):
+ return 'even more secret'
+
+
+class SignatureVerificationTest(TestCase):
+
+ def test_signature_verification(self):
+ v = ClientValidator()
+ e = BaseEndpoint(v)
+
+ uri = 'https://example.com/'
+ short_sig = ('oauth_signature=fmrXnTF4lO4o%2BD0%2FlZaJHP%2FXqEY&'
+ 'oauth_timestamp=1234567890&'
+ 'oauth_nonce=abcdefghijklmnopqrstuvwxyz&'
+ 'oauth_version=1.0&oauth_signature_method=HMAC-SHA1&'
+ 'oauth_token=abcdefghijklmnopqrstuvxyz&'
+ 'oauth_consumer_key=foo')
+ r = e._create_request(uri, 'GET', short_sig, URLENCODED)
+ self.assertFalse(e._check_signature(r))
+
+ plain = ('oauth_signature=correctlengthbutthewrongcontent1111&'
+ 'oauth_timestamp=1234567890&'
+ 'oauth_nonce=abcdefghijklmnopqrstuvwxyz&'
+ 'oauth_version=1.0&oauth_signature_method=PLAINTEXT&'
+ 'oauth_token=abcdefghijklmnopqrstuvxyz&'
+ 'oauth_consumer_key=foo')
+ r = e._create_request(uri, 'GET', plain, URLENCODED)
+ self.assertFalse(e._check_signature(r))