diff options
author | Thomas Kluyver <takowl@gmail.com> | 2013-09-30 16:27:13 -0700 |
---|---|---|
committer | Thomas Kluyver <takowl@gmail.com> | 2013-09-30 16:27:13 -0700 |
commit | 3e67d99d201c929fd89bc472fe716f6116b2f5d9 (patch) | |
tree | 426d303730c775d3127385a445b31ddf03959705 | |
parent | ae26a87df58d447176f8054c52d8113ef1831023 (diff) | |
download | pexpect-3e67d99d201c929fd89bc472fe716f6116b2f5d9.tar.gz |
Improve test for bad arguments to expect() and expect_exact()
-rw-r--r-- | pexpect/__init__.py | 29 | ||||
-rwxr-xr-x | tests/test_expect.py | 36 |
2 files changed, 52 insertions, 13 deletions
diff --git a/pexpect/__init__.py b/pexpect/__init__.py index 7b69d3c..28824ff 100644 --- a/pexpect/__init__.py +++ b/pexpect/__init__.py @@ -1265,6 +1265,16 @@ class spawn(object): if self.isalive(): os.kill(self.pid, sig) + def _pattern_type_err(self, pattern): + raise TypeError('got {badtype} ({badobj!r}) as pattern, must be one' + ' of: {goodtypes}, pexpect.EOF, pexpect.TIMEOUT'\ + .format(badtype=type(pattern), + badobj=pattern, + goodtypes=', '.join([str(ast)\ + for ast in self.allowed_string_types]) + ) + ) + def compile_pattern_list(self, patterns): '''This compiles a pattern-string or a list of pattern-strings. @@ -1311,12 +1321,7 @@ class spawn(object): elif isinstance(p, type(re.compile(''))): compiled_pattern_list.append(p) else: - raise TypeError('pattern is %s at position %d, ' - 'must be one of: %s' % ( - str(type(p)), idx, ', '.join([str(ast) - for ast in self.allowed_string_types - ] + [str(EOF), str(TIMEOUT), - str(type(re.compile('')))]),)) + self._pattern_type_err(p) return compiled_pattern_list def expect(self, pattern, timeout=-1, searchwindowsize=-1): @@ -1433,12 +1438,18 @@ class spawn(object): pattern_list in (TIMEOUT, EOF)): pattern_list = [pattern_list] - def prepare_string(pattern): + def prepare_pattern(pattern): if pattern in (TIMEOUT, EOF): return pattern - return self._coerce_expect_string(pattern) + if isinstance(pattern, self.allowed_string_types): + return self._coerce_expect_string(pattern) + self._pattern_type_err(pattern) - pattern_list = [prepare_string(p) for p in pattern_list] + try: + pattern_list = iter(pattern_list) + except TypeError: + self._pattern_type_err(pattern_list) + pattern_list = [prepare_pattern(p) for p in pattern_list] return self.expect_loop(searcher_string(pattern_list), timeout, searchwindowsize) diff --git a/tests/test_expect.py b/tests/test_expect.py index 6aba431..48fc022 100755 --- a/tests/test_expect.py +++ b/tests/test_expect.py @@ -52,6 +52,29 @@ def hex_diff(left, right): return '\n' + '\n'.join(diff,) +class assert_raises_msg(object): + def __init__(self, errtype, msgpart): + self.errtype = errtype + self.msgpart = msgpart + + def __enter__(self): + pass + + def __exit__(self, etype, value, traceback): + if value is None: + raise AssertionError('Expected %s, but no exception was raised' \ + % self.errtype) + if not isinstance(value, self.errtype): + raise AssertionError('Expected %s, but %s was raised' \ + % (self.errtype, etype)) + + errstr = str(value) + if self.msgpart not in errstr: + raise AssertionError('%r was not in %r' % (self.msgpart, errstr)) + + return True + + class ExpectTestCase (PexpectTestCase.PexpectTestCase): def test_expect_basic (self): @@ -451,10 +474,15 @@ class ExpectTestCase (PexpectTestCase.PexpectTestCase): def test_bad_arg(self): p = pexpect.spawn('cat') - self.assertRaises(TypeError, p.expect, 1) - self.assertRaises(TypeError, p.expect, [1, b'2']) - self.assertRaises(TypeError, p.expect_exact, 1) - self.assertRaises(TypeError, p.expect_exact, [1, b'2']) + with assert_raises_msg(TypeError, 'must be one of'): + p.expect(1) + with assert_raises_msg(TypeError, 'must be one of'): + p.expect([1, b'2']) + + with assert_raises_msg(TypeError, 'must be one of'): + p.expect_exact(1) + with assert_raises_msg(TypeError, 'must be one of'): + p.expect_exact([1, b'2']) if __name__ == '__main__': unittest.main() |