diff options
Diffstat (limited to 'Lib/unittest/main.py')
| -rw-r--r-- | Lib/unittest/main.py | 25 | 
1 files changed, 20 insertions, 5 deletions
diff --git a/Lib/unittest/main.py b/Lib/unittest/main.py index 807604f08d..e62469aa2a 100644 --- a/Lib/unittest/main.py +++ b/Lib/unittest/main.py @@ -46,6 +46,12 @@ def _convert_names(names):      return [_convert_name(name) for name in names] +def _convert_select_pattern(pattern): +    if not '*' in pattern: +        pattern = '*%s*' % pattern +    return pattern + +  class TestProgram(object):      """A command-line program that runs a set of tests; this is primarily         for making test modules conveniently executable. @@ -53,7 +59,7 @@ class TestProgram(object):      # defaults for testing      module=None      verbosity = 1 -    failfast = catchbreak = buffer = progName = warnings = None +    failfast = catchbreak = buffer = progName = warnings = testNamePatterns = None      _discovery_parser = None      def __init__(self, module='__main__', defaultTest=None, argv=None, @@ -140,8 +146,13 @@ class TestProgram(object):              self.testNames = list(self.defaultTest)          self.createTests() -    def createTests(self): -        if self.testNames is None: +    def createTests(self, from_discovery=False, Loader=None): +        if self.testNamePatterns: +            self.testLoader.testNamePatterns = self.testNamePatterns +        if from_discovery: +            loader = self.testLoader if Loader is None else Loader() +            self.test = loader.discover(self.start, self.pattern, self.top) +        elif self.testNames is None:              self.test = self.testLoader.loadTestsFromModule(self.module)          else:              self.test = self.testLoader.loadTestsFromNames(self.testNames, @@ -179,6 +190,11 @@ class TestProgram(object):                                  action='store_true',                                  help='Buffer stdout and stderr during tests')              self.buffer = False +        if self.testNamePatterns is None: +            parser.add_argument('-k', dest='testNamePatterns', +                                action='append', type=_convert_select_pattern, +                                help='Only run tests which match the given substring') +            self.testNamePatterns = []          return parser @@ -225,8 +241,7 @@ class TestProgram(object):                  self._initArgParsers()              self._discovery_parser.parse_args(argv, self) -        loader = self.testLoader if Loader is None else Loader() -        self.test = loader.discover(self.start, self.pattern, self.top) +        self.createTests(from_discovery=True, Loader=Loader)      def runTests(self):          if self.catchbreak:  | 
