summaryrefslogtreecommitdiff
path: root/testlib.py
diff options
context:
space:
mode:
authorAdrien Di Mascio <adim@logilab.fr>2006-05-24 17:18:12 +0200
committerAdrien Di Mascio <adim@logilab.fr>2006-05-24 17:18:12 +0200
commit7b31f9fb0eba21e28e53c54d1dd73745ab5f5bea (patch)
tree3bbe3cb38ab6709090128de6cd644ff1db7795a6 /testlib.py
parent96935250413d8c1c0181c1dd1622766f55eae4a7 (diff)
downloadlogilab-common-7b31f9fb0eba21e28e53c54d1dd73745ab5f5bea.tar.gz
added a more flexible test loader in testlib
Diffstat (limited to 'testlib.py')
-rw-r--r--testlib.py103
1 files changed, 103 insertions, 0 deletions
diff --git a/testlib.py b/testlib.py
index acbd943..9b1bb28 100644
--- a/testlib.py
+++ b/testlib.py
@@ -32,6 +32,7 @@ import getopt
import traceback
import unittest
import difflib
+import types
from warnings import warn
from compiler.consts import CO_GENERATOR
@@ -360,6 +361,102 @@ class starargs(tuple):
return tuple.__new__(cls, args)
+
+class NonStrictTestLoader(unittest.TestLoader):
+ """
+ overrides default testloader to be able to omit classname when
+ specifying tests to run on command line. For example, if the file
+ test_foo.py contains ::
+
+ class FooTC(TestCase):
+ def test_foo1(self): # ...
+ def test_foo2(self): # ...
+ def test_bar1(self): # ...
+
+ class BarTC(TestCase):
+ def test_bar2(self): # ...
+
+ python test_foo.py will run the 3 tests in FooTC
+ python test_foo.py FooTC will run the 3 tests in FooTC
+ python test_foo.py test_foo will run test_foo1 and test_foo2
+ python test_foo.py test_foo1 will run test_foo1
+ python test_foo.py test_bar will run FooTC.test_bar1 and BarTC.test_bar2
+ """
+ def loadTestsFromNames(self, names, module=None):
+ suites = []
+ for name in names:
+ suites.extend(self.loadTestsFromName(name, module))
+ return self.suiteClass(suites)
+
+
+ def _collect_tests(self, module):
+ tests = {}
+ for obj in vars(module).values():
+ if type(obj) in (types.ClassType, type) and \
+ issubclass(obj, unittest.TestCase):
+ classname = obj.__name__
+ methodnames = []
+ # obj is a TestCase class
+ for attrname in dir(obj):
+ if attrname.startswith(self.testMethodPrefix):
+ attr = getattr(obj, attrname)
+ if callable(attr):
+ methodnames.append(attrname)
+ # keep track of class (obj) for convenience
+ tests[classname] = (obj, methodnames)
+ return tests
+
+ def loadTestsFromName(self, name, module=None):
+ parts = name.split('.')
+ if module is None or len(parts) > 2:
+ # let the base class do its job here
+ return [unittest.TestLoader.loadTestsFromName(self, name)]
+ tests = self._collect_tests(module)
+ # import pprint
+ # pprint.pprint(tests)
+ collected = []
+ if len(parts) == 1:
+ pattern = parts[0]
+ if pattern in tests:
+ # case python unittest_foo.py MyTestTC
+ klass, methodnames = tests[pattern]
+ for methodname in methodnames:
+ collected = [klass(methodname) for methodname in methodnames]
+ else:
+ # case python unittest_foo.py something
+ for klass, methodnames in tests.values():
+ collected += [klass(methodname) for methodname in methodnames
+ if self._test_should_be_collected(methodname, pattern)]
+ elif len(parts) == 2:
+ # case "MyClass.test_1"
+ classname, pattern = parts
+ klass, methodnames = tests.get(classname, (None, []))
+ for methodname in methodnames:
+ collected = [klass(methodname) for methodname in methodnames
+ if self._test_should_be_collected(methodname, pattern)]
+ return collected
+
+ def _test_should_be_collected(self, methodname, pattern):
+ """returns True if <methodname> matches <pattern>
+ >>> self._test_should_be_collected('test_foobar', 'foo')
+ True
+ >>> self._test_should_be_collected('testfoobar', 'foo')
+ True
+ >>> self._test_should_be_collected('test_foobar', 'test_foo')
+ True
+ >>> self._test_should_be_collected('test_foobar', 'testfoo')
+ False
+ """
+ prefix = self.testMethodPrefix
+ # case where testname="test_some" and methodname="test_something"
+ if pattern.startswith(prefix):
+ return methodname.startswith(pattern)
+ # case where pattern="foo" and methodname="testfoo" or "test_foo"
+ if methodname.startswith(prefix + pattern) or \
+ methodname.startswith(prefix + '_' + pattern):
+ return True
+ return False
+
class SkipAwareTestProgram(unittest.TestProgram):
# XXX: don't try to stay close to unittest.py, use optparse
USAGE = """\
@@ -379,6 +476,11 @@ Examples:
%(progName)s MyTestCase - run all 'test*' test methods
in MyTestCase
"""
+ def __init__(self, module='__main__'):
+ unittest.TestProgram.__init__(self, module=module,
+ testLoader=NonStrictTestLoader())
+
+
def parseArgs(self, argv):
self.pdbmode = False
self.exitfirst = False
@@ -409,6 +511,7 @@ Examples:
self.usageExit(msg)
+
def runTests(self):
self.testRunner = SkipAwareTextTestRunner(verbosity=self.verbosity,
exitfirst=self.exitfirst)