diff options
author | Adrien Di Mascio <adim@logilab.fr> | 2006-05-24 17:18:12 +0200 |
---|---|---|
committer | Adrien Di Mascio <adim@logilab.fr> | 2006-05-24 17:18:12 +0200 |
commit | 7b31f9fb0eba21e28e53c54d1dd73745ab5f5bea (patch) | |
tree | 3bbe3cb38ab6709090128de6cd644ff1db7795a6 /testlib.py | |
parent | 96935250413d8c1c0181c1dd1622766f55eae4a7 (diff) | |
download | logilab-common-7b31f9fb0eba21e28e53c54d1dd73745ab5f5bea.tar.gz |
added a more flexible test loader in testlib
Diffstat (limited to 'testlib.py')
-rw-r--r-- | testlib.py | 103 |
1 files changed, 103 insertions, 0 deletions
@@ -32,6 +32,7 @@ import getopt import traceback import unittest import difflib +import types from warnings import warn from compiler.consts import CO_GENERATOR @@ -360,6 +361,102 @@ class starargs(tuple): return tuple.__new__(cls, args) + +class NonStrictTestLoader(unittest.TestLoader): + """ + overrides default testloader to be able to omit classname when + specifying tests to run on command line. For example, if the file + test_foo.py contains :: + + class FooTC(TestCase): + def test_foo1(self): # ... + def test_foo2(self): # ... + def test_bar1(self): # ... + + class BarTC(TestCase): + def test_bar2(self): # ... + + python test_foo.py will run the 3 tests in FooTC + python test_foo.py FooTC will run the 3 tests in FooTC + python test_foo.py test_foo will run test_foo1 and test_foo2 + python test_foo.py test_foo1 will run test_foo1 + python test_foo.py test_bar will run FooTC.test_bar1 and BarTC.test_bar2 + """ + def loadTestsFromNames(self, names, module=None): + suites = [] + for name in names: + suites.extend(self.loadTestsFromName(name, module)) + return self.suiteClass(suites) + + + def _collect_tests(self, module): + tests = {} + for obj in vars(module).values(): + if type(obj) in (types.ClassType, type) and \ + issubclass(obj, unittest.TestCase): + classname = obj.__name__ + methodnames = [] + # obj is a TestCase class + for attrname in dir(obj): + if attrname.startswith(self.testMethodPrefix): + attr = getattr(obj, attrname) + if callable(attr): + methodnames.append(attrname) + # keep track of class (obj) for convenience + tests[classname] = (obj, methodnames) + return tests + + def loadTestsFromName(self, name, module=None): + parts = name.split('.') + if module is None or len(parts) > 2: + # let the base class do its job here + return [unittest.TestLoader.loadTestsFromName(self, name)] + tests = self._collect_tests(module) + # import pprint + # pprint.pprint(tests) + collected = [] + if len(parts) == 1: + pattern = parts[0] + if pattern in tests: + # case python unittest_foo.py MyTestTC + klass, methodnames = tests[pattern] + for methodname in methodnames: + collected = [klass(methodname) for methodname in methodnames] + else: + # case python unittest_foo.py something + for klass, methodnames in tests.values(): + collected += [klass(methodname) for methodname in methodnames + if self._test_should_be_collected(methodname, pattern)] + elif len(parts) == 2: + # case "MyClass.test_1" + classname, pattern = parts + klass, methodnames = tests.get(classname, (None, [])) + for methodname in methodnames: + collected = [klass(methodname) for methodname in methodnames + if self._test_should_be_collected(methodname, pattern)] + return collected + + def _test_should_be_collected(self, methodname, pattern): + """returns True if <methodname> matches <pattern> + >>> self._test_should_be_collected('test_foobar', 'foo') + True + >>> self._test_should_be_collected('testfoobar', 'foo') + True + >>> self._test_should_be_collected('test_foobar', 'test_foo') + True + >>> self._test_should_be_collected('test_foobar', 'testfoo') + False + """ + prefix = self.testMethodPrefix + # case where testname="test_some" and methodname="test_something" + if pattern.startswith(prefix): + return methodname.startswith(pattern) + # case where pattern="foo" and methodname="testfoo" or "test_foo" + if methodname.startswith(prefix + pattern) or \ + methodname.startswith(prefix + '_' + pattern): + return True + return False + class SkipAwareTestProgram(unittest.TestProgram): # XXX: don't try to stay close to unittest.py, use optparse USAGE = """\ @@ -379,6 +476,11 @@ Examples: %(progName)s MyTestCase - run all 'test*' test methods in MyTestCase """ + def __init__(self, module='__main__'): + unittest.TestProgram.__init__(self, module=module, + testLoader=NonStrictTestLoader()) + + def parseArgs(self, argv): self.pdbmode = False self.exitfirst = False @@ -409,6 +511,7 @@ Examples: self.usageExit(msg) + def runTests(self): self.testRunner = SkipAwareTextTestRunner(verbosity=self.verbosity, exitfirst=self.exitfirst) |