# -*- coding: utf-8 -*-
"""
Tests for thread usage in lxml.etree.
"""
import unittest, threading, sys, os.path
this_dir = os.path.dirname(__file__)
if this_dir not in sys.path:
sys.path.insert(0, this_dir) # needed for Py3
from common_imports import etree, HelperTestCase, BytesIO, _bytes
try:
from Queue import Queue
except ImportError:
from queue import Queue # Py3
class ThreadingTestCase(HelperTestCase):
"""Threading tests"""
etree = etree
def _run_thread(self, func):
thread = threading.Thread(target=func)
thread.start()
thread.join()
def test_subtree_copy_thread(self):
tostring = self.etree.tostring
XML = self.etree.XML
xml = _bytes("")
main_root = XML(_bytes(""))
def run_thread():
thread_root = XML(xml)
main_root.append(thread_root[0])
del thread_root
self._run_thread(run_thread)
self.assertEquals(xml, tostring(main_root))
def test_main_xslt_in_thread(self):
XML = self.etree.XML
style = XML(_bytes('''\
'''))
st = etree.XSLT(style)
result = []
def run_thread():
root = XML(_bytes('BC'))
result.append( st(root) )
self._run_thread(run_thread)
self.assertEquals('''\
B
''',
str(result[0]))
def test_thread_xslt(self):
XML = self.etree.XML
tostring = self.etree.tostring
root = XML(_bytes('BC'))
def run_thread():
style = XML(_bytes('''\
'''))
st = etree.XSLT(style)
root.append( st(root).getroot() )
self._run_thread(run_thread)
self.assertEquals(_bytes('BCB'),
tostring(root))
def test_thread_xslt_attr_replace(self):
# this is the only case in XSLT where the result tree can be
# modified in-place
XML = self.etree.XML
tostring = self.etree.tostring
style = self.etree.XSLT(XML(_bytes('''\
xyz
''')))
result = []
def run_thread():
root = XML(_bytes(''))
result.append( style(root).getroot() )
self._run_thread(run_thread)
self.assertEquals(_bytes(''),
tostring(result[0]))
def test_thread_create_xslt(self):
XML = self.etree.XML
tostring = self.etree.tostring
root = XML(_bytes('BC'))
stylesheets = []
def run_thread():
style = XML(_bytes('''\
'''))
stylesheets.append( etree.XSLT(style) )
self._run_thread(run_thread)
st = stylesheets[0]
result = tostring( st(root) )
self.assertEquals(_bytes('
BC
'),
result)
def test_thread_error_log(self):
XML = self.etree.XML
ParseError = self.etree.ParseError
expected_error = [self.etree.ErrorTypes.ERR_TAG_NAME_MISMATCH]
children = "test" * 100
def parse_error_test(thread_no):
tag = "tag%d" % thread_no
xml = "<%s>%s%s>" % (tag, children, tag.upper())
parser = self.etree.XMLParser()
for _ in range(10):
errors = None
try:
XML(xml, parser)
except self.etree.ParseError:
e = sys.exc_info()[1]
errors = e.error_log.filter_types(expected_error)
self.assert_(errors, "Expected error not found")
for error in errors:
self.assert_(
tag in error.message and tag.upper() in error.message,
"%s and %s not found in '%s'" % (
tag, tag.upper(), error.message))
self.etree.clear_error_log()
threads = []
for thread_no in range(1, 10):
t = threading.Thread(target=parse_error_test,
args=(thread_no,))
threads.append(t)
t.start()
parse_error_test(0)
for t in threads:
t.join()
def test_thread_mix(self):
XML = self.etree.XML
Element = self.etree.Element
SubElement = self.etree.SubElement
tostring = self.etree.tostring
xml = _bytes('BC')
root = XML(xml)
fragment = XML(_bytes(""))
result = self.etree.Element("{myns}root", att = "someval")
def run_XML():
thread_root = XML(xml)
result.append(thread_root[0])
result.append(thread_root[-1])
def run_parse():
thread_root = self.etree.parse(BytesIO(xml)).getroot()
result.append(thread_root[0])
result.append(thread_root[-1])
def run_move_main():
result.append(fragment[0])
def run_build():
result.append(
Element("{myns}foo", attrib={'{test}attr':'val'}))
SubElement(result, "{otherns}tasty")
def run_xslt():
style = XML(_bytes('''\
'''))
st = etree.XSLT(style)
result.append( st(root).getroot() )
for test in (run_XML, run_parse, run_move_main, run_xslt, run_build):
tostring(result)
self._run_thread(test)
self.assertEquals(
_bytes('B'
'CBC'
'B'
''
''),
tostring(result))
def strip_first():
root = Element("newroot")
root.append(result[0])
while len(result):
self._run_thread(strip_first)
self.assertEquals(
_bytes(''),
tostring(result))
def test_concurrent_proxies(self):
XML = self.etree.XML
root = XML(_bytes('AB'))
child_count = len(root)
def testrun():
for i in range(10000):
el = root[i%child_count]
del el
threads = [ threading.Thread(target=testrun)
for _ in range(10) ]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
def test_concurrent_class_lookup(self):
XML = self.etree.XML
class TestElement(etree.ElementBase):
pass
class MyLookup(etree.CustomElementClassLookup):
repeat = range(100)
def lookup(self, t, d, ns, name):
count = 0
for i in self.repeat:
# allow other threads to run
count += 1
return TestElement
parser = self.etree.XMLParser()
parser.set_element_class_lookup(MyLookup())
root = XML(_bytes('AB'),
parser)
child_count = len(root)
def testrun():
for i in range(1000):
el = root[i%child_count]
del el
threads = [ threading.Thread(target=testrun)
for _ in range(10) ]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
class ThreadPipelineTestCase(HelperTestCase):
"""Threading tests based on a thread worker pipeline.
"""
etree = etree
item_count = 20
class Worker(threading.Thread):
def __init__(self, in_queue, in_count, **kwargs):
threading.Thread.__init__(self)
self.in_queue = in_queue
self.in_count = in_count
self.out_queue = Queue(in_count)
self.__dict__.update(kwargs)
def run(self):
get, put = self.in_queue.get, self.out_queue.put
handle = self.handle
for _ in range(self.in_count):
put(handle(get()))
class ParseWorker(Worker):
XML = etree.XML
def handle(self, xml):
return self.XML(xml)
class RotateWorker(Worker):
def handle(self, element):
first = element[0]
element[:] = element[1:]
element.append(first)
return element
class ReverseWorker(Worker):
def handle(self, element):
element[:] = element[::-1]
return element
class ParseAndExtendWorker(Worker):
XML = etree.XML
def handle(self, element):
element.extend(self.XML(self.xml))
return element
class SerialiseWorker(Worker):
def handle(self, element):
return etree.tostring(element)
xml = _bytes('''\
''')
def _build_pipeline(self, item_count, *classes, **kwargs):
in_queue = Queue(item_count)
start = last = classes[0](in_queue, item_count, **kwargs)
start.setDaemon(True)
for worker_class in classes[1:]:
last = worker_class(last.out_queue, item_count, **kwargs)
last.setDaemon(True)
last.start()
return (in_queue, start, last)
def test_thread_pipeline_thread_parse(self):
item_count = self.item_count
# build and start the pipeline
in_queue, start, last = self._build_pipeline(
item_count,
self.ParseWorker,
self.RotateWorker,
self.ReverseWorker,
self.ParseAndExtendWorker,
self.SerialiseWorker,
xml = self.xml)
# fill the queue
put = start.in_queue.put
for _ in range(item_count):
put(self.xml)
# start the first thread and thus everything
start.start()
# make sure the last thread has terminated
last.join(60) # time out after 60 seconds
self.assertEquals(item_count, last.out_queue.qsize())
# read the results
get = last.out_queue.get
results = [ get() for _ in range(item_count) ]
comparison = results[0]
for i, result in enumerate(results[1:]):
self.assertEquals(comparison, result)
def test_thread_pipeline_global_parse(self):
item_count = self.item_count
XML = self.etree.XML
# build and start the pipeline
in_queue, start, last = self._build_pipeline(
item_count,
self.RotateWorker,
self.ReverseWorker,
self.ParseAndExtendWorker,
self.SerialiseWorker,
xml = self.xml)
# fill the queue
put = start.in_queue.put
for _ in range(item_count):
put(XML(self.xml))
# start the first thread and thus everything
start.start()
# make sure the last thread has terminated
last.join(60) # time out after 90 seconds
self.assertEquals(item_count, last.out_queue.qsize())
# read the results
get = last.out_queue.get
results = [ get() for _ in range(item_count) ]
comparison = results[0]
for i, result in enumerate(results[1:]):
self.assertEquals(comparison, result)
def test_suite():
suite = unittest.TestSuite()
suite.addTests([unittest.makeSuite(ThreadingTestCase)])
suite.addTests([unittest.makeSuite(ThreadPipelineTestCase)])
return suite
if __name__ == '__main__':
print('to test use test.py %s' % __file__)