summaryrefslogtreecommitdiff
path: root/astroid
diff options
context:
space:
mode:
authorCeridwen <ceridwenv@gmail.com>2016-03-08 17:02:29 -0500
committerCeridwen <ceridwenv@gmail.com>2016-03-08 17:02:29 -0500
commit77a111204e15635130567b1d0eb8d8f147996fcc (patch)
tree80e383345bb34fc865eb2d80aeec2fcd57858117 /astroid
parent380fabb904b1f5e2eaaef4f239a6cc38a65f7b12 (diff)
downloadastroid-git-77a111204e15635130567b1d0eb8d8f147996fcc.tar.gz
Add comments and docstrings to the zipper and the tests
Diffstat (limited to 'astroid')
-rw-r--r--astroid/tests/unittest_zipper.py76
-rw-r--r--astroid/tree/zipper.py129
2 files changed, 183 insertions, 22 deletions
diff --git a/astroid/tests/unittest_zipper.py b/astroid/tests/unittest_zipper.py
index f32505c9..b8d5bb58 100644
--- a/astroid/tests/unittest_zipper.py
+++ b/astroid/tests/unittest_zipper.py
@@ -1,6 +1,12 @@
+'''Rather than generating a random AST, the zipper tests pick a random
+file out of astroid's code and parse it, running tests on the
+resulting AST. The tests create a dict-of-lists graph representation
+of the AST by using the recursive structure only, without using the
+zipper, with each node labeled by a unique integer, and then compares
+the zipper's result with what the zipper should return.
+
+'''
import collections
-import itertools
-import pprint
import os
import unittest
@@ -13,19 +19,33 @@ from astroid.tree import base
from astroid.tree import zipper
-def _all_subclasses(cls):
- return cls.__subclasses__() + [g for s in cls.__subclasses__()
- for g in _all_subclasses(s)]
-node_types_strategy = strategies.sampled_from(_all_subclasses(base.NodeNG))
-
-# This screens out the empty init files.
+# This is a strategy that generates a random file name, screening out
+# the empty init files because they produce 1-element ASTs that aren't
+# useful for testing.
astroid_file = strategies.sampled_from(os.path.join(p, n) for p, _, ns in os.walk('astroid/') for n in ns if n.endswith('.py') and '__init__.py' not in n)
+
class ASTMap(dict):
+ '''Hypothesis uses the repr of arguments to a function when printing
+ output for failed tests but the ASTs are too large to be legible,
+ so this is a simple dict subclass with a shortened repr.
+
+ '''
def __repr__(self):
return '{ 1: ' + repr(self[1]) + '...}'
class AssignLabels(object):
+ '''Traverses an AST, creating a dict with integer labels representing
+ AST nodes.
+
+ The keys of the resulting dictionary contain the actual AST node,
+ the labels of its children, and the label of its parent. The
+ labels are assigned starting at 1, for the root, in prefix order.
+ This is a replacement for an inner function in ast_from_file_name.
+ On Python 3, self.label would instead be a closure variable with
+ the nonlocal statement.
+
+ '''
Node = collections.namedtuple('Node', 'node children parent')
def __init__(self):
self.label = 1
@@ -36,12 +56,29 @@ class AssignLabels(object):
labels[label] = self.Node(node, children, parent_label)
return label
+
Node = collections.namedtuple('Node', 'node children parent edges')
+# Each edge represents a valid zipper method, with move being the
+# function corresponding to that method and label corresponding to the
+# label of the destination node.
Edge = collections.namedtuple('Edge', 'label move')
AST_CACHE = {}
def ast_from_file_name(name):
+ '''Takes a file name and creates a dict-of-lists representation of that AST.
+
+ Each key is a unique integer assigned to a node, the values are
+ tuples containing the actual node, the integer labels of the
+ children, the label of the parent, and pairs of zipper
+ methods/functions with the labels of the corresponding node that
+ zipper function will generate when applied at the key's node's
+ position.
+
+ '''
+ # Generating ASTs is slow right now because it depends on
+ # inference, so this caches one AST per file. Avoiding the global
+ # caching ensures that other tests can't mutate these ASTs.
if name in AST_CACHE:
return AST_CACHE[name]
with open(name, 'r') as source_file:
@@ -73,10 +110,10 @@ def ast_from_file_name(name):
AST_CACHE[name] = ast
return ast
+# Buid a strategy that generates digraph representations of ASTs from
+# file names.
ast_strategy = strategies.builds(ast_from_file_name, astroid_file)
-# pprint.pprint(ast_strategy.example())
-
def check_linked_list(linked_list):
'''Check that this linked list of tuples is correctly formed.'''
while linked_list:
@@ -86,6 +123,7 @@ def check_linked_list(linked_list):
assert(len(linked_list) == 0)
def check_zipper(position):
+ '''Check that a zipper is correctly formed.'''
assert(isinstance(position, (base.NodeNG, collections.Sequence)))
assert(isinstance(position._self_path, (zipper.Path, type(None))))
if position._self_path:
@@ -95,6 +133,11 @@ def check_zipper(position):
check_linked_list(position._self_path.parent_nodes)
assert isinstance(position._self_path.changed, bool)
+# These two functions are recursive implementations of preorder and
+# postorder traversals that iterate over labels rather than nodes.
+# Using recursion reduces the probability of the error being in both
+# implementations, the recursive test and the iterative functional
+# code.
def preorder_descendants(label, ast, dont_recurse_on=None):
def _preorder_descendants(label):
if dont_recurse_on is not None and isinstance(ast[label].node, dont_recurse_on):
@@ -111,15 +154,16 @@ def postorder_descendants(label, ast, dont_recurse_on=None):
return sum((_postorder_descendants(l) for l in ast[label].children), ()) + (label,)
return sum((_postorder_descendants(l) for l in ast[label].children), ()) + (label,)
+# This test function uses a set-based implementation for finding the
+# common parent rather than the reverse-based implementation in the
+# functional code.
def common_ancestor(label1, label2, ast):
ancestors = set()
while label1:
if ast[label1].node is not nodes.Empty:
ancestors.add(label1)
label1 = ast[label1].parent
- # print([ast[a].node for a in ancestors])
while label2 not in ancestors:
- # print(repr(ast[label2].node))
label2 = ast[label2].parent
return label2
@@ -135,6 +179,14 @@ def traverse_to_node(label, ast, location):
location = move(location)
return location
+# This function and strategy creates a strategy for generating a
+# random node class, for testing that the iterators properly exclude
+# nodes of that type and their descendants.
+def _all_subclasses(cls):
+ return cls.__subclasses__() + [g for s in cls.__subclasses__()
+ for g in _all_subclasses(s)]
+node_types_strategy = strategies.sampled_from(_all_subclasses(base.NodeNG))
+
class TestZipper(unittest.TestCase):
@hypothesis.settings(perform_health_check=False)
diff --git a/astroid/tree/zipper.py b/astroid/tree/zipper.py
index 220514a9..cd30b6e9 100644
--- a/astroid/tree/zipper.py
+++ b/astroid/tree/zipper.py
@@ -1,3 +1,14 @@
+'''This contains an implementation of a zipper for astroid ASTs.
+
+A zipper is a data structure for traversing and editing immutable
+recursive data types that can act as a doubly-linked structure without
+actual double links.
+http://blog.ezyang.com/2010/04/you-could-have-invented-zippers/ has a
+brief introduction to zippers as a whole. This implementation is
+based on the Clojure implementation,
+https://github.com/clojure/clojure/blob/master/src/clj/clojure/zip.clj .
+
+'''
import collections
# Because every zipper method creates a new zipper, zipper creation
@@ -80,20 +91,56 @@ def initial(linked_list):
return tail
+# Attributes:
+# left (linked list): The siblings to the left of the zipper's focus.
+# right (linked list): The siblings to the right of the zipper's focus.
+# parent_nodes (linked list): The ancestors of the zipper's focus
+# parent_path (Path): The Path from the zipper that created this zipper.
+# changed (bool): Whether this zipper has been edited or not.
Path = collections.namedtuple('Path', 'left right parent_nodes parent_path changed')
class Zipper(wrapt.ObjectProxy):
+ '''This an object-oriented version of a zipper with methods instead of
+ functions. All the methods return a new zipper or None, and none
+ of them mutate the underlying AST nodes. They return None when
+ the method is not valid for that zipper. The zipper acts as a
+ proxy so the underlying node's or sequence's methods and
+ attributes are accessible through it.
+
+ Attributes:
+ __wrapped__ (base.NodeNG, collections.Sequence): The AST node or
+ sequence at the zipper's focus.
+ _self_path (Path): The Path tuple containing information about the
+ zipper's history. This must be accessed as ._self_path.
+
+ '''
__slots__ = ('path')
- # Setting wrapt.ObjectProxy.__init__ as a default value turns it into a
- # local variable, avoiding a super() call, two globals lookups,
- # and two dict lookups (on wrapt's and ObjectProxy's dicts).
- def __init__(self, focus, path=None, init=wrapt.ObjectProxy.__init__):
- init(self, focus)
+ # Setting wrapt.ObjectProxy.__init__ as a default value turns it
+ # into a local variable, avoiding a super() call, two globals
+ # lookups, and two dict lookups (on wrapt's and ObjectProxy's
+ # dicts) in the most common zipper operation on CPython.
+ def __init__(self, focus, path=None, _init=wrapt.ObjectProxy.__init__):
+ '''Make a new zipper.
+
+ Args:
+ focus (base.NodeNG, collections.Sequence): The focus for this
+ zipper, will be assigned to self.__wrapped__ by
+ wrapt.ObjectProxy's __init__.
+ path: The path of the zipper used to create the new zipper, if any.
+
+ Returns:
+ A new zipper object.
+ '''
+ _init(self, focus)
self._self_path = path
+
# Traversal
def left(self):
+ '''Go to the next sibling that's directly to the left of the focus.
+
+ This takes constant time.'''
if self._self_path and self._self_path.left:
focus, left = self._self_path.left
path = self._self_path._replace(left=left,
@@ -102,12 +149,18 @@ class Zipper(wrapt.ObjectProxy):
return type(self)(focus=focus, path=path)
def leftmost(self):
+ '''Go to the leftmost sibling of the focus.
+
+ This takes time linear in the number of left siblings.'''
if self._self_path and self._self_path.left:
focus, siblings = last(self._self_path.left), initial(self._self_path.left)
path = self._self_path._replace(left=(), right=concatenate(reverse(siblings), (self.__wrapped__, self._self_path.right)))
return type(self)(focus=focus, path=path)
def right(self):
+ '''Go to the next sibling that's directly to the right of the focus.
+
+ This takes constant time.'''
if self._self_path and self._self_path.right:
focus, right = self._self_path.right
path = self._self_path._replace(left=(self.__wrapped__,
@@ -116,6 +169,9 @@ class Zipper(wrapt.ObjectProxy):
return type(self)(focus=focus, path=path)
def rightmost(self):
+ '''Go to the rightmost sibling of the focus.
+
+ This takes time linear in the number of right siblings.'''
if self._self_path and self._self_path.right:
siblings, focus = initial(self._self_path.right), last(self._self_path.right)
path = self._self_path._replace(left=concatenate(reverse(siblings), (self.__wrapped__, self._self_path.left)),
@@ -123,6 +179,9 @@ class Zipper(wrapt.ObjectProxy):
return type(self)(focus=focus, path=path)
def down(self):
+ '''Go to the leftmost child of the focus.
+
+ This takes constant time.'''
try:
children = iter(self.__wrapped__)
first = next(children)
@@ -137,10 +196,19 @@ class Zipper(wrapt.ObjectProxy):
return type(self)(focus=first, path=path)
def up(self):
+ '''Go to the parent of the focus.
+
+ This takes time linear in the number of left siblings if the
+ focus has been edited or constant time if it hasn't been
+ edited.
+
+ '''
if self._self_path:
left, right, parent_nodes, parent_path, changed = self._self_path
if parent_nodes:
focus = parent_nodes[0]
+ # This conditional uses parent_nodes to make going up
+ # take constant time if the focus hasn't been edited.
if changed:
return type(self)(
focus=focus.make_node(concatenate(reverse(left), (self.__wrapped__, right))),
@@ -149,13 +217,24 @@ class Zipper(wrapt.ObjectProxy):
return type(self)(focus=focus, path=parent_path)
def root(self):
- """return the root node of the tree"""
+ '''Go to the root of the AST for the focus.
+
+ This takes time linear in the number of ancestors of the focus.'''
location = self
while location._self_path:
location = location.up()
return location
def common_ancestor(self, other):
+ '''Find the most recent common ancestor of two different zippers.
+
+ This takes time linear in the number of ancestors of both foci
+ and will return None for zippers from two different ASTs. The
+ new zipper is derived from the zipper the method is called on,
+ so edits in the second argument will not be included in the
+ new zipper.
+
+ '''
if self._self_path:
self_ancestors = reverse((self.__wrapped__, self._self_path.parent_nodes))
else:
@@ -182,12 +261,24 @@ class Zipper(wrapt.ObjectProxy):
return location
def get_children(self):
+ '''Iterates over the children of the focus.'''
child = self.down()
while child is not None:
yield child
child = child.right()
+ # Iterative algorithms for these two methods, with explicit
+ # stacks, avoid the problem of yield from only being available on
+ # Python 3 and ensure that no AST will overflow the call stack.
+ # On CPython, avoiding the extra function calls necessary for a
+ # recursive algorithm will probably make them faster too.
def preorder_descendants(self, dont_recurse_on=None):
+ '''Iterates over the descendants of the focus in prefix order.
+
+ Args:
+ dont_recurse_on (base.NodeNG): If not None, will not include nodes
+ of this type or types or any of the descendants of those nodes.
+ '''
to_visit = [self]
while to_visit:
location = to_visit.pop()
@@ -201,6 +292,12 @@ class Zipper(wrapt.ObjectProxy):
if not isinstance(c, dont_recurse_on))
def postorder_descendants(self, dont_recurse_on=None):
+ '''Iterates over the descendants of the focus in postfix order.
+
+ Args:
+ dont_recurse_on (base.NodeNG): If not None, will not include nodes
+ of this type or types or any of the descendants of those nodes.
+ '''
to_visit = [self]
visited_ancestors = []
while to_visit:
@@ -220,10 +317,14 @@ class Zipper(wrapt.ObjectProxy):
to_visit.pop()
def find_descendants_of_type(self, cls, skip_class=None):
- """return an iterator on nodes which are instance of the given class(es)
-
- cls may be a class object or a tuple of class objects
- """
+ '''Iterates over the descendants of the focus of a given type in
+ prefix order.
+
+ Args:
+ skip_class (base.NodeNG, tuple(base.NodeNG)): If not None, will
+ not include nodes of this type or types or any of the
+ descendants of those nodes.
+ '''
return (d for d in self.preorder_descendants(skip_class) if isinstance(node, cls))
# if isinstance(self, cls):
# yield self
@@ -241,6 +342,8 @@ class Zipper(wrapt.ObjectProxy):
# Legacy APIs
@property
def parent(self):
+ '''Goes up to the next ancestor of the focus that's a node, not a
+ sequence.'''
location = self.up()
if isinstance(location, collections.Sequence):
return location.up()
@@ -292,6 +395,12 @@ class Zipper(wrapt.ObjectProxy):
# Editing
def replace(self, focus):
+ '''Replaces the existing node at the focus.
+
+ Args:
+ focus (base.NodeNG, collections.Sequence): The object to replace
+ the focus with.
+ '''
return type(self)(focus=focus, path=self._self_path._replace(changed=True))
# def edit(self, *args, **kws):