diff options
author | Tyson Andre <tyson.andre@uwaterloo.ca> | 2019-08-21 13:15:22 -0400 |
---|---|---|
committer | Eli Bendersky <eliben@users.noreply.github.com> | 2019-08-21 10:15:22 -0700 |
commit | a350f0d11d17cac15a1fc7ecae2fae9006872f8b (patch) | |
tree | 9f6a06a27d13f6fa960da28aa29bf9acfa32b79c | |
parent | bc2010aea92535cb1d70be9fc1bebeb6eff229d8 (diff) | |
download | pycparser-a350f0d11d17cac15a1fc7ecae2fae9006872f8b.tar.gz |
Fix error transforming an empty switch (#346)
* Fix error transforming an empty switch
The parser would crash on that line for `switch(1) {}`
because NoneType is not iterable.
Fixes #345
* Add a test of empty switch statements
* Address review comments
-rw-r--r-- | pycparser/ast_transforms.py | 3 | ||||
-rwxr-xr-x | tests/test_c_parser.py | 14 |
2 files changed, 16 insertions, 1 deletions
diff --git a/pycparser/ast_transforms.py b/pycparser/ast_transforms.py index ba50966..0aeb88f 100644 --- a/pycparser/ast_transforms.py +++ b/pycparser/ast_transforms.py @@ -74,7 +74,8 @@ def fix_switch_cases(switch_node): # Goes over the children of the Compound below the Switch, adding them # either directly below new_compound or below the last Case as appropriate - for child in switch_node.stmt.block_items: + # (for `switch(cond) {}`, block_items would have been None) + for child in (switch_node.stmt.block_items or []): if isinstance(child, (c_ast.Case, c_ast.Default)): # If it's a Case/Default: # 1. Add it to the Compound and mark as "last case" diff --git a/tests/test_c_parser.py b/tests/test_c_parser.py index ad9a218..49cada3 100755 --- a/tests/test_c_parser.py +++ b/tests/test_c_parser.py @@ -1792,6 +1792,7 @@ class TestCParser_whole_code(TestCParser_base): switch = ps1.ext[0].body.block_items[0] block = switch.stmt.block_items + self.assertEqual(len(block), 4) assert_case_node(block[0], '10') self.assertEqual(len(block[0].stmts), 3) assert_case_node(block[1], '20') @@ -1819,6 +1820,7 @@ class TestCParser_whole_code(TestCParser_base): switch = ps2.ext[0].body.block_items[0] block = switch.stmt.block_items + self.assertEqual(len(block), 5) assert_default_node(block[0]) self.assertEqual(len(block[0].stmts), 2) assert_case_node(block[1], '10') @@ -1830,6 +1832,18 @@ class TestCParser_whole_code(TestCParser_base): assert_case_node(block[4], '40') self.assertEqual(len(block[4].stmts), 1) + s3 = r''' + int foo(void) { + switch (myvar) { + } + return 0; + } + ''' + ps3 = self.parse(s3) + switch = ps3.ext[0].body.block_items[0] + + self.assertEqual(switch.stmt.block_items, []) + def test_for_statement(self): s2 = r''' void x(void) |