summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIvan Kanakarakis <ivan.kanak@gmail.com>2023-01-31 14:41:07 +0200
committerGitHub <noreply@github.com>2023-01-31 14:41:07 +0200
commitaa0de7cbba84d56b7021fdcfcd3a2bf39a3e249c (patch)
tree451ae979ba74a886ad71e48ecfb9c833bb82807d
parent01f5567facf5f5ef61d4f9e10ed6424b6ed2dae3 (diff)
parent30243a89c43872bb6523478b614ff0a205a01279 (diff)
downloadpysaml2-aa0de7cbba84d56b7021fdcfcd3a2bf39a3e249c.tar.gz
Merge pull request #894 from REANNZ/fix-ed-extensions
Fix: render extensions also for EntityDescriptor and IdPSSODescriptor
-rw-r--r--src/saml2/metadata.py55
1 files changed, 55 insertions, 0 deletions
diff --git a/src/saml2/metadata.py b/src/saml2/metadata.py
index 47823406..4266ca6e 100644
--- a/src/saml2/metadata.py
+++ b/src/saml2/metadata.py
@@ -533,6 +533,17 @@ def do_idpsso_descriptor(conf, cert=None, enc_cert=None):
idpsso = md.IDPSSODescriptor()
idpsso.protocol_support_enumeration = samlp.NAMESPACE
+ exts = conf.getattr("extensions", "idp")
+ if exts:
+ if idpsso.extensions is None:
+ idpsso.extensions = md.Extensions()
+
+ for key, val in exts.items():
+ _ext = do_extensions(key, val)
+ if _ext:
+ for _e in _ext:
+ idpsso.extensions.add_extension_element(_e)
+
endps = conf.getattr("endpoints", "idp")
if endps:
for (endpoint, instlist) in do_endpoints(endps, ENDPOINTS["idp"]).items():
@@ -578,6 +589,17 @@ def do_aa_descriptor(conf, cert=None, enc_cert=None):
aad = md.AttributeAuthorityDescriptor()
aad.protocol_support_enumeration = samlp.NAMESPACE
+ exts = conf.getattr("extensions", "aa")
+ if exts:
+ if aad.extensions is None:
+ aad.extensions = md.Extensions()
+
+ for key, val in exts.items():
+ _ext = do_extensions(key, val)
+ if _ext:
+ for _e in _ext:
+ aad.extensions.add_extension_element(_e)
+
endps = conf.getattr("endpoints", "aa")
if endps:
@@ -606,6 +628,17 @@ def do_aq_descriptor(conf, cert=None, enc_cert=None):
aqs = md.AuthnAuthorityDescriptor()
aqs.protocol_support_enumeration = samlp.NAMESPACE
+ exts = conf.getattr("extensions", "aa")
+ if exts:
+ if aqs.extensions is None:
+ aqs.extensions = md.Extensions()
+
+ for key, val in exts.items():
+ _ext = do_extensions(key, val)
+ if _ext:
+ for _e in _ext:
+ aqs.extensions.add_extension_element(_e)
+
endps = conf.getattr("endpoints", "aq")
if endps:
@@ -626,6 +659,17 @@ def do_pdp_descriptor(conf, cert=None, enc_cert=None):
pdp.protocol_support_enumeration = samlp.NAMESPACE
+ exts = conf.getattr("extensions", "pdp")
+ if exts:
+ if pdp.extensions is None:
+ pdp.extensions = md.Extensions()
+
+ for key, val in exts.items():
+ _ext = do_extensions(key, val)
+ if _ext:
+ for _e in _ext:
+ pdp.extensions.add_extension_element(_e)
+
endps = conf.getattr("endpoints", "pdp")
if endps:
@@ -675,6 +719,17 @@ def entity_descriptor(confd):
if confd.contact_person is not None:
entd.contact_person = do_contact_persons_info(confd.contact_person)
+ exts = confd.extensions
+ if exts:
+ if not entd.extensions:
+ entd.extensions = md.Extensions()
+
+ for key, val in exts.items():
+ _ext = do_extensions(key, val)
+ if _ext:
+ for _e in _ext:
+ entd.extensions.add_extension_element(_e)
+
if confd.entity_attributes:
if not entd.extensions:
entd.extensions = md.Extensions()