diff options
Diffstat (limited to 'Lib/ast.py')
-rw-r--r-- | Lib/ast.py | 100 |
1 files changed, 92 insertions, 8 deletions
diff --git a/Lib/ast.py b/Lib/ast.py index 03b8a1b16b..6c1e978b05 100644 --- a/Lib/ast.py +++ b/Lib/ast.py @@ -115,10 +115,10 @@ def dump(node, annotate_fields=True, include_attributes=False): def copy_location(new_node, old_node): """ - Copy source location (`lineno` and `col_offset` attributes) from - *old_node* to *new_node* if possible, and return *new_node*. + Copy source location (`lineno`, `col_offset`, `end_lineno`, and `end_col_offset` + attributes) from *old_node* to *new_node* if possible, and return *new_node*. """ - for attr in 'lineno', 'col_offset': + for attr in 'lineno', 'col_offset', 'end_lineno', 'end_col_offset': if attr in old_node._attributes and attr in new_node._attributes \ and hasattr(old_node, attr): setattr(new_node, attr, getattr(old_node, attr)) @@ -133,31 +133,44 @@ def fix_missing_locations(node): recursively where not already set, by setting them to the values of the parent node. It works recursively starting at *node*. """ - def _fix(node, lineno, col_offset): + def _fix(node, lineno, col_offset, end_lineno, end_col_offset): if 'lineno' in node._attributes: if not hasattr(node, 'lineno'): node.lineno = lineno else: lineno = node.lineno + if 'end_lineno' in node._attributes: + if not hasattr(node, 'end_lineno'): + node.end_lineno = end_lineno + else: + end_lineno = node.end_lineno if 'col_offset' in node._attributes: if not hasattr(node, 'col_offset'): node.col_offset = col_offset else: col_offset = node.col_offset + if 'end_col_offset' in node._attributes: + if not hasattr(node, 'end_col_offset'): + node.end_col_offset = end_col_offset + else: + end_col_offset = node.end_col_offset for child in iter_child_nodes(node): - _fix(child, lineno, col_offset) - _fix(node, 1, 0) + _fix(child, lineno, col_offset, end_lineno, end_col_offset) + _fix(node, 1, 0, 1, 0) return node def increment_lineno(node, n=1): """ - Increment the line number of each node in the tree starting at *node* by *n*. - This is useful to "move code" to a different location in a file. + Increment the line number and end line number of each node in the tree + starting at *node* by *n*. This is useful to "move code" to a different + location in a file. """ for child in walk(node): if 'lineno' in child._attributes: child.lineno = getattr(child, 'lineno', 0) + n + if 'end_lineno' in child._attributes: + child.end_lineno = getattr(child, 'end_lineno', 0) + n return node @@ -213,6 +226,77 @@ def get_docstring(node, clean=True): return text +def _splitlines_no_ff(source): + """Split a string into lines ignoring form feed and other chars. + + This mimics how the Python parser splits source code. + """ + idx = 0 + lines = [] + next_line = '' + while idx < len(source): + c = source[idx] + next_line += c + idx += 1 + # Keep \r\n together + if c == '\r' and idx < len(source) and source[idx] == '\n': + next_line += '\n' + idx += 1 + if c in '\r\n': + lines.append(next_line) + next_line = '' + + if next_line: + lines.append(next_line) + return lines + + +def _pad_whitespace(source): + """Replace all chars except '\f\t' in a line with spaces.""" + result = '' + for c in source: + if c in '\f\t': + result += c + else: + result += ' ' + return result + + +def get_source_segment(source, node, *, padded=False): + """Get source code segment of the *source* that generated *node*. + + If some location information (`lineno`, `end_lineno`, `col_offset`, + or `end_col_offset`) is missing, return None. + + If *padded* is `True`, the first line of a multi-line statement will + be padded with spaces to match its original position. + """ + try: + lineno = node.lineno - 1 + end_lineno = node.end_lineno - 1 + col_offset = node.col_offset + end_col_offset = node.end_col_offset + except AttributeError: + return None + + lines = _splitlines_no_ff(source) + if end_lineno == lineno: + return lines[lineno].encode()[col_offset:end_col_offset].decode() + + if padded: + padding = _pad_whitespace(lines[lineno].encode()[:col_offset].decode()) + else: + padding = '' + + first = padding + lines[lineno].encode()[col_offset:].decode() + last = lines[end_lineno].encode()[:end_col_offset].decode() + lines = lines[lineno+1:end_lineno] + + lines.insert(0, first) + lines.append(last) + return ''.join(lines) + + def walk(node): """ Recursively yield all descendant nodes in the tree starting at *node* |