summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorThomas Kluyver <takowl@gmail.com>2013-09-30 16:27:13 -0700
committerThomas Kluyver <takowl@gmail.com>2013-09-30 16:27:13 -0700
commit3e67d99d201c929fd89bc472fe716f6116b2f5d9 (patch)
tree426d303730c775d3127385a445b31ddf03959705
parentae26a87df58d447176f8054c52d8113ef1831023 (diff)
downloadpexpect-3e67d99d201c929fd89bc472fe716f6116b2f5d9.tar.gz
Improve test for bad arguments to expect() and expect_exact()
-rw-r--r--pexpect/__init__.py29
-rwxr-xr-xtests/test_expect.py36
2 files changed, 52 insertions, 13 deletions
diff --git a/pexpect/__init__.py b/pexpect/__init__.py
index 7b69d3c..28824ff 100644
--- a/pexpect/__init__.py
+++ b/pexpect/__init__.py
@@ -1265,6 +1265,16 @@ class spawn(object):
if self.isalive():
os.kill(self.pid, sig)
+ def _pattern_type_err(self, pattern):
+ raise TypeError('got {badtype} ({badobj!r}) as pattern, must be one'
+ ' of: {goodtypes}, pexpect.EOF, pexpect.TIMEOUT'\
+ .format(badtype=type(pattern),
+ badobj=pattern,
+ goodtypes=', '.join([str(ast)\
+ for ast in self.allowed_string_types])
+ )
+ )
+
def compile_pattern_list(self, patterns):
'''This compiles a pattern-string or a list of pattern-strings.
@@ -1311,12 +1321,7 @@ class spawn(object):
elif isinstance(p, type(re.compile(''))):
compiled_pattern_list.append(p)
else:
- raise TypeError('pattern is %s at position %d, '
- 'must be one of: %s' % (
- str(type(p)), idx, ', '.join([str(ast)
- for ast in self.allowed_string_types
- ] + [str(EOF), str(TIMEOUT),
- str(type(re.compile('')))]),))
+ self._pattern_type_err(p)
return compiled_pattern_list
def expect(self, pattern, timeout=-1, searchwindowsize=-1):
@@ -1433,12 +1438,18 @@ class spawn(object):
pattern_list in (TIMEOUT, EOF)):
pattern_list = [pattern_list]
- def prepare_string(pattern):
+ def prepare_pattern(pattern):
if pattern in (TIMEOUT, EOF):
return pattern
- return self._coerce_expect_string(pattern)
+ if isinstance(pattern, self.allowed_string_types):
+ return self._coerce_expect_string(pattern)
+ self._pattern_type_err(pattern)
- pattern_list = [prepare_string(p) for p in pattern_list]
+ try:
+ pattern_list = iter(pattern_list)
+ except TypeError:
+ self._pattern_type_err(pattern_list)
+ pattern_list = [prepare_pattern(p) for p in pattern_list]
return self.expect_loop(searcher_string(pattern_list),
timeout, searchwindowsize)
diff --git a/tests/test_expect.py b/tests/test_expect.py
index 6aba431..48fc022 100755
--- a/tests/test_expect.py
+++ b/tests/test_expect.py
@@ -52,6 +52,29 @@ def hex_diff(left, right):
return '\n' + '\n'.join(diff,)
+class assert_raises_msg(object):
+ def __init__(self, errtype, msgpart):
+ self.errtype = errtype
+ self.msgpart = msgpart
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, etype, value, traceback):
+ if value is None:
+ raise AssertionError('Expected %s, but no exception was raised' \
+ % self.errtype)
+ if not isinstance(value, self.errtype):
+ raise AssertionError('Expected %s, but %s was raised' \
+ % (self.errtype, etype))
+
+ errstr = str(value)
+ if self.msgpart not in errstr:
+ raise AssertionError('%r was not in %r' % (self.msgpart, errstr))
+
+ return True
+
+
class ExpectTestCase (PexpectTestCase.PexpectTestCase):
def test_expect_basic (self):
@@ -451,10 +474,15 @@ class ExpectTestCase (PexpectTestCase.PexpectTestCase):
def test_bad_arg(self):
p = pexpect.spawn('cat')
- self.assertRaises(TypeError, p.expect, 1)
- self.assertRaises(TypeError, p.expect, [1, b'2'])
- self.assertRaises(TypeError, p.expect_exact, 1)
- self.assertRaises(TypeError, p.expect_exact, [1, b'2'])
+ with assert_raises_msg(TypeError, 'must be one of'):
+ p.expect(1)
+ with assert_raises_msg(TypeError, 'must be one of'):
+ p.expect([1, b'2'])
+
+ with assert_raises_msg(TypeError, 'must be one of'):
+ p.expect_exact(1)
+ with assert_raises_msg(TypeError, 'must be one of'):
+ p.expect_exact([1, b'2'])
if __name__ == '__main__':
unittest.main()