diff options
Diffstat (limited to 'src/saml2/s_utils.py')
-rw-r--r-- | src/saml2/s_utils.py | 80 |
1 files changed, 44 insertions, 36 deletions
diff --git a/src/saml2/s_utils.py b/src/saml2/s_utils.py index f0d983c9..6e2fd9c2 100644 --- a/src/saml2/s_utils.py +++ b/src/saml2/s_utils.py @@ -12,9 +12,9 @@ import zlib import six +from saml2 import VERSION from saml2 import saml from saml2 import samlp -from saml2 import VERSION from saml2.time_util import instant @@ -89,9 +89,28 @@ EXCEPTION2STATUS = { Exception: samlp.STATUS_AUTHN_FAILED, } -GENERIC_DOMAINS = ["aero", "asia", "biz", "cat", "com", "coop", "edu", - "gov", "info", "int", "jobs", "mil", "mobi", "museum", - "name", "net", "org", "pro", "tel", "travel"] +GENERIC_DOMAINS = [ + "aero", + "asia", + "biz", + "cat", + "com", + "coop", + "edu", + "gov", + "info", + "int", + "jobs", + "mil", + "mobi", + "museum", + "name", + "net", + "org", + "pro", + "tel", + "travel", +] def valid_email(emailaddress, domains=GENERIC_DOMAINS): @@ -104,8 +123,8 @@ def valid_email(emailaddress, domains=GENERIC_DOMAINS): # Split up email address into parts. try: - localpart, domainname = emailaddress.rsplit('@', 1) - host, toplevel = domainname.rsplit('.', 1) + localpart, domainname = emailaddress.rsplit("@", 1) + host, toplevel = domainname.rsplit(".", 1) except ValueError: return False # Address does not have enough parts. @@ -113,9 +132,9 @@ def valid_email(emailaddress, domains=GENERIC_DOMAINS): if len(toplevel) != 2 and toplevel not in domains: return False # Not a domain name. - for i in '-_.%+.': + for i in "-_.%+.": localpart = localpart.replace(i, "") - for i in '-_.': + for i in "-_.": host = host.replace(i, "") if localpart.isalnum() and host.isalnum(): @@ -125,7 +144,7 @@ def valid_email(emailaddress, domains=GENERIC_DOMAINS): def decode_base64_and_inflate(string): - """ base64 decodes and then inflates according to RFC1951 + """base64 decodes and then inflates according to RFC1951 :param string: a deflated and encoded string :return: the string after decoding and inflating @@ -142,7 +161,7 @@ def deflate_and_base64_encode(string_val): :return: The deflated and encoded string """ if not isinstance(string_val, six.binary_type): - string_val = string_val.encode('utf-8') + string_val = string_val.encode("utf-8") return base64.b64encode(zlib.compress(string_val)[2:-4]) @@ -165,7 +184,7 @@ def rndbytes(size=16, alphabet=""): """ x = rndstr(size, alphabet) if isinstance(x, six.string_types): - return x.encode('utf-8') + return x.encode("utf-8") return x @@ -214,6 +233,7 @@ def identity_attribute(form, attribute, forward_map=None): # default is name return attribute.name + # ---------------------------------------------------------------------------- @@ -228,24 +248,14 @@ def error_status_factory(info): try: exc_context = info.args[0] - err_ctx = ( - {"status_message_text": exc_context} - if isinstance(exc_context, str) - else exc_context - ) + err_ctx = {"status_message_text": exc_context} if isinstance(exc_context, str) else exc_context except IndexError: err_ctx = {"status_message_text": str(info)} status_message_text = err_ctx.get("status_message_text") - status_code_status_code_value = err_ctx.get( - "status_code_status_code_value", exc_val - ) - - status_msg = ( - samlp.StatusMessage(text=status_message_text) - if status_message_text - else None - ) + status_code_status_code_value = err_ctx.get("status_code_status_code_value", exc_val) + + status_msg = samlp.StatusMessage(text=status_message_text) if status_message_text else None status = samlp.Status( status_message=status_msg, @@ -258,20 +268,18 @@ def error_status_factory(info): def success_status_factory(): - return samlp.Status(status_code=samlp.StatusCode( - value=samlp.STATUS_SUCCESS)) + return samlp.Status(status_code=samlp.StatusCode(value=samlp.STATUS_SUCCESS)) def status_message_factory(message, code, fro=samlp.STATUS_RESPONDER): return samlp.Status( status_message=samlp.StatusMessage(text=message), - status_code=samlp.StatusCode(value=fro, - status_code=samlp.StatusCode(value=code))) + status_code=samlp.StatusCode(value=fro, status_code=samlp.StatusCode(value=code)), + ) def assertion_factory(**kwargs): - assertion = saml.Assertion(version=VERSION, id=sid(), - issue_instant=instant()) + assertion = saml.Assertion(version=VERSION, id=sid(), issue_instant=instant()) for key, val in kwargs.items(): setattr(assertion, key, val) return assertion @@ -291,6 +299,7 @@ def _attrval(val, typ=""): return attrval + # --- attribute profiles ----- # xmlns:xs="http://www.w3.org/2001/XMLSchema" @@ -381,14 +390,13 @@ def factory(klass, **kwargs): def signature(secret, parts): - """Generates a signature. All strings are assumed to be utf-8 - """ + """Generates a signature. All strings are assumed to be utf-8""" if not isinstance(secret, six.binary_type): - secret = secret.encode('utf-8') + secret = secret.encode("utf-8") newparts = [] for part in parts: if not isinstance(part, six.binary_type): - part = part.encode('utf-8') + part = part.encode("utf-8") newparts.append(part) parts = newparts csum = hmac.new(secret, digestmod=hashlib.sha1) @@ -400,7 +408,7 @@ def signature(secret, parts): def verify_signature(secret, parts): - """ Checks that the signature is correct """ + """Checks that the signature is correct""" if signature(secret, parts[:-1]) == parts[-1]: return True else: |