from __future__ import print_function import io import os import sys import unittest import warnings from xml.etree import ElementTree as orig_elementtree from xml.sax.saxutils import XMLGenerator from xml.sax import SAXParseException from pyexpat import ExpatError from defusedxml import ElementTree, minidom, pulldom, sax, xmlrpc, expatreader from defusedxml import defuse_stdlib from defusedxml import ( DTDForbidden, EntitiesForbidden, ExternalReferenceForbidden, NotSupportedError, ) if sys.version_info < (3, 7): warnings.filterwarnings("once", category=DeprecationWarning) with warnings.catch_warnings(record=True) as cetree_warnings: from defusedxml import cElementTree try: import gzip except ImportError: gzip = None try: with warnings.catch_warnings(record=True) as lxml_warnings: from defusedxml import lxml from lxml.etree import XMLSyntaxError LXML3 = lxml.LXML3 except ImportError: lxml = None XMLSyntaxError = None LXML3 = False lxml_warnings = None warnings.filterwarnings("error", category=DeprecationWarning, module=r"defusedxml\..*") HERE = os.path.dirname(os.path.abspath(__file__)) # prevent web access # based on Debian's rules, Port 9 is discard os.environ["http_proxy"] = "http://127.0.9.1:9" os.environ["https_proxy"] = os.environ["http_proxy"] os.environ["ftp_proxy"] = os.environ["http_proxy"] class DefusedTestCase(unittest.TestCase): content_binary = False xml_dtd = os.path.join(HERE, "xmltestdata", "dtd.xml") xml_external = os.path.join(HERE, "xmltestdata", "external.xml") xml_external_file = os.path.join(HERE, "xmltestdata", "external_file.xml") xml_quadratic = os.path.join(HERE, "xmltestdata", "quadratic.xml") xml_simple = os.path.join(HERE, "xmltestdata", "simple.xml") xml_simple_ns = os.path.join(HERE, "xmltestdata", "simple-ns.xml") xml_bomb = os.path.join(HERE, "xmltestdata", "xmlbomb.xml") xml_bomb2 = os.path.join(HERE, "xmltestdata", "xmlbomb2.xml") xml_cyclic = os.path.join(HERE, "xmltestdata", "cyclic.xml") def get_content(self, xmlfile): mode = "rb" if self.content_binary else "r" with io.open(xmlfile, mode) as f: data = f.read() return data class BaseTests(DefusedTestCase): module = None dtd_external_ref = False external_ref_exception = ExternalReferenceForbidden cyclic_error = None iterparse = None def test_simple_parse(self): self.parse(self.xml_simple) self.parseString(self.get_content(self.xml_simple)) if self.iterparse: self.iterparse(self.xml_simple) def test_simple_parse_ns(self): self.parse(self.xml_simple_ns) self.parseString(self.get_content(self.xml_simple_ns)) if self.iterparse: self.iterparse(self.xml_simple_ns) def test_entities_forbidden(self): self.assertRaises(EntitiesForbidden, self.parse, self.xml_bomb) self.assertRaises(EntitiesForbidden, self.parse, self.xml_quadratic) self.assertRaises(EntitiesForbidden, self.parse, self.xml_external) self.assertRaises(EntitiesForbidden, self.parseString, self.get_content(self.xml_bomb)) self.assertRaises( EntitiesForbidden, self.parseString, self.get_content(self.xml_quadratic) ) self.assertRaises( EntitiesForbidden, self.parseString, self.get_content(self.xml_external) ) if self.iterparse: self.assertRaises(EntitiesForbidden, self.iterparse, self.xml_bomb) self.assertRaises(EntitiesForbidden, self.iterparse, self.xml_quadratic) self.assertRaises(EntitiesForbidden, self.iterparse, self.xml_external) def test_entity_cycle(self): self.assertRaises(self.cyclic_error, self.parse, self.xml_cyclic, forbid_entities=False) def test_dtd_forbidden(self): self.assertRaises(DTDForbidden, self.parse, self.xml_bomb, forbid_dtd=True) self.assertRaises(DTDForbidden, self.parse, self.xml_quadratic, forbid_dtd=True) self.assertRaises(DTDForbidden, self.parse, self.xml_external, forbid_dtd=True) self.assertRaises(DTDForbidden, self.parse, self.xml_dtd, forbid_dtd=True) self.assertRaises( DTDForbidden, self.parseString, self.get_content(self.xml_bomb), forbid_dtd=True ) self.assertRaises( DTDForbidden, self.parseString, self.get_content(self.xml_quadratic), forbid_dtd=True ) self.assertRaises( DTDForbidden, self.parseString, self.get_content(self.xml_external), forbid_dtd=True ) self.assertRaises( DTDForbidden, self.parseString, self.get_content(self.xml_dtd), forbid_dtd=True ) if self.iterparse: self.assertRaises(DTDForbidden, self.iterparse, self.xml_bomb, forbid_dtd=True) self.assertRaises(DTDForbidden, self.iterparse, self.xml_quadratic, forbid_dtd=True) self.assertRaises(DTDForbidden, self.iterparse, self.xml_external, forbid_dtd=True) self.assertRaises(DTDForbidden, self.iterparse, self.xml_dtd, forbid_dtd=True) def test_dtd_with_external_ref(self): if self.dtd_external_ref: self.assertRaises(self.external_ref_exception, self.parse, self.xml_dtd) else: self.parse(self.xml_dtd) def test_external_ref(self): self.assertRaises( self.external_ref_exception, self.parse, self.xml_external, forbid_entities=False ) def test_external_file_ref(self): content = self.get_content(self.xml_external_file) if isinstance(content, bytes): here = HERE.encode(sys.getfilesystemencoding()) content = content.replace(b"/PATH/TO", here) else: content = content.replace("/PATH/TO", HERE) self.assertRaises( self.external_ref_exception, self.parseString, content, forbid_entities=False ) def test_allow_expansion(self): self.parse(self.xml_bomb2, forbid_entities=False) self.parseString(self.get_content(self.xml_bomb2), forbid_entities=False) class TestDefusedElementTree(BaseTests): module = ElementTree # etree doesn't do external ref lookup # external_ref_exception = ElementTree.ParseError cyclic_error = ElementTree.ParseError def parse(self, xmlfile, **kwargs): tree = self.module.parse(xmlfile, **kwargs) return self.module.tostring(tree.getroot()) def parseString(self, xmlstring, **kwargs): tree = self.module.fromstring(xmlstring, **kwargs) return self.module.tostring(tree) def parseStringList(self, sequence, **kwargs): tree = self.module.fromstringlist(sequence, **kwargs) return self.module.tostring(tree) def iterparse(self, source, **kwargs): return list(self.module.iterparse(source, **kwargs)) def test_html_arg(self): with self.assertRaises(DeprecationWarning): ElementTree.XMLParse(html=0) with self.assertRaises(TypeError): ElementTree.XMLParse(html=1) def test_aliases(self): parser = self.module.DefusedXMLParser assert self.module.XMLTreeBuilder is parser assert self.module.XMLParser is parser assert self.module.XMLParse is parser def test_fromstringlist(self): seq = ["", '', '', ""] tree = self.module.fromstringlist(seq) result = self.module.tostring(tree) self.assertEqual(result, "".join(seq).encode("utf-8")) def test_import_order(self): from xml.etree import ElementTree as second_elementtree self.assertIs(orig_elementtree, second_elementtree) def test_orig_parseerror(self): # https://github.com/tiran/defusedxml/issues/63 self.assertIs(self.module.ParseError, orig_elementtree.ParseError) try: self.parseString("invalid") except Exception as e: self.assertIsInstance(e, orig_elementtree.ParseError) self.assertIsInstance(e, self.module.ParseError) class TestDefusedcElementTree(TestDefusedElementTree): module = cElementTree def test_celementtree_warnings(self): self.assertTrue(cetree_warnings) self.assertEqual(cetree_warnings[0].category, DeprecationWarning) self.assertIn("tests.py", cetree_warnings[0].filename) class TestDefusedMinidom(BaseTests): module = minidom cyclic_error = ExpatError iterparse = None def parse(self, xmlfile, **kwargs): doc = self.module.parse(xmlfile, **kwargs) return doc.toxml() def parseString(self, xmlstring, **kwargs): doc = self.module.parseString(xmlstring, **kwargs) return doc.toxml() class TestDefusedMinidomWithParser(TestDefusedMinidom): cyclic_error = SAXParseException dtd_external_ref = True def parse(self, xmlfile, **kwargs): doc = self.module.parse(xmlfile, parser=expatreader.create_parser(**kwargs), **kwargs) return doc.toxml() def parseString(self, xmlstring, **kwargs): doc = self.module.parseString( xmlstring, parser=expatreader.create_parser(**kwargs), **kwargs ) return doc.toxml() class TestDefusedPulldom(BaseTests): module = pulldom cyclic_error = SAXParseException dtd_external_ref = True def parse(self, xmlfile, **kwargs): events = self.module.parse(xmlfile, **kwargs) return list(events) def parseString(self, xmlstring, **kwargs): events = self.module.parseString(xmlstring, **kwargs) return list(events) class TestDefusedSax(BaseTests): module = sax cyclic_error = SAXParseException content_binary = True dtd_external_ref = True def parse(self, xmlfile, **kwargs): result = io.StringIO() handler = XMLGenerator(result) self.module.parse(xmlfile, handler, **kwargs) return result.getvalue() def parseString(self, xmlstring, **kwargs): result = io.StringIO() handler = XMLGenerator(result) self.module.parseString(xmlstring, handler, **kwargs) return result.getvalue() def test_exceptions(self): with self.assertRaises(EntitiesForbidden) as ctx: self.parse(self.xml_bomb) msg = "EntitiesForbidden(name='a', system_id=None, public_id=None)" self.assertEqual(str(ctx.exception), msg) self.assertEqual(repr(ctx.exception), msg) with self.assertRaises(ExternalReferenceForbidden) as ctx: self.parse(self.xml_external, forbid_entities=False) msg = ( "ExternalReferenceForbidden" "(system_id='http://www.w3schools.com/xml/note.xml', public_id=None)" ) self.assertEqual(str(ctx.exception), msg) self.assertEqual(repr(ctx.exception), msg) with self.assertRaises(DTDForbidden) as ctx: self.parse(self.xml_bomb, forbid_dtd=True) msg = "DTDForbidden(name='xmlbomb', system_id=None, public_id=None)" self.assertEqual(str(ctx.exception), msg) self.assertEqual(repr(ctx.exception), msg) class TestDefusedLxml(BaseTests): module = lxml cyclic_error = XMLSyntaxError content_binary = True def parse(self, xmlfile, **kwargs): try: tree = self.module.parse(xmlfile, **kwargs) except XMLSyntaxError: self.skipTest("lxml detects entityt reference loop") return self.module.tostring(tree) def parseString(self, xmlstring, **kwargs): try: tree = self.module.fromstring(xmlstring, **kwargs) except XMLSyntaxError: self.skipTest("lxml detects entityt reference loop") return self.module.tostring(tree) if not LXML3: def test_entities_forbidden(self): self.assertRaises(NotSupportedError, self.parse, self.xml_bomb) def test_dtd_with_external_ref(self): self.assertRaises(NotSupportedError, self.parse, self.xml_dtd) def test_external_ref(self): pass def test_external_file_ref(self): pass def test_restricted_element1(self): try: tree = self.module.parse(self.xml_bomb, forbid_dtd=False, forbid_entities=False) except XMLSyntaxError: self.skipTest("lxml detects entityt reference loop") root = tree.getroot() self.assertEqual(root.text, None) self.assertEqual(list(root), []) self.assertEqual(root.getchildren(), []) self.assertEqual(list(root.iter()), [root]) self.assertEqual(list(root.iterchildren()), []) self.assertEqual(list(root.iterdescendants()), []) self.assertEqual(list(root.itersiblings()), []) self.assertEqual(list(root.getiterator()), [root]) self.assertEqual(root.getnext(), None) def test_restricted_element2(self): try: tree = self.module.parse(self.xml_bomb2, forbid_dtd=False, forbid_entities=False) except XMLSyntaxError: self.skipTest("lxml detects entityt reference loop") root = tree.getroot() bomb, tag = root self.assertEqual(root.text, "text") self.assertEqual(list(root), [bomb, tag]) self.assertEqual(root.getchildren(), [bomb, tag]) self.assertEqual(list(root.iter()), [root, bomb, tag]) self.assertEqual(list(root.iterchildren()), [bomb, tag]) self.assertEqual(list(root.iterdescendants()), [bomb, tag]) self.assertEqual(list(root.itersiblings()), []) self.assertEqual(list(root.getiterator()), [root, bomb, tag]) self.assertEqual(root.getnext(), None) self.assertEqual(root.getprevious(), None) self.assertEqual(list(bomb.itersiblings()), [tag]) self.assertEqual(bomb.getnext(), tag) self.assertEqual(bomb.getprevious(), None) self.assertEqual(tag.getnext(), None) self.assertEqual(tag.getprevious(), bomb) def test_xpath_injection(self): # show XPath injection vulnerability xml = """""" expr = "one' or @id='two" root = lxml.fromstring(xml) # insecure way xp = "tag[@id='%s']" % expr elements = root.xpath(xp) self.assertEqual(len(elements), 2) self.assertEqual(elements, list(root)) # proper and safe way xp = "tag[@id=$idname]" elements = root.xpath(xp, idname=expr) self.assertEqual(len(elements), 0) self.assertEqual(elements, []) elements = root.xpath(xp, idname="one") self.assertEqual(len(elements), 1) self.assertEqual(elements, list(root)[:1]) def test_lxml_warnings(self): self.assertTrue(lxml_warnings) self.assertEqual(lxml_warnings[0].category, DeprecationWarning) self.assertIn("tests.py", lxml_warnings[0].filename) class XmlRpcTarget(object): def __init__(self): self._data = [] def __str__(self): return "".join(self._data) def xml(self, encoding, standalone): pass def start(self, tag, attrs): self._data.append("<%s>" % tag) def data(self, text): self._data.append(text) def end(self, tag): self._data.append("" % tag) class TestXmlRpc(DefusedTestCase): module = xmlrpc def parse(self, xmlfile, **kwargs): target = XmlRpcTarget() parser = self.module.DefusedExpatParser(target, **kwargs) data = self.get_content(xmlfile) parser.feed(data) parser.close() return target def parse_unpatched(self, xmlfile): target = XmlRpcTarget() parser = self.module.ExpatParser(target) data = self.get_content(xmlfile) parser.feed(data) parser.close() return target def test_xmlrpc(self): self.assertRaises(EntitiesForbidden, self.parse, self.xml_bomb) self.assertRaises(EntitiesForbidden, self.parse, self.xml_quadratic) self.parse(self.xml_dtd) self.assertRaises(DTDForbidden, self.parse, self.xml_dtd, forbid_dtd=True) # def test_xmlrpc_unpatched(self): # for fname in (self.xml_external, self.xml_dtd): # print(self.parse_unpatched(fname)) def test_monkeypatch(self): try: xmlrpc.monkey_patch() finally: xmlrpc.unmonkey_patch() class TestDefusedGzip(DefusedTestCase): def get_gzipped(self, length): f = io.BytesIO() gzf = gzip.GzipFile(mode="wb", fileobj=f) gzf.write(b"d" * length) gzf.close() f.seek(0) return f def decode_response(self, response, limit=None, readlength=1024): dec = xmlrpc.DefusedGzipDecodedResponse(response, limit) acc = [] while True: data = dec.read(readlength) if not data: break acc.append(data) return b"".join(acc) def test_defused_gzip_decode(self): data = self.get_gzipped(4096).getvalue() result = xmlrpc.defused_gzip_decode(data) self.assertEqual(result, b"d" * 4096) result = xmlrpc.defused_gzip_decode(data, -1) self.assertEqual(result, b"d" * 4096) result = xmlrpc.defused_gzip_decode(data, 4096) self.assertEqual(result, b"d" * 4096) with self.assertRaises(ValueError): result = xmlrpc.defused_gzip_decode(data, 4095) with self.assertRaises(ValueError): result = xmlrpc.defused_gzip_decode(data, 0) def test_defused_gzip_response(self): clen = len(self.get_gzipped(4096).getvalue()) response = self.get_gzipped(4096) data = self.decode_response(response) self.assertEqual(data, b"d" * 4096) with self.assertRaises(ValueError): response = self.get_gzipped(4096) xmlrpc.DefusedGzipDecodedResponse(response, clen - 1) with self.assertRaises(ValueError): response = self.get_gzipped(4096) self.decode_response(response, 4095) with self.assertRaises(ValueError): response = self.get_gzipped(4096) self.decode_response(response, 4095, 8192) def test_main(): suite = unittest.TestSuite() suite.addTests(unittest.makeSuite(TestDefusedcElementTree)) suite.addTests(unittest.makeSuite(TestDefusedElementTree)) suite.addTests(unittest.makeSuite(TestDefusedMinidom)) suite.addTests(unittest.makeSuite(TestDefusedMinidomWithParser)) suite.addTests(unittest.makeSuite(TestDefusedPulldom)) suite.addTests(unittest.makeSuite(TestDefusedSax)) suite.addTests(unittest.makeSuite(TestXmlRpc)) if lxml is not None: suite.addTests(unittest.makeSuite(TestDefusedLxml)) if gzip is not None: suite.addTests(unittest.makeSuite(TestDefusedGzip)) return suite if __name__ == "__main__": suite = test_main() result = unittest.TextTestRunner(verbosity=1).run(suite) # TODO: test that it actually works defuse_stdlib() sys.exit(not result.wasSuccessful())