diff options
Diffstat (limited to 'test/test_argparse.py')
-rw-r--r-- | test/test_argparse.py | 89 |
1 files changed, 62 insertions, 27 deletions
diff --git a/test/test_argparse.py b/test/test_argparse.py index 2e78c88..8d3e97c 100644 --- a/test/test_argparse.py +++ b/test/test_argparse.py @@ -109,31 +109,42 @@ class NS(object): class ArgumentParserError(Exception): - def __init__(self, message, error_code): - Exception.__init__(self, message) + def __init__(self, message, stdout=None, stderr=None, error_code=None): self.message = message + self.stdout = stdout + self.stderr = stderr self.error_code = error_code -def stderr_to_parser_error(func, *args, **kwargs): - # if this is being called recursively and stderr is already being +def stderr_to_parser_error(parse_args, *args, **kwargs): + # if this is being called recursively and stderr or stdout is already being # redirected, simply call the function and let the enclosing function # catch the exception - if isinstance(sys.stderr, StringIO): - return func(*args, **kwargs) + if isinstance(sys.stderr, StringIO) or isinstance(sys.stdout, StringIO): + return parse_args(*args, **kwargs) # if this is not being called recursively, redirect stderr and # use it as the ArgumentParserError message + old_stdout = sys.stdout old_stderr = sys.stderr + sys.stdout = StringIO() sys.stderr = StringIO() try: try: - return func(*args, **kwargs) + result = parse_args(*args, **kwargs) + for key in list(vars(result)): + if getattr(result, key) is sys.stdout: + setattr(result, key, old_stdout) + if getattr(result, key) is sys.stderr: + setattr(result, key, old_stderr) + return result except SystemExit: code = sys.exc_info()[1].code - message = sys.stderr.getvalue() - raise ArgumentParserError(message, code) + stdout = sys.stdout.getvalue() + stderr = sys.stderr.getvalue() + raise ArgumentParserError("SystemExit", stdout, stderr, code) finally: + sys.stdout = old_stdout sys.stderr = old_stderr @@ -1341,6 +1352,7 @@ class TestArgumentsFromFileConverter(TempDirMixin, ParserTestCase): file.close() class FromFileConverterArgumentParser(ErrorRaisingArgumentParser): + def convert_arg_line_to_args(self, arg_line): for arg in arg_line.split(): if not arg.strip(): @@ -1799,10 +1811,10 @@ class TestAddSubparsers(TestCase): self.parser.parse_args(args_str.split()) except ArgumentParserError: err = sys.exc_info()[1] - if err.message != expected_help: + if err.stdout != expected_help: print(repr(expected_help)) - print(repr(err.message)) - self.assertEqual(err.message, expected_help) + print(repr(err.stdout)) + self.assertEqual(err.stdout, expected_help) def test_subparser1_help(self): self._test_subparser_help('5.0 1 -h', textwrap.dedent('''\ @@ -2553,6 +2565,27 @@ class TestGetDefault(TestCase): self.assertEqual("badger", parser.get_default("foo")) self.assertEqual(42, parser.get_default("bar")) +# ========================== +# Namespace 'contains' tests +# ========================== + +class TestNamespaceContainsSimple(TestCase): + + def test_empty(self): + ns = argparse.Namespace() + self.assertEquals('' in ns, False) + self.assertEquals('' not in ns, True) + self.assertEquals('x' in ns, False) + + def test_non_empty(self): + ns = argparse.Namespace(x=1, y=2) + self.assertEquals('x' in ns, True) + self.assertEquals('x' not in ns, False) + self.assertEquals('y' in ns, True) + self.assertEquals('' in ns, False) + self.assertEquals('xx' in ns, False) + self.assertEquals('z' in ns, False) + # ===================== # Help formatting tests # ===================== @@ -2565,8 +2598,9 @@ class TestHelpFormattingMetaclass(type): class AddTests(object): - def __init__(self, test_class, func_suffix): + def __init__(self, test_class, func_suffix, std_name): self.func_suffix = func_suffix + self.std_name = std_name for test_func in [self.test_format, self.test_print, @@ -2617,13 +2651,13 @@ class TestHelpFormattingMetaclass(type): def test_print(self, tester): parser = self._get_parser(tester) print_ = getattr(parser, 'print_%s' % self.func_suffix) - oldstderr = sys.stderr - sys.stderr = StringIO() + old_stream = getattr(sys, self.std_name) + setattr(sys, self.std_name, StringIO()) try: print_() - parser_text = sys.stderr.getvalue() + parser_text = getattr(sys, self.std_name).getvalue() finally: - sys.stderr = oldstderr + setattr(sys, self.std_name, old_stream) self._test(tester, parser_text) def test_print_file(self, tester): @@ -2635,8 +2669,10 @@ class TestHelpFormattingMetaclass(type): self._test(tester, parser_text) # add tests for {format,print}_{usage,help,version} - for func_suffix in ['usage', 'help', 'version']: - AddTests(cls, func_suffix) + for func_suffix, std_name in [('usage', 'stdout'), + ('help', 'stdout'), + ('version', 'stderr')]: + AddTests(cls, func_suffix, std_name) bases = TestCase, HelpTestCase = TestHelpFormattingMetaclass('HelpTestCase', bases, {}) @@ -3877,24 +3913,23 @@ class TestConflictHandling(TestCase): class TestOptionalsHelpVersionActions(TestCase): """Test the help and version actions""" - def _get_error_message(self, func, *args, **kwargs): + def _get_error(self, func, *args, **kwargs): try: func(*args, **kwargs) except ArgumentParserError: - err = sys.exc_info()[1] - return err.message + return sys.exc_info()[1] else: self.assertRaises(ArgumentParserError, func, *args, **kwargs) def assertPrintHelpExit(self, parser, args_str): self.assertEqual( parser.format_help(), - self._get_error_message(parser.parse_args, args_str.split())) + self._get_error(parser.parse_args, args_str.split()).stdout) def assertPrintVersionExit(self, parser, args_str): self.assertEqual( parser.format_version(), - self._get_error_message(parser.parse_args, args_str.split())) + self._get_error(parser.parse_args, args_str.split()).stderr) def assertArgumentParserError(self, parser, *args): self.assertRaises(ArgumentParserError, parser.parse_args, args) @@ -3908,7 +3943,7 @@ class TestOptionalsHelpVersionActions(TestCase): def test_version_format(self): parser = ErrorRaisingArgumentParser(prog='PPP', version='%(prog)s 3.5') - msg = self._get_error_message(parser.parse_args, ['-v']) + msg = self._get_error(parser.parse_args, ['-v']).stderr self.assertEqual('PPP 3.5\n', msg) def test_version_no_help(self): @@ -3921,7 +3956,7 @@ class TestOptionalsHelpVersionActions(TestCase): def test_version_action(self): parser = ErrorRaisingArgumentParser(prog='XXX') parser.add_argument('-V', action='version', version='%(prog)s 3.7') - msg = self._get_error_message(parser.parse_args, ['-V']) + msg = self._get_error(parser.parse_args, ['-V']).stderr self.assertEqual('XXX 3.7\n', msg) def test_no_help(self): @@ -4089,7 +4124,7 @@ class TestArgumentError(TestCase): parser.parse_args(['XXX']) except ArgumentParserError: expected = 'usage: PROG x\nPROG: error: argument x: spam!\n' - msg = str(sys.exc_info()[1]) + msg = sys.exc_info()[1].stderr self.failUnlessEqual(expected, msg) else: self.fail() |