summaryrefslogtreecommitdiff
path: root/Lib/ast.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/ast.py')
-rw-r--r--Lib/ast.py100
1 files changed, 92 insertions, 8 deletions
diff --git a/Lib/ast.py b/Lib/ast.py
index 03b8a1b16b..6c1e978b05 100644
--- a/Lib/ast.py
+++ b/Lib/ast.py
@@ -115,10 +115,10 @@ def dump(node, annotate_fields=True, include_attributes=False):
def copy_location(new_node, old_node):
"""
- Copy source location (`lineno` and `col_offset` attributes) from
- *old_node* to *new_node* if possible, and return *new_node*.
+ Copy source location (`lineno`, `col_offset`, `end_lineno`, and `end_col_offset`
+ attributes) from *old_node* to *new_node* if possible, and return *new_node*.
"""
- for attr in 'lineno', 'col_offset':
+ for attr in 'lineno', 'col_offset', 'end_lineno', 'end_col_offset':
if attr in old_node._attributes and attr in new_node._attributes \
and hasattr(old_node, attr):
setattr(new_node, attr, getattr(old_node, attr))
@@ -133,31 +133,44 @@ def fix_missing_locations(node):
recursively where not already set, by setting them to the values of the
parent node. It works recursively starting at *node*.
"""
- def _fix(node, lineno, col_offset):
+ def _fix(node, lineno, col_offset, end_lineno, end_col_offset):
if 'lineno' in node._attributes:
if not hasattr(node, 'lineno'):
node.lineno = lineno
else:
lineno = node.lineno
+ if 'end_lineno' in node._attributes:
+ if not hasattr(node, 'end_lineno'):
+ node.end_lineno = end_lineno
+ else:
+ end_lineno = node.end_lineno
if 'col_offset' in node._attributes:
if not hasattr(node, 'col_offset'):
node.col_offset = col_offset
else:
col_offset = node.col_offset
+ if 'end_col_offset' in node._attributes:
+ if not hasattr(node, 'end_col_offset'):
+ node.end_col_offset = end_col_offset
+ else:
+ end_col_offset = node.end_col_offset
for child in iter_child_nodes(node):
- _fix(child, lineno, col_offset)
- _fix(node, 1, 0)
+ _fix(child, lineno, col_offset, end_lineno, end_col_offset)
+ _fix(node, 1, 0, 1, 0)
return node
def increment_lineno(node, n=1):
"""
- Increment the line number of each node in the tree starting at *node* by *n*.
- This is useful to "move code" to a different location in a file.
+ Increment the line number and end line number of each node in the tree
+ starting at *node* by *n*. This is useful to "move code" to a different
+ location in a file.
"""
for child in walk(node):
if 'lineno' in child._attributes:
child.lineno = getattr(child, 'lineno', 0) + n
+ if 'end_lineno' in child._attributes:
+ child.end_lineno = getattr(child, 'end_lineno', 0) + n
return node
@@ -213,6 +226,77 @@ def get_docstring(node, clean=True):
return text
+def _splitlines_no_ff(source):
+ """Split a string into lines ignoring form feed and other chars.
+
+ This mimics how the Python parser splits source code.
+ """
+ idx = 0
+ lines = []
+ next_line = ''
+ while idx < len(source):
+ c = source[idx]
+ next_line += c
+ idx += 1
+ # Keep \r\n together
+ if c == '\r' and idx < len(source) and source[idx] == '\n':
+ next_line += '\n'
+ idx += 1
+ if c in '\r\n':
+ lines.append(next_line)
+ next_line = ''
+
+ if next_line:
+ lines.append(next_line)
+ return lines
+
+
+def _pad_whitespace(source):
+ """Replace all chars except '\f\t' in a line with spaces."""
+ result = ''
+ for c in source:
+ if c in '\f\t':
+ result += c
+ else:
+ result += ' '
+ return result
+
+
+def get_source_segment(source, node, *, padded=False):
+ """Get source code segment of the *source* that generated *node*.
+
+ If some location information (`lineno`, `end_lineno`, `col_offset`,
+ or `end_col_offset`) is missing, return None.
+
+ If *padded* is `True`, the first line of a multi-line statement will
+ be padded with spaces to match its original position.
+ """
+ try:
+ lineno = node.lineno - 1
+ end_lineno = node.end_lineno - 1
+ col_offset = node.col_offset
+ end_col_offset = node.end_col_offset
+ except AttributeError:
+ return None
+
+ lines = _splitlines_no_ff(source)
+ if end_lineno == lineno:
+ return lines[lineno].encode()[col_offset:end_col_offset].decode()
+
+ if padded:
+ padding = _pad_whitespace(lines[lineno].encode()[:col_offset].decode())
+ else:
+ padding = ''
+
+ first = padding + lines[lineno].encode()[col_offset:].decode()
+ last = lines[end_lineno].encode()[:end_col_offset].decode()
+ lines = lines[lineno+1:end_lineno]
+
+ lines.insert(0, first)
+ lines.append(last)
+ return ''.join(lines)
+
+
def walk(node):
"""
Recursively yield all descendant nodes in the tree starting at *node*