diff options
author | Ivan Kanakarakis <ivan.kanak@gmail.com> | 2023-01-31 14:41:07 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-31 14:41:07 +0200 |
commit | aa0de7cbba84d56b7021fdcfcd3a2bf39a3e249c (patch) | |
tree | 451ae979ba74a886ad71e48ecfb9c833bb82807d | |
parent | 01f5567facf5f5ef61d4f9e10ed6424b6ed2dae3 (diff) | |
parent | 30243a89c43872bb6523478b614ff0a205a01279 (diff) | |
download | pysaml2-aa0de7cbba84d56b7021fdcfcd3a2bf39a3e249c.tar.gz |
Merge pull request #894 from REANNZ/fix-ed-extensions
Fix: render extensions also for EntityDescriptor and IdPSSODescriptor
-rw-r--r-- | src/saml2/metadata.py | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/src/saml2/metadata.py b/src/saml2/metadata.py index 47823406..4266ca6e 100644 --- a/src/saml2/metadata.py +++ b/src/saml2/metadata.py @@ -533,6 +533,17 @@ def do_idpsso_descriptor(conf, cert=None, enc_cert=None): idpsso = md.IDPSSODescriptor() idpsso.protocol_support_enumeration = samlp.NAMESPACE + exts = conf.getattr("extensions", "idp") + if exts: + if idpsso.extensions is None: + idpsso.extensions = md.Extensions() + + for key, val in exts.items(): + _ext = do_extensions(key, val) + if _ext: + for _e in _ext: + idpsso.extensions.add_extension_element(_e) + endps = conf.getattr("endpoints", "idp") if endps: for (endpoint, instlist) in do_endpoints(endps, ENDPOINTS["idp"]).items(): @@ -578,6 +589,17 @@ def do_aa_descriptor(conf, cert=None, enc_cert=None): aad = md.AttributeAuthorityDescriptor() aad.protocol_support_enumeration = samlp.NAMESPACE + exts = conf.getattr("extensions", "aa") + if exts: + if aad.extensions is None: + aad.extensions = md.Extensions() + + for key, val in exts.items(): + _ext = do_extensions(key, val) + if _ext: + for _e in _ext: + aad.extensions.add_extension_element(_e) + endps = conf.getattr("endpoints", "aa") if endps: @@ -606,6 +628,17 @@ def do_aq_descriptor(conf, cert=None, enc_cert=None): aqs = md.AuthnAuthorityDescriptor() aqs.protocol_support_enumeration = samlp.NAMESPACE + exts = conf.getattr("extensions", "aa") + if exts: + if aqs.extensions is None: + aqs.extensions = md.Extensions() + + for key, val in exts.items(): + _ext = do_extensions(key, val) + if _ext: + for _e in _ext: + aqs.extensions.add_extension_element(_e) + endps = conf.getattr("endpoints", "aq") if endps: @@ -626,6 +659,17 @@ def do_pdp_descriptor(conf, cert=None, enc_cert=None): pdp.protocol_support_enumeration = samlp.NAMESPACE + exts = conf.getattr("extensions", "pdp") + if exts: + if pdp.extensions is None: + pdp.extensions = md.Extensions() + + for key, val in exts.items(): + _ext = do_extensions(key, val) + if _ext: + for _e in _ext: + pdp.extensions.add_extension_element(_e) + endps = conf.getattr("endpoints", "pdp") if endps: @@ -675,6 +719,17 @@ def entity_descriptor(confd): if confd.contact_person is not None: entd.contact_person = do_contact_persons_info(confd.contact_person) + exts = confd.extensions + if exts: + if not entd.extensions: + entd.extensions = md.Extensions() + + for key, val in exts.items(): + _ext = do_extensions(key, val) + if _ext: + for _e in _ext: + entd.extensions.add_extension_element(_e) + if confd.entity_attributes: if not entd.extensions: entd.extensions = md.Extensions() |