summaryrefslogtreecommitdiff
path: root/src/saml2/s_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/saml2/s_utils.py')
-rw-r--r--src/saml2/s_utils.py80
1 files changed, 44 insertions, 36 deletions
diff --git a/src/saml2/s_utils.py b/src/saml2/s_utils.py
index f0d983c9..6e2fd9c2 100644
--- a/src/saml2/s_utils.py
+++ b/src/saml2/s_utils.py
@@ -12,9 +12,9 @@ import zlib
import six
+from saml2 import VERSION
from saml2 import saml
from saml2 import samlp
-from saml2 import VERSION
from saml2.time_util import instant
@@ -89,9 +89,28 @@ EXCEPTION2STATUS = {
Exception: samlp.STATUS_AUTHN_FAILED,
}
-GENERIC_DOMAINS = ["aero", "asia", "biz", "cat", "com", "coop", "edu",
- "gov", "info", "int", "jobs", "mil", "mobi", "museum",
- "name", "net", "org", "pro", "tel", "travel"]
+GENERIC_DOMAINS = [
+ "aero",
+ "asia",
+ "biz",
+ "cat",
+ "com",
+ "coop",
+ "edu",
+ "gov",
+ "info",
+ "int",
+ "jobs",
+ "mil",
+ "mobi",
+ "museum",
+ "name",
+ "net",
+ "org",
+ "pro",
+ "tel",
+ "travel",
+]
def valid_email(emailaddress, domains=GENERIC_DOMAINS):
@@ -104,8 +123,8 @@ def valid_email(emailaddress, domains=GENERIC_DOMAINS):
# Split up email address into parts.
try:
- localpart, domainname = emailaddress.rsplit('@', 1)
- host, toplevel = domainname.rsplit('.', 1)
+ localpart, domainname = emailaddress.rsplit("@", 1)
+ host, toplevel = domainname.rsplit(".", 1)
except ValueError:
return False # Address does not have enough parts.
@@ -113,9 +132,9 @@ def valid_email(emailaddress, domains=GENERIC_DOMAINS):
if len(toplevel) != 2 and toplevel not in domains:
return False # Not a domain name.
- for i in '-_.%+.':
+ for i in "-_.%+.":
localpart = localpart.replace(i, "")
- for i in '-_.':
+ for i in "-_.":
host = host.replace(i, "")
if localpart.isalnum() and host.isalnum():
@@ -125,7 +144,7 @@ def valid_email(emailaddress, domains=GENERIC_DOMAINS):
def decode_base64_and_inflate(string):
- """ base64 decodes and then inflates according to RFC1951
+ """base64 decodes and then inflates according to RFC1951
:param string: a deflated and encoded string
:return: the string after decoding and inflating
@@ -142,7 +161,7 @@ def deflate_and_base64_encode(string_val):
:return: The deflated and encoded string
"""
if not isinstance(string_val, six.binary_type):
- string_val = string_val.encode('utf-8')
+ string_val = string_val.encode("utf-8")
return base64.b64encode(zlib.compress(string_val)[2:-4])
@@ -165,7 +184,7 @@ def rndbytes(size=16, alphabet=""):
"""
x = rndstr(size, alphabet)
if isinstance(x, six.string_types):
- return x.encode('utf-8')
+ return x.encode("utf-8")
return x
@@ -214,6 +233,7 @@ def identity_attribute(form, attribute, forward_map=None):
# default is name
return attribute.name
+
# ----------------------------------------------------------------------------
@@ -228,24 +248,14 @@ def error_status_factory(info):
try:
exc_context = info.args[0]
- err_ctx = (
- {"status_message_text": exc_context}
- if isinstance(exc_context, str)
- else exc_context
- )
+ err_ctx = {"status_message_text": exc_context} if isinstance(exc_context, str) else exc_context
except IndexError:
err_ctx = {"status_message_text": str(info)}
status_message_text = err_ctx.get("status_message_text")
- status_code_status_code_value = err_ctx.get(
- "status_code_status_code_value", exc_val
- )
-
- status_msg = (
- samlp.StatusMessage(text=status_message_text)
- if status_message_text
- else None
- )
+ status_code_status_code_value = err_ctx.get("status_code_status_code_value", exc_val)
+
+ status_msg = samlp.StatusMessage(text=status_message_text) if status_message_text else None
status = samlp.Status(
status_message=status_msg,
@@ -258,20 +268,18 @@ def error_status_factory(info):
def success_status_factory():
- return samlp.Status(status_code=samlp.StatusCode(
- value=samlp.STATUS_SUCCESS))
+ return samlp.Status(status_code=samlp.StatusCode(value=samlp.STATUS_SUCCESS))
def status_message_factory(message, code, fro=samlp.STATUS_RESPONDER):
return samlp.Status(
status_message=samlp.StatusMessage(text=message),
- status_code=samlp.StatusCode(value=fro,
- status_code=samlp.StatusCode(value=code)))
+ status_code=samlp.StatusCode(value=fro, status_code=samlp.StatusCode(value=code)),
+ )
def assertion_factory(**kwargs):
- assertion = saml.Assertion(version=VERSION, id=sid(),
- issue_instant=instant())
+ assertion = saml.Assertion(version=VERSION, id=sid(), issue_instant=instant())
for key, val in kwargs.items():
setattr(assertion, key, val)
return assertion
@@ -291,6 +299,7 @@ def _attrval(val, typ=""):
return attrval
+
# --- attribute profiles -----
# xmlns:xs="http://www.w3.org/2001/XMLSchema"
@@ -381,14 +390,13 @@ def factory(klass, **kwargs):
def signature(secret, parts):
- """Generates a signature. All strings are assumed to be utf-8
- """
+ """Generates a signature. All strings are assumed to be utf-8"""
if not isinstance(secret, six.binary_type):
- secret = secret.encode('utf-8')
+ secret = secret.encode("utf-8")
newparts = []
for part in parts:
if not isinstance(part, six.binary_type):
- part = part.encode('utf-8')
+ part = part.encode("utf-8")
newparts.append(part)
parts = newparts
csum = hmac.new(secret, digestmod=hashlib.sha1)
@@ -400,7 +408,7 @@ def signature(secret, parts):
def verify_signature(secret, parts):
- """ Checks that the signature is correct """
+ """Checks that the signature is correct"""
if signature(secret, parts[:-1]) == parts[-1]:
return True
else: