diff options
Diffstat (limited to 'Lib/unittest.py')
| -rw-r--r-- | Lib/unittest.py | 98 | 
1 files changed, 58 insertions, 40 deletions
| diff --git a/Lib/unittest.py b/Lib/unittest.py index 043b9a848a..f44769e926 100644 --- a/Lib/unittest.py +++ b/Lib/unittest.py @@ -27,7 +27,7 @@ Further information is available in the bundled documentation, and from    http://pyunit.sourceforge.net/ -Copyright (c) 1999, 2000, 2001 Steve Purcell +Copyright (c) 1999-2003 Steve Purcell  This module is free software, and you may redistribute it and/or modify  it under the same terms as Python itself, so long as this copyright message  and disclaimer are retained in their original form. @@ -46,12 +46,11 @@ SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.  __author__ = "Steve Purcell"  __email__ = "stephen_purcell at yahoo dot com" -__version__ = "#Revision: 1.46 $"[11:-2] +__version__ = "#Revision: 1.56 $"[11:-2]  import time  import sys  import traceback -import string  import os  import types @@ -61,11 +60,27 @@ import types  __all__ = ['TestResult', 'TestCase', 'TestSuite', 'TextTestRunner',             'TestLoader', 'FunctionTestCase', 'main', 'defaultTestLoader'] -# Expose obsolete functions for backwards compatability +# Expose obsolete functions for backwards compatibility  __all__.extend(['getTestCaseNames', 'makeSuite', 'findTestCases'])  ############################################################################## +# Backward compatibility +############################################################################## +if sys.version_info[:2] < (2, 2): +    False, True = 0, 1 +    def isinstance(obj, clsinfo): +        import __builtin__ +        if type(clsinfo) in (types.TupleType, types.ListType): +            for cls in clsinfo: +                if cls is type: cls = types.ClassType +                if __builtin__.isinstance(obj, cls): +                    return 1 +            return 0 +        else: return __builtin__.isinstance(obj, clsinfo) + + +##############################################################################  # Test framework core  ############################################################################## @@ -121,11 +136,11 @@ class TestResult:      def stop(self):          "Indicates that the tests should be aborted" -        self.shouldStop = 1 +        self.shouldStop = True      def _exc_info_to_string(self, err):          """Converts a sys.exc_info()-style tuple of values into a string.""" -        return string.join(traceback.format_exception(*err), '') +        return ''.join(traceback.format_exception(*err))      def __repr__(self):          return "<%s run=%i errors=%i failures=%i>" % \ @@ -196,7 +211,7 @@ class TestCase:          the specified test method's docstring.          """          doc = self.__testMethodDoc -        return doc and string.strip(string.split(doc, "\n")[0]) or None +        return doc and doc.split("\n")[0].strip() or None      def id(self):          return "%s.%s" % (_strclass(self.__class__), self.__testMethodName) @@ -209,9 +224,6 @@ class TestCase:                 (_strclass(self.__class__), self.__testMethodName)      def run(self, result=None): -        return self(result) - -    def __call__(self, result=None):          if result is None: result = self.defaultTestResult()          result.startTest(self)          testMethod = getattr(self, self.__testMethodName) @@ -224,10 +236,10 @@ class TestCase:                  result.addError(self, self.__exc_info())                  return -            ok = 0 +            ok = False              try:                  testMethod() -                ok = 1 +                ok = True              except self.failureException:                  result.addFailure(self, self.__exc_info())              except KeyboardInterrupt: @@ -241,11 +253,13 @@ class TestCase:                  raise              except:                  result.addError(self, self.__exc_info()) -                ok = 0 +                ok = False              if ok: result.addSuccess(self)          finally:              result.stopTest(self) +    __call__ = run +      def debug(self):          """Run the test without collecting errors in a TestResult"""          self.setUp() @@ -292,7 +306,7 @@ class TestCase:          else:              if hasattr(excClass,'__name__'): excName = excClass.__name__              else: excName = str(excClass) -            raise self.failureException, excName +            raise self.failureException, "%s not raised" % excName      def failUnlessEqual(self, first, second, msg=None):          """Fail if the two objects are unequal as determined by the '==' @@ -334,6 +348,8 @@ class TestCase:              raise self.failureException, \                    (msg or '%s == %s within %s places' % (`first`, `second`, `places`)) +    # Synonyms for assertion methods +      assertEqual = assertEquals = failUnlessEqual      assertNotEqual = assertNotEquals = failIfEqual @@ -344,7 +360,9 @@ class TestCase:      assertRaises = failUnlessRaises -    assert_ = failUnless +    assert_ = assertTrue = failUnless + +    assertFalse = failIf @@ -369,7 +387,7 @@ class TestSuite:      def countTestCases(self):          cases = 0          for test in self._tests: -            cases = cases + test.countTestCases() +            cases += test.countTestCases()          return cases      def addTest(self, test): @@ -434,7 +452,7 @@ class FunctionTestCase(TestCase):      def shortDescription(self):          if self.__description is not None: return self.__description          doc = self.__testFunc.__doc__ -        return doc and string.strip(string.split(doc, "\n")[0]) or None +        return doc and doc.split("\n")[0].strip() or None @@ -452,8 +470,10 @@ class TestLoader:      def loadTestsFromTestCase(self, testCaseClass):          """Return a suite of all tests cases contained in testCaseClass""" -        return self.suiteClass(map(testCaseClass, -                                   self.getTestCaseNames(testCaseClass))) +        testCaseNames = self.getTestCaseNames(testCaseClass) +        if not testCaseNames and hasattr(testCaseClass, 'runTest'): +            testCaseNames = ['runTest'] +        return self.suiteClass(map(testCaseClass, testCaseNames))      def loadTestsFromModule(self, module):          """Return a suite of all tests cases contained in the given module""" @@ -474,23 +494,20 @@ class TestLoader:          The method optionally resolves the names relative to a given module.          """ -        parts = string.split(name, '.') +        parts = name.split('.')          if module is None: -            if not parts: -                raise ValueError, "incomplete test name: %s" % name -            else: -                parts_copy = parts[:] -                while parts_copy: -                    try: -                        module = __import__(string.join(parts_copy,'.')) -                        break -                    except ImportError: -                        del parts_copy[-1] -                        if not parts_copy: raise +            parts_copy = parts[:] +            while parts_copy: +                try: +                    module = __import__('.'.join(parts_copy)) +                    break +                except ImportError: +                    del parts_copy[-1] +                    if not parts_copy: raise                  parts = parts[1:]          obj = module          for part in parts: -            obj = getattr(obj, part) +            parent, obj = obj, getattr(obj, part)          import unittest          if type(obj) == types.ModuleType: @@ -499,11 +516,13 @@ class TestLoader:                issubclass(obj, unittest.TestCase)):              return self.loadTestsFromTestCase(obj)          elif type(obj) == types.UnboundMethodType: +            return parent(obj.__name__)              return obj.im_class(obj.__name__) +        elif isinstance(obj, unittest.TestSuite): +            return obj          elif callable(obj):              test = obj() -            if not isinstance(test, unittest.TestCase) and \ -               not isinstance(test, unittest.TestSuite): +            if not isinstance(test, (unittest.TestCase, unittest.TestSuite)):                  raise ValueError, \                        "calling %s returned %s, not a test" % (obj,test)              return test @@ -514,16 +533,15 @@ class TestLoader:          """Return a suite of all tests cases found using the given sequence          of string specifiers. See 'loadTestsFromName()'.          """ -        suites = [] -        for name in names: -            suites.append(self.loadTestsFromName(name, module)) +        suites = [self.loadTestsFromName(name, module) for name in names]          return self.suiteClass(suites)      def getTestCaseNames(self, testCaseClass):          """Return a sorted sequence of method names found within testCaseClass          """ -        testFnNames = filter(lambda n,p=self.testMethodPrefix: n[:len(p)] == p, -                             dir(testCaseClass)) +        def isTestMethod(attrname, testCaseClass=testCaseClass, prefix=self.testMethodPrefix): +            return attrname[:len(prefix)] == prefix and callable(getattr(testCaseClass, attrname)) +        testFnNames = filter(isTestMethod, dir(testCaseClass))          for baseclass in testCaseClass.__bases__:              for testFnName in self.getTestCaseNames(baseclass):                  if testFnName not in testFnNames:  # handle overridden methods @@ -706,7 +724,7 @@ Examples:                   argv=None, testRunner=None, testLoader=defaultTestLoader):          if type(module) == type(''):              self.module = __import__(module) -            for part in string.split(module,'.')[1:]: +            for part in module.split('.')[1:]:                  self.module = getattr(self.module, part)          else:              self.module = module | 
