summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNikos Sklikas <nsklikas@admin.grnet.gr>2021-03-22 16:38:38 +0200
committerNikos Sklikas <nsklikas@admin.grnet.gr>2021-06-03 18:00:12 +0300
commit595bf5f98ab785aa64840ed469fb1b9dc09bdb9e (patch)
tree11d8e9fe68726582cba3ce6f1d4420e6db5d135b
parent5a1e7483749229e01442cfb969916a14a2078789 (diff)
downloadoauthlib-595bf5f98ab785aa64840ed469fb1b9dc09bdb9e.tar.gz
Add support for refreshing ID Tokens
-rw-r--r--oauthlib/openid/connect/core/grant_types/__init__.py1
-rw-r--r--oauthlib/openid/connect/core/grant_types/refresh_token.py36
-rw-r--r--tests/openid/connect/core/grant_types/test_refresh_token.py99
3 files changed, 136 insertions, 0 deletions
diff --git a/oauthlib/openid/connect/core/grant_types/__init__.py b/oauthlib/openid/connect/core/grant_types/__init__.py
index 887a585..8dad5f6 100644
--- a/oauthlib/openid/connect/core/grant_types/__init__.py
+++ b/oauthlib/openid/connect/core/grant_types/__init__.py
@@ -10,3 +10,4 @@ from .dispatchers import (
)
from .hybrid import HybridGrant
from .implicit import ImplicitGrant
+from .refresh_token import RefreshTokenGrant
diff --git a/oauthlib/openid/connect/core/grant_types/refresh_token.py b/oauthlib/openid/connect/core/grant_types/refresh_token.py
new file mode 100644
index 0000000..386b57c
--- /dev/null
+++ b/oauthlib/openid/connect/core/grant_types/refresh_token.py
@@ -0,0 +1,36 @@
+"""
+oauthlib.openid.connect.core.grant_types
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+"""
+import logging
+
+from oauthlib.oauth2.rfc6749.grant_types.refresh_token import (
+ RefreshTokenGrant as OAuth2RefreshTokenGrant,
+)
+
+from .base import GrantTypeBase
+
+log = logging.getLogger(__name__)
+
+
+class RefreshTokenGrant(GrantTypeBase):
+
+ def __init__(self, refresh_id_token=True, request_validator=None, **kwargs):
+ self.refresh_id_token = refresh_id_token
+ self.proxy_target = OAuth2RefreshTokenGrant(
+ request_validator=request_validator, **kwargs)
+ self.register_token_modifier(self.add_id_token)
+
+ def add_id_token(self, token, token_handler, request):
+ """
+ Construct an initial version of id_token, and let the
+ request_validator sign or encrypt it.
+
+ The authorization_code version of this method is used to
+ retrieve the nonce accordingly to the code storage.
+ """
+ # Treat it as normal OAuth 2 auth code request if openid is not present
+ if not self.refresh_id_token:
+ return token
+
+ return super().add_id_token(token, token_handler, request)
diff --git a/tests/openid/connect/core/grant_types/test_refresh_token.py b/tests/openid/connect/core/grant_types/test_refresh_token.py
new file mode 100644
index 0000000..c19de18
--- /dev/null
+++ b/tests/openid/connect/core/grant_types/test_refresh_token.py
@@ -0,0 +1,99 @@
+import json
+from unittest import mock
+
+from oauthlib.common import Request
+from oauthlib.oauth2.rfc6749.tokens import BearerToken
+from oauthlib.openid.connect.core.grant_types import RefreshTokenGrant
+
+from tests.oauth2.rfc6749.grant_types.test_refresh_token import (
+ RefreshTokenGrantTest,
+)
+from tests.unittest import TestCase
+
+
+def get_id_token_mock(token, token_handler, request):
+ return "MOCKED_TOKEN"
+
+
+class OpenIDRefreshTokenInterferenceTest(RefreshTokenGrantTest):
+ """Test that OpenID don't interfere with normal OAuth 2 flows."""
+
+ def setUp(self):
+ super().setUp()
+ self.auth = RefreshTokenGrant(request_validator=self.mock_validator)
+
+
+class OpenIDRefreshTokenTest(TestCase):
+
+ def setUp(self):
+ self.request = Request('http://a.b/path')
+ self.request.grant_type = 'refresh_token'
+ self.request.refresh_token = 'lsdkfhj230'
+ self.request.scope = ('hello', 'openid')
+ self.mock_validator = mock.MagicMock()
+
+ self.mock_validator = mock.MagicMock()
+ self.mock_validator.authenticate_client.side_effect = self.set_client
+ self.mock_validator.get_id_token.side_effect = get_id_token_mock
+ self.auth = RefreshTokenGrant(request_validator=self.mock_validator)
+
+ def set_client(self, request):
+ request.client = mock.MagicMock()
+ request.client.client_id = 'mocked'
+ return True
+
+ def test_refresh_id_token(self):
+ self.mock_validator.get_original_scopes.return_value = [
+ 'hello', 'openid'
+ ]
+ bearer = BearerToken(self.mock_validator)
+
+ headers, body, status_code = self.auth.create_token_response(
+ self.request, bearer
+ )
+
+ token = json.loads(body)
+ self.assertEqual(self.mock_validator.save_token.call_count, 1)
+ self.assertIn('access_token', token)
+ self.assertIn('refresh_token', token)
+ self.assertIn('id_token', token)
+ self.assertIn('token_type', token)
+ self.assertIn('expires_in', token)
+ self.assertEqual(token['scope'], 'hello openid')
+
+ def test_refresh_id_token_false(self):
+ self.auth.refresh_id_token = False
+ self.mock_validator.get_original_scopes.return_value = [
+ 'hello', 'openid'
+ ]
+ bearer = BearerToken(self.mock_validator)
+
+ headers, body, status_code = self.auth.create_token_response(
+ self.request, bearer
+ )
+
+ token = json.loads(body)
+ self.assertEqual(self.mock_validator.save_token.call_count, 1)
+ self.assertIn('access_token', token)
+ self.assertIn('refresh_token', token)
+ self.assertIn('token_type', token)
+ self.assertIn('expires_in', token)
+ self.assertEqual(token['scope'], 'hello openid')
+ self.assertNotIn('id_token', token)
+
+ def test_refresh_token_without_openid_scope(self):
+ self.request.scope = "hello"
+ bearer = BearerToken(self.mock_validator)
+
+ headers, body, status_code = self.auth.create_token_response(
+ self.request, bearer
+ )
+
+ token = json.loads(body)
+ self.assertEqual(self.mock_validator.save_token.call_count, 1)
+ self.assertIn('access_token', token)
+ self.assertIn('refresh_token', token)
+ self.assertIn('token_type', token)
+ self.assertIn('expires_in', token)
+ self.assertNotIn('id_token', token)
+ self.assertEqual(token['scope'], 'hello')