summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNikos Sklikas <nsklikas@admin.grnet.gr>2021-06-02 11:12:32 +0300
committerNikos Sklikas <nsklikas@admin.grnet.gr>2021-06-03 18:00:42 +0300
commitf6b625886d03f1582a7a99317e84c57d03895339 (patch)
treef9b13648a3e933fa17756fe66d2cad054412bc17
parentcebec2b075600e88c3fdcf554125ecf086e1b500 (diff)
downloadoauthlib-f6b625886d03f1582a7a99317e84c57d03895339.tar.gz
Move refresh_id_token to validator function
-rw-r--r--oauthlib/openid/connect/core/grant_types/refresh_token.py6
-rw-r--r--oauthlib/openid/connect/core/request_validator.py12
-rw-r--r--tests/openid/connect/core/grant_types/test_refresh_token.py8
3 files changed, 21 insertions, 5 deletions
diff --git a/oauthlib/openid/connect/core/grant_types/refresh_token.py b/oauthlib/openid/connect/core/grant_types/refresh_token.py
index 386b57c..43e4499 100644
--- a/oauthlib/openid/connect/core/grant_types/refresh_token.py
+++ b/oauthlib/openid/connect/core/grant_types/refresh_token.py
@@ -15,8 +15,7 @@ log = logging.getLogger(__name__)
class RefreshTokenGrant(GrantTypeBase):
- def __init__(self, refresh_id_token=True, request_validator=None, **kwargs):
- self.refresh_id_token = refresh_id_token
+ def __init__(self, request_validator=None, **kwargs):
self.proxy_target = OAuth2RefreshTokenGrant(
request_validator=request_validator, **kwargs)
self.register_token_modifier(self.add_id_token)
@@ -29,8 +28,7 @@ class RefreshTokenGrant(GrantTypeBase):
The authorization_code version of this method is used to
retrieve the nonce accordingly to the code storage.
"""
- # Treat it as normal OAuth 2 auth code request if openid is not present
- if not self.refresh_id_token:
+ if not self.request_validator.refresh_id_token(request):
return token
return super().add_id_token(token, token_handler, request)
diff --git a/oauthlib/openid/connect/core/request_validator.py b/oauthlib/openid/connect/core/request_validator.py
index e8f334b..47c4cd9 100644
--- a/oauthlib/openid/connect/core/request_validator.py
+++ b/oauthlib/openid/connect/core/request_validator.py
@@ -306,3 +306,15 @@ class RequestValidator(OAuth2RequestValidator):
Method is used by:
UserInfoEndpoint
"""
+
+ def refresh_id_token(self, request):
+ """Whether the id token should be refreshed. Default, True
+
+ :param request: OAuthlib request.
+ :type request: oauthlib.common.Request
+ :rtype: True or False
+
+ Method is used by:
+ RefreshTokenGrant
+ """
+ return True
diff --git a/tests/openid/connect/core/grant_types/test_refresh_token.py b/tests/openid/connect/core/grant_types/test_refresh_token.py
index c19de18..8126e1b 100644
--- a/tests/openid/connect/core/grant_types/test_refresh_token.py
+++ b/tests/openid/connect/core/grant_types/test_refresh_token.py
@@ -60,9 +60,12 @@ class OpenIDRefreshTokenTest(TestCase):
self.assertIn('token_type', token)
self.assertIn('expires_in', token)
self.assertEqual(token['scope'], 'hello openid')
+ self.mock_validator.refresh_id_token.assert_called_once_with(
+ self.request
+ )
def test_refresh_id_token_false(self):
- self.auth.refresh_id_token = False
+ self.mock_validator.refresh_id_token.return_value = False
self.mock_validator.get_original_scopes.return_value = [
'hello', 'openid'
]
@@ -80,6 +83,9 @@ class OpenIDRefreshTokenTest(TestCase):
self.assertIn('expires_in', token)
self.assertEqual(token['scope'], 'hello openid')
self.assertNotIn('id_token', token)
+ self.mock_validator.refresh_id_token.assert_called_once_with(
+ self.request
+ )
def test_refresh_token_without_openid_scope(self):
self.request.scope = "hello"