summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClaudiu Popa <pcmanticore@gmail.com>2014-10-20 17:41:05 +0300
committerClaudiu Popa <pcmanticore@gmail.com>2014-10-20 17:41:05 +0300
commit2e2d04b4c270f31c1a9883e9ae41bb2d4cbca3fe (patch)
tree0be6f15bf00f1a276418ca292cad029d53ebd35b
parent1ed485519d63e3fa9cd87cf336f39445d3488396 (diff)
parent3a090e2819c85abacae5dd244733408cb110e427 (diff)
downloadastroid-git-2e2d04b4c270f31c1a9883e9ae41bb2d4cbca3fe.tar.gz
Various speed improvements.
Patch by Alex Munroe.
-rw-r--r--ChangeLog2
-rw-r--r--__init__.py5
-rw-r--r--bases.py158
-rw-r--r--brain/py2stdlib.py10
-rw-r--r--inference.py87
-rw-r--r--mixins.py10
-rw-r--r--node_classes.py36
-rw-r--r--protocols.py14
-rw-r--r--rebuilder.py78
-rw-r--r--scoped_nodes.py98
-rw-r--r--test/unittest_inference.py2
-rw-r--r--test/unittest_nodes.py6
12 files changed, 254 insertions, 252 deletions
diff --git a/ChangeLog b/ChangeLog
index 0f648659..4288d89c 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -14,6 +14,8 @@ Change log for the astroid package (used to be astng)
* Fix an infinite loop with decorator call chain inference,
where the decorator returns itself. Closes issue #50.
+ * Various speed improvements. Patch by Alex Munroe.
+
2014-08-24 -- 1.2.1
* Fix a crash occurred when inferring decorator call chain.
diff --git a/__init__.py b/__init__.py
index 19c80902..6df06b1e 100644
--- a/__init__.py
+++ b/__init__.py
@@ -79,6 +79,9 @@ class AsStringRegexpPredicate(object):
If specified, the second argument is an `attrgetter` expression that will be
applied on the node first to get the actual node on which `as_string` should
be called.
+
+ WARNING: This can be fairly slow, as it has to convert every AST node back
+ to Python code; you should consider examining the AST directly instead.
"""
def __init__(self, regexp, expression=None):
self.regexp = re.compile(regexp)
@@ -98,7 +101,7 @@ def inference_tip(infer_function):
.. sourcecode:: python
MANAGER.register_transform(CallFunc, inference_tip(infer_named_tuple),
- AsStringRegexpPredicate('namedtuple', 'func'))
+ predicate)
"""
def transform(node, infer_function=infer_function):
node._explicit_inference = infer_function
diff --git a/bases.py b/bases.py
index 37e613b8..91ea91bf 100644
--- a/bases.py
+++ b/bases.py
@@ -24,6 +24,8 @@ __docformat__ = "restructuredtext en"
import sys
from contextlib import contextmanager
+from logilab.common.decorators import cachedproperty
+
from astroid.exceptions import (InferenceError, AstroidError, NotFoundError,
UnresolvableName, UseInferenceDefault)
@@ -56,63 +58,84 @@ class Proxy(object):
# Inference ##################################################################
+MISSING = object()
+
+
class InferenceContext(object):
- __slots__ = ('path', 'lookupname', 'callcontext', 'boundnode')
+ __slots__ = ('path', 'callcontext', 'boundnode', 'infered')
- def __init__(self, path=None):
+ def __init__(self,
+ path=None, callcontext=None, boundnode=None, infered=None):
if path is None:
- self.path = set()
+ self.path = frozenset()
else:
self.path = path
- self.lookupname = None
- self.callcontext = None
- self.boundnode = None
-
- def push(self, node):
- name = self.lookupname
- if (node, name) in self.path:
- raise StopIteration()
- self.path.add((node, name))
-
- def clone(self):
- # XXX copy lookupname/callcontext ?
- clone = InferenceContext(self.path)
- clone.callcontext = self.callcontext
- clone.boundnode = self.boundnode
- return clone
+ self.callcontext = callcontext
+ self.boundnode = boundnode
+ if infered is None:
+ self.infered = {}
+ else:
+ self.infered = infered
+
+ def push(self, key):
+ # This returns a NEW context with the same attributes, but a new key
+ # added to `path`. The intention is that it's only passed to callees
+ # and then destroyed; otherwise scope() may not work correctly.
+ # The cache will be shared, since it's the same exact dict.
+ if key in self.path:
+ # End the containing generator
+ raise StopIteration
+
+ return InferenceContext(
+ self.path.union([key]),
+ self.callcontext,
+ self.boundnode,
+ self.infered,
+ )
@contextmanager
- def restore_path(self):
- path = set(self.path)
- yield
- self.path = path
-
-def copy_context(context):
- if context is not None:
- return context.clone()
- else:
- return InferenceContext()
+ def scope(self, callcontext=MISSING, boundnode=MISSING):
+ try:
+ orig = self.callcontext, self.boundnode
+ if callcontext is not MISSING:
+ self.callcontext = callcontext
+ if boundnode is not MISSING:
+ self.boundnode = boundnode
+ yield
+ finally:
+ self.callcontext, self.boundnode = orig
+
+ def cache_generator(self, key, generator):
+ results = []
+ for result in generator:
+ results.append(result)
+ yield result
+
+ self.infered[key] = tuple(results)
+ return
-def _infer_stmts(stmts, context, frame=None):
+def _infer_stmts(stmts, context, frame=None, lookupname=None):
"""return an iterator on statements inferred by each statement in <stmts>
"""
stmt = None
infered = False
- if context is not None:
- name = context.lookupname
- context = context.clone()
- else:
- name = None
+ if context is None:
context = InferenceContext()
for stmt in stmts:
if stmt is YES:
yield stmt
infered = True
continue
- context.lookupname = stmt._infer_name(frame, name)
+
+ kw = {}
+ infered_name = stmt._infer_name(frame, lookupname)
+ if infered_name is not None:
+ # only returns not None if .infer() accepts a lookupname kwarg
+ kw['lookupname'] = infered_name
+
try:
- for infered in stmt.infer(context):
+ for infered in stmt.infer(context, **kw):
yield infered
infered = True
except UnresolvableName:
@@ -170,20 +193,24 @@ class Instance(Proxy):
def igetattr(self, name, context=None):
"""inferred getattr"""
+ if not context:
+ context = InferenceContext()
try:
# avoid recursively inferring the same attr on the same class
- if context:
- context.push((self._proxied, name))
+ new_context = context.push((self._proxied, name))
# XXX frame should be self._proxied, or not ?
- get_attr = self.getattr(name, context, lookupclass=False)
- return _infer_stmts(self._wrap_attr(get_attr, context), context,
- frame=self)
+ get_attr = self.getattr(name, new_context, lookupclass=False)
+ return _infer_stmts(
+ self._wrap_attr(get_attr, new_context),
+ new_context,
+ frame=self,
+ )
except NotFoundError:
try:
# fallback to class'igetattr since it has some logic to handle
# descriptors
return self._wrap_attr(self._proxied.igetattr(name, context),
- context)
+ context)
except NotFoundError:
raise InferenceError(name)
@@ -274,9 +301,9 @@ class BoundMethod(UnboundMethod):
return True
def infer_call_result(self, caller, context):
- context = context.clone()
- context.boundnode = self.bound
- return self._proxied.infer_call_result(caller, context)
+ with context.scope(boundnode=self.bound):
+ for infered in self._proxied.infer_call_result(caller, context):
+ yield infered
class Generator(Instance):
@@ -308,7 +335,8 @@ def path_wrapper(func):
"""wrapper function handling context"""
if context is None:
context = InferenceContext()
- context.push(node)
+ context = context.push((node, kwargs.get('lookupname')))
+
yielded = set()
for res in _func(node, context, **kwargs):
# unproxy only true instance, not const, tuple, dict...
@@ -377,7 +405,15 @@ class NodeNG(object):
return self._explicit_inference(self, context, **kwargs)
except UseInferenceDefault:
pass
- return self._infer(context, **kwargs)
+
+ if not context:
+ return self._infer(context, **kwargs)
+
+ key = (self, kwargs.get('lookupname'), context.callcontext, context.boundnode)
+ if key in context.infered:
+ return iter(context.infered[key])
+
+ return context.cache_generator(key, self._infer(context, **kwargs))
def _repr_name(self):
"""return self.name or self.attrname or '' for nice representation"""
@@ -415,7 +451,7 @@ class NodeNG(object):
attr = getattr(self, field)
if not attr: # None or empty listy / tuple
continue
- if isinstance(attr, (list, tuple)):
+ if attr.__class__ in (list, tuple):
return attr[-1]
else:
return attr
@@ -506,16 +542,28 @@ class NodeNG(object):
# FIXME: raise an exception if nearest is None ?
return nearest[0]
- def set_line_info(self, lastchild):
+ # these are lazy because they're relatively expensive to compute for every
+ # single node, and they rarely get looked at
+
+ @cachedproperty
+ def fromlineno(self):
if self.lineno is None:
- self.fromlineno = self._fixed_source_line()
+ return self._fixed_source_line()
+ else:
+ return self.lineno
+
+ @cachedproperty
+ def tolineno(self):
+ if not self._astroid_fields:
+ # can't have children
+ lastchild = None
else:
- self.fromlineno = self.lineno
+ lastchild = self.last_child()
if lastchild is None:
- self.tolineno = self.fromlineno
+ return self.fromlineno
else:
- self.tolineno = lastchild.tolineno
- return
+ return lastchild.tolineno
+
# TODO / FIXME:
assert self.fromlineno is not None, self
assert self.tolineno is not None, self
diff --git a/brain/py2stdlib.py b/brain/py2stdlib.py
index d728071b..92c783b4 100644
--- a/brain/py2stdlib.py
+++ b/brain/py2stdlib.py
@@ -259,6 +259,14 @@ MODULE_TRANSFORMS['subprocess'] = subprocess_transform
# namedtuple support ###########################################################
+def looks_like_namedtuple(node):
+ func = node.func
+ if type(func) is nodes.Getattr:
+ return func.attrname == 'namedtuple'
+ if type(func) is nodes.Name:
+ return func.name == 'namedtuple'
+ return False
+
def infer_named_tuple(node, context=None):
"""Specific inference function for namedtuple CallFunc node"""
class_node, name, attributes = infer_func_form(node, nodes.Tuple._proxied,
@@ -336,7 +344,7 @@ def infer_enum_class(node, context=None):
return node
MANAGER.register_transform(nodes.CallFunc, inference_tip(infer_named_tuple),
- AsStringRegexpPredicate('namedtuple', 'func'))
+ looks_like_namedtuple)
MANAGER.register_transform(nodes.CallFunc, inference_tip(infer_enum),
AsStringRegexpPredicate('Enum', 'func'))
MANAGER.register_transform(nodes.Class, infer_enum_class)
diff --git a/inference.py b/inference.py
index 4186307a..3f216ddf 100644
--- a/inference.py
+++ b/inference.py
@@ -28,7 +28,7 @@ from astroid.manager import AstroidManager
from astroid.exceptions import (AstroidError, InferenceError, NoDefault,
NotFoundError, UnresolvableName)
from astroid.bases import (YES, Instance, InferenceContext,
- _infer_stmts, copy_context, path_wrapper,
+ _infer_stmts, path_wrapper,
raise_if_nothing_infered)
from astroid.protocols import (
_arguments_infer_argname,
@@ -175,93 +175,89 @@ def infer_name(self, context=None):
if not stmts:
raise UnresolvableName(self.name)
- context = context.clone()
- context.lookupname = self.name
- return _infer_stmts(stmts, context, frame)
+ return _infer_stmts(stmts, context, frame, self.name)
nodes.Name._infer = path_wrapper(infer_name)
nodes.AssName.infer_lhs = infer_name # won't work with a path wrapper
def infer_callfunc(self, context=None):
"""infer a CallFunc node by trying to guess what the function returns"""
- callcontext = context.clone()
- callcontext.callcontext = CallContext(self.args, self.starargs, self.kwargs)
- callcontext.boundnode = None
+ if context is None:
+ context = InferenceContext()
for callee in self.func.infer(context):
- if callee is YES:
- yield callee
- continue
- try:
- if hasattr(callee, 'infer_call_result'):
- for infered in callee.infer_call_result(self, callcontext):
- yield infered
- except InferenceError:
- ## XXX log error ?
- continue
+ with context.scope(
+ callcontext=CallContext(self.args, self.starargs, self.kwargs),
+ boundnode=None,
+ ):
+ if callee is YES:
+ yield callee
+ continue
+ try:
+ if hasattr(callee, 'infer_call_result'):
+ for infered in callee.infer_call_result(self, context):
+ yield infered
+ except InferenceError:
+ ## XXX log error ?
+ continue
nodes.CallFunc._infer = path_wrapper(raise_if_nothing_infered(infer_callfunc))
-def infer_import(self, context=None, asname=True):
+def infer_import(self, context=None, asname=True, lookupname=None):
"""infer an Import node: return the imported module/object"""
- name = context.lookupname
- if name is None:
+ if lookupname is None:
raise InferenceError()
if asname:
- yield self.do_import_module(self.real_name(name))
+ yield self.do_import_module(self.real_name(lookupname))
else:
- yield self.do_import_module(name)
+ yield self.do_import_module(lookupname)
nodes.Import._infer = path_wrapper(infer_import)
def infer_name_module(self, name):
context = InferenceContext()
- context.lookupname = name
- return self.infer(context, asname=False)
+ return self.infer(context, asname=False, lookupname=name)
nodes.Import.infer_name_module = infer_name_module
-def infer_from(self, context=None, asname=True):
+def infer_from(self, context=None, asname=True, lookupname=None):
"""infer a From nodes: return the imported module/object"""
- name = context.lookupname
- if name is None:
+ if lookupname is None:
raise InferenceError()
if asname:
- name = self.real_name(name)
+ lookupname = self.real_name(lookupname)
module = self.do_import_module()
try:
- context = copy_context(context)
- context.lookupname = name
- return _infer_stmts(module.getattr(name, ignore_locals=module is self.root()), context)
+ return _infer_stmts(module.getattr(lookupname, ignore_locals=module is self.root()), context, lookupname=lookupname)
except NotFoundError:
- raise InferenceError(name)
+ raise InferenceError(lookupname)
nodes.From._infer = path_wrapper(infer_from)
def infer_getattr(self, context=None):
"""infer a Getattr node by using getattr on the associated object"""
- #context = context.clone()
+ if not context:
+ context = InferenceContext()
for owner in self.expr.infer(context):
if owner is YES:
yield owner
continue
try:
- context.boundnode = owner
- for obj in owner.igetattr(self.attrname, context):
- yield obj
- context.boundnode = None
+ with context.scope(boundnode=owner):
+ for obj in owner.igetattr(self.attrname, context):
+ yield obj
except (NotFoundError, InferenceError):
- context.boundnode = None
+ pass
except AttributeError:
# XXX method / function
- context.boundnode = None
+ pass
nodes.Getattr._infer = path_wrapper(raise_if_nothing_infered(infer_getattr))
nodes.AssAttr.infer_lhs = raise_if_nothing_infered(infer_getattr) # # won't work with a path wrapper
-def infer_global(self, context=None):
- if context.lookupname is None:
+def infer_global(self, context=None, lookupname=None):
+ if lookupname is None:
raise InferenceError()
try:
- return _infer_stmts(self.root().getattr(context.lookupname), context)
+ return _infer_stmts(self.root().getattr(lookupname), context)
except NotFoundError:
raise InferenceError()
nodes.Global._infer = path_wrapper(infer_global)
@@ -347,11 +343,10 @@ def infer_binop(self, context=None):
nodes.BinOp._infer = path_wrapper(infer_binop)
-def infer_arguments(self, context=None):
- name = context.lookupname
- if name is None:
+def infer_arguments(self, context=None, lookupname=None):
+ if lookupname is None:
raise InferenceError()
- return _arguments_infer_argname(self, name, context)
+ return _arguments_infer_argname(self, lookupname, context)
nodes.Arguments._infer = infer_arguments
diff --git a/mixins.py b/mixins.py
index f9037245..1b34c402 100644
--- a/mixins.py
+++ b/mixins.py
@@ -18,16 +18,18 @@
"""This module contains some mixins for the different nodes.
"""
+from logilab.common.decorators import cachedproperty
+
from astroid.exceptions import (AstroidBuildingException, InferenceError,
NotFoundError)
class BlockRangeMixIn(object):
"""override block range """
- def set_line_info(self, lastchild):
- self.fromlineno = self.lineno
- self.tolineno = lastchild.tolineno
- self.blockstart_tolineno = self._blockstart_toline()
+
+ @cachedproperty
+ def blockstart_tolineno(self):
+ return self.lineno
def _elsed_block_range(self, lineno, orelse, last=None):
"""handle block line numbers range for try/finally, for, if and while
diff --git a/node_classes.py b/node_classes.py
index b92fc247..13376d11 100644
--- a/node_classes.py
+++ b/node_classes.py
@@ -20,6 +20,8 @@
import sys
+from logilab.common.decorators import cachedproperty
+
from astroid.exceptions import NoDefault
from astroid.bases import (NodeNG, Statement, Instance, InferenceContext,
_infer_stmts, YES, BUILTINS)
@@ -127,8 +129,7 @@ class LookupMixIn(object):
the lookup method
"""
frame, stmts = self.lookup(name)
- context = InferenceContext()
- return _infer_stmts(stmts, context, frame)
+ return _infer_stmts(stmts, None, frame)
def _filter_stmts(self, stmts, frame, offset):
"""filter statements to remove ignorable statements.
@@ -303,6 +304,11 @@ class Arguments(NodeNG, AssignTypeMixin):
return name
return None
+ @cachedproperty
+ def fromlineno(self):
+ lineno = super(Arguments, self).fromlineno
+ return max(lineno, self.parent.fromlineno)
+
def format_args(self):
"""return arguments formatted as string"""
result = []
@@ -597,7 +603,8 @@ class ExceptHandler(Statement, AssignTypeMixin):
name = None
body = None
- def _blockstart_toline(self):
+ @cachedproperty
+ def blockstart_tolineno(self):
if self.name:
return self.name.tolineno
elif self.type:
@@ -605,11 +612,6 @@ class ExceptHandler(Statement, AssignTypeMixin):
else:
return self.lineno
- def set_line_info(self, lastchild):
- self.fromlineno = self.lineno
- self.tolineno = lastchild.tolineno
- self.blockstart_tolineno = self._blockstart_toline()
-
def catch(self, exceptions):
if self.type is None or exceptions is None:
return True
@@ -640,7 +642,8 @@ class For(BlockRangeMixIn, AssignTypeMixin, Statement):
orelse = None
optional_assign = True
- def _blockstart_toline(self):
+ @cachedproperty
+ def blockstart_tolineno(self):
return self.iter.tolineno
@@ -675,7 +678,8 @@ class If(BlockRangeMixIn, Statement):
body = None
orelse = None
- def _blockstart_toline(self):
+ @cachedproperty
+ def blockstart_tolineno(self):
return self.test.tolineno
def block_range(self, lineno):
@@ -826,9 +830,6 @@ class TryExcept(BlockRangeMixIn, Statement):
def _infer_name(self, frame, name):
return name
- def _blockstart_toline(self):
- return self.lineno
-
def block_range(self, lineno):
"""handle block line numbers range for try/except statements"""
last = None
@@ -848,9 +849,6 @@ class TryFinally(BlockRangeMixIn, Statement):
body = None
finalbody = None
- def _blockstart_toline(self):
- return self.lineno
-
def block_range(self, lineno):
"""handle block line numbers range for try/finally statements"""
child = self.body[0]
@@ -894,7 +892,8 @@ class While(BlockRangeMixIn, Statement):
body = None
orelse = None
- def _blockstart_toline(self):
+ @cachedproperty
+ def blockstart_tolineno(self):
return self.test.tolineno
def block_range(self, lineno):
@@ -908,7 +907,8 @@ class With(BlockRangeMixIn, AssignTypeMixin, Statement):
items = None
body = None
- def _blockstart_toline(self):
+ @cachedproperty
+ def blockstart_tolineno(self):
return self.items[-1][0].tolineno
def get_children(self):
diff --git a/protocols.py b/protocols.py
index e7703a06..cc284285 100644
--- a/protocols.py
+++ b/protocols.py
@@ -23,7 +23,7 @@ __doctype__ = "restructuredtext en"
from astroid.exceptions import InferenceError, NoDefault, NotFoundError
from astroid.node_classes import unpack_infer
-from astroid.bases import copy_context, \
+from astroid.bases import InferenceContext, \
raise_if_nothing_infered, yes_if_nothing_infered, Instance, YES
from astroid.nodes import const_factory
from astroid import nodes
@@ -282,7 +282,8 @@ def _arguments_infer_argname(self, name, context):
# if there is a default value, yield it. And then yield YES to reflect
# we can't guess given argument value
try:
- context = copy_context(context)
+ if context is None:
+ context = InferenceContext()
for infered in self.default_value(name).infer(context):
yield infered
yield YES
@@ -294,13 +295,8 @@ def arguments_assigned_stmts(self, node, context, asspath=None):
if context.callcontext:
# reset call context/name
callcontext = context.callcontext
- context = copy_context(context)
- context.callcontext = None
- for infered in callcontext.infer_argument(self.parent, node.name, context):
- yield infered
- return
- for infered in _arguments_infer_argname(self, node.name, context):
- yield infered
+ return callcontext.infer_argument(self.parent, node.name, context)
+ return _arguments_infer_argname(self, node.name, context)
nodes.Arguments.assigned_stmts = arguments_assigned_stmts
diff --git a/rebuilder.py b/rebuilder.py
index 8fa7ee92..e3899fd5 100644
--- a/rebuilder.py
+++ b/rebuilder.py
@@ -99,7 +99,6 @@ def _init_set_doc(node, newnode):
newnode.doc = None
try:
if isinstance(node.body[0], Discard) and isinstance(node.body[0].value, Str):
- newnode.tolineno = node.body[0].lineno
newnode.doc = node.body[0].value.s
node.body = node.body[1:]
@@ -108,10 +107,8 @@ def _init_set_doc(node, newnode):
def _lineno_parent(oldnode, newnode, parent):
newnode.parent = parent
- if hasattr(oldnode, 'lineno'):
- newnode.lineno = oldnode.lineno
- if hasattr(oldnode, 'col_offset'):
- newnode.col_offset = oldnode.col_offset
+ newnode.lineno = oldnode.lineno
+ newnode.col_offset = oldnode.col_offset
def _set_infos(oldnode, newnode, parent):
newnode.parent = parent
@@ -119,14 +116,12 @@ def _set_infos(oldnode, newnode, parent):
newnode.lineno = oldnode.lineno
if hasattr(oldnode, 'col_offset'):
newnode.col_offset = oldnode.col_offset
- newnode.set_line_info(newnode.last_child()) # set_line_info accepts None
def _create_yield_node(node, parent, rebuilder, factory):
newnode = factory()
_lineno_parent(node, newnode, parent)
if node.value is not None:
newnode.value = rebuilder.visit(node.value, newnode)
- newnode.set_line_info(newnode.last_child())
return newnode
@@ -146,10 +141,9 @@ class TreeRebuilder(object):
"""visit a Module node by returning a fresh instance of it"""
newnode = new.Module(modname, None)
newnode.package = package
- _lineno_parent(node, newnode, parent=None)
+ newnode.parent = None
_init_set_doc(node, newnode)
newnode.body = [self.visit(child, newnode) for child in node.body]
- newnode.set_line_info(newnode.last_child())
return self._transform(newnode)
def visit(self, node, parent):
@@ -174,7 +168,7 @@ class TreeRebuilder(object):
def visit_arguments(self, node, parent):
"""visit a Arguments node by returning a fresh instance of it"""
newnode = new.Arguments()
- _lineno_parent(node, newnode, parent)
+ newnode.parent = parent
self.asscontext = "Ass"
newnode.args = [self.visit(child, newnode) for child in node.args]
self.asscontext = None
@@ -210,7 +204,6 @@ class TreeRebuilder(object):
newnode.parent.set_local(vararg, newnode)
if kwarg:
newnode.parent.set_local(kwarg, newnode)
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_assattr(self, node, parent):
@@ -221,7 +214,6 @@ class TreeRebuilder(object):
newnode.expr = self.visit(node.expr, newnode)
self.asscontext = assc
self._delayed_assattr.append(newnode)
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_assert(self, node, parent):
@@ -231,7 +223,6 @@ class TreeRebuilder(object):
newnode.test = self.visit(node.test, newnode)
if node.msg is not None:
newnode.fail = self.visit(node.msg, newnode)
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_assign(self, node, parent):
@@ -259,7 +250,6 @@ class TreeRebuilder(object):
meth.extra_decorators.append(newnode.value)
except (AttributeError, KeyError):
continue
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_assname(self, node, parent, node_name=None):
@@ -279,7 +269,6 @@ class TreeRebuilder(object):
newnode.target = self.visit(node.target, newnode)
self.asscontext = None
newnode.value = self.visit(node.value, newnode)
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_backquote(self, node, parent):
@@ -287,7 +276,6 @@ class TreeRebuilder(object):
newnode = new.Backquote()
_lineno_parent(node, newnode, parent)
newnode.value = self.visit(node.value, newnode)
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_binop(self, node, parent):
@@ -297,7 +285,6 @@ class TreeRebuilder(object):
newnode.left = self.visit(node.left, newnode)
newnode.right = self.visit(node.right, newnode)
newnode.op = _BIN_OP_CLASSES[node.op.__class__]
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_boolop(self, node, parent):
@@ -306,7 +293,6 @@ class TreeRebuilder(object):
_lineno_parent(node, newnode, parent)
newnode.values = [self.visit(child, newnode) for child in node.values]
newnode.op = _BOOL_OP_CLASSES[node.op.__class__]
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_break(self, node, parent):
@@ -325,8 +311,8 @@ class TreeRebuilder(object):
newnode.starargs = self.visit(node.starargs, newnode)
if node.kwargs is not None:
newnode.kwargs = self.visit(node.kwargs, newnode)
- newnode.args.extend(self.visit(child, newnode) for child in node.keywords)
- newnode.set_line_info(newnode.last_child())
+ for child in node.keywords:
+ newnode.args.append(self.visit(child, newnode))
return newnode
def visit_class(self, node, parent):
@@ -338,7 +324,6 @@ class TreeRebuilder(object):
newnode.body = [self.visit(child, newnode) for child in node.body]
if 'decorator_list' in node._fields and node.decorator_list:# py >= 2.6
newnode.decorators = self.visit_decorators(node, newnode)
- newnode.set_line_info(newnode.last_child())
newnode.parent.frame().set_local(newnode.name, newnode)
return newnode
@@ -361,19 +346,17 @@ class TreeRebuilder(object):
newnode.left = self.visit(node.left, newnode)
newnode.ops = [(_CMP_OP_CLASSES[op.__class__], self.visit(expr, newnode))
for (op, expr) in zip(node.ops, node.comparators)]
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_comprehension(self, node, parent):
"""visit a Comprehension node by returning a fresh instance of it"""
newnode = new.Comprehension()
- _lineno_parent(node, newnode, parent)
+ newnode.parent = parent
self.asscontext = "Ass"
newnode.target = self.visit(node.target, newnode)
self.asscontext = None
newnode.iter = self.visit(node.iter, newnode)
newnode.ifs = [self.visit(child, newnode) for child in node.ifs]
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_decorators(self, node, parent):
@@ -387,7 +370,6 @@ class TreeRebuilder(object):
else:
decorators = node.decorator_list
newnode.nodes = [self.visit(child, newnode) for child in decorators]
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_delete(self, node, parent):
@@ -397,7 +379,6 @@ class TreeRebuilder(object):
self.asscontext = "Del"
newnode.targets = [self.visit(child, newnode) for child in node.targets]
self.asscontext = None
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_dict(self, node, parent):
@@ -406,7 +387,6 @@ class TreeRebuilder(object):
_lineno_parent(node, newnode, parent)
newnode.items = [(self.visit(key, newnode), self.visit(value, newnode))
for key, value in zip(node.keys, node.values)]
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_dictcomp(self, node, parent):
@@ -417,7 +397,6 @@ class TreeRebuilder(object):
newnode.value = self.visit(node.value, newnode)
newnode.generators = [self.visit(child, newnode)
for child in node.generators]
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_discard(self, node, parent):
@@ -425,7 +404,6 @@ class TreeRebuilder(object):
newnode = new.Discard()
_lineno_parent(node, newnode, parent)
newnode.value = self.visit(node.value, newnode)
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_ellipsis(self, node, parent):
@@ -452,7 +430,6 @@ class TreeRebuilder(object):
newnode.name = self.visit(node.name, newnode)
self.asscontext = None
newnode.body = [self.visit(child, newnode) for child in node.body]
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_exec(self, node, parent):
@@ -464,15 +441,13 @@ class TreeRebuilder(object):
newnode.globals = self.visit(node.globals, newnode)
if node.locals is not None:
newnode.locals = self.visit(node.locals, newnode)
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_extslice(self, node, parent):
"""visit an ExtSlice node by returning a fresh instance of it"""
newnode = new.ExtSlice()
- _lineno_parent(node, newnode, parent)
+ newnode.parent = parent
newnode.dims = [self.visit(dim, newnode) for dim in node.dims]
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_for(self, node, parent):
@@ -485,7 +460,6 @@ class TreeRebuilder(object):
newnode.iter = self.visit(node.iter, newnode)
newnode.body = [self.visit(child, newnode) for child in node.body]
newnode.orelse = [self.visit(child, newnode) for child in node.orelse]
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_from(self, node, parent):
@@ -514,7 +488,6 @@ class TreeRebuilder(object):
newnode.decorators = self.visit_decorators(node, newnode)
if PY3K and node.returns:
newnode.returns = self.visit(node.returns, newnode)
- newnode.set_line_info(newnode.last_child())
self._global_names.pop()
frame = newnode.parent.frame()
if isinstance(frame, new.Class):
@@ -538,7 +511,6 @@ class TreeRebuilder(object):
_lineno_parent(node, newnode, parent)
newnode.elt = self.visit(node.elt, newnode)
newnode.generators = [self.visit(child, newnode) for child in node.generators]
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_getattr(self, node, parent):
@@ -558,7 +530,6 @@ class TreeRebuilder(object):
newnode.expr = self.visit(node.value, newnode)
self.asscontext = asscontext
newnode.attrname = node.attr
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_global(self, node, parent):
@@ -577,7 +548,6 @@ class TreeRebuilder(object):
newnode.test = self.visit(node.test, newnode)
newnode.body = [self.visit(child, newnode) for child in node.body]
newnode.orelse = [self.visit(child, newnode) for child in node.orelse]
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_ifexp(self, node, parent):
@@ -587,7 +557,6 @@ class TreeRebuilder(object):
newnode.test = self.visit(node.test, newnode)
newnode.body = self.visit(node.body, newnode)
newnode.orelse = self.visit(node.orelse, newnode)
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_import(self, node, parent):
@@ -604,18 +573,16 @@ class TreeRebuilder(object):
def visit_index(self, node, parent):
"""visit a Index node by returning a fresh instance of it"""
newnode = new.Index()
- _lineno_parent(node, newnode, parent)
+ newnode.parent = parent
newnode.value = self.visit(node.value, newnode)
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_keyword(self, node, parent):
"""visit a Keyword node by returning a fresh instance of it"""
newnode = new.Keyword()
- _lineno_parent(node, newnode, parent)
+ newnode.parent = parent
newnode.arg = node.arg
newnode.value = self.visit(node.value, newnode)
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_lambda(self, node, parent):
@@ -624,7 +591,6 @@ class TreeRebuilder(object):
_lineno_parent(node, newnode, parent)
newnode.args = self.visit(node.args, newnode)
newnode.body = self.visit(node.body, newnode)
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_list(self, node, parent):
@@ -632,7 +598,6 @@ class TreeRebuilder(object):
newnode = new.List()
_lineno_parent(node, newnode, parent)
newnode.elts = [self.visit(child, newnode) for child in node.elts]
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_listcomp(self, node, parent):
@@ -642,7 +607,6 @@ class TreeRebuilder(object):
newnode.elt = self.visit(node.elt, newnode)
newnode.generators = [self.visit(child, newnode)
for child in node.generators]
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_name(self, node, parent):
@@ -665,7 +629,6 @@ class TreeRebuilder(object):
# XXX REMOVE me :
if self.asscontext in ('Del', 'Ass'): # 'Aug' ??
self._save_assignment(newnode)
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_bytes(self, node, parent):
@@ -700,7 +663,6 @@ class TreeRebuilder(object):
if node.dest is not None:
newnode.dest = self.visit(node.dest, newnode)
newnode.values = [self.visit(child, newnode) for child in node.values]
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_raise(self, node, parent):
@@ -713,7 +675,6 @@ class TreeRebuilder(object):
newnode.inst = self.visit(node.inst, newnode)
if node.tback is not None:
newnode.tback = self.visit(node.tback, newnode)
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_return(self, node, parent):
@@ -722,7 +683,6 @@ class TreeRebuilder(object):
_lineno_parent(node, newnode, parent)
if node.value is not None:
newnode.value = self.visit(node.value, newnode)
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_set(self, node, parent):
@@ -730,7 +690,6 @@ class TreeRebuilder(object):
newnode = new.Set()
_lineno_parent(node, newnode, parent)
newnode.elts = [self.visit(child, newnode) for child in node.elts]
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_setcomp(self, node, parent):
@@ -740,20 +699,18 @@ class TreeRebuilder(object):
newnode.elt = self.visit(node.elt, newnode)
newnode.generators = [self.visit(child, newnode)
for child in node.generators]
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_slice(self, node, parent):
"""visit a Slice node by returning a fresh instance of it"""
newnode = new.Slice()
- _lineno_parent(node, newnode, parent)
+ newnode.parent = parent
if node.lower is not None:
newnode.lower = self.visit(node.lower, newnode)
if node.upper is not None:
newnode.upper = self.visit(node.upper, newnode)
if node.step is not None:
newnode.step = self.visit(node.step, newnode)
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_subscript(self, node, parent):
@@ -764,7 +721,6 @@ class TreeRebuilder(object):
newnode.value = self.visit(node.value, newnode)
newnode.slice = self.visit(node.slice, newnode)
self.asscontext = subcontext
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_tryexcept(self, node, parent):
@@ -774,7 +730,6 @@ class TreeRebuilder(object):
newnode.body = [self.visit(child, newnode) for child in node.body]
newnode.handlers = [self.visit(child, newnode) for child in node.handlers]
newnode.orelse = [self.visit(child, newnode) for child in node.orelse]
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_tryfinally(self, node, parent):
@@ -783,7 +738,6 @@ class TreeRebuilder(object):
_lineno_parent(node, newnode, parent)
newnode.body = [self.visit(child, newnode) for child in node.body]
newnode.finalbody = [self.visit(n, newnode) for n in node.finalbody]
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_tuple(self, node, parent):
@@ -791,7 +745,6 @@ class TreeRebuilder(object):
newnode = new.Tuple()
_lineno_parent(node, newnode, parent)
newnode.elts = [self.visit(child, newnode) for child in node.elts]
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_unaryop(self, node, parent):
@@ -800,7 +753,6 @@ class TreeRebuilder(object):
_lineno_parent(node, newnode, parent)
newnode.operand = self.visit(node.operand, newnode)
newnode.op = _UNARY_OP_CLASSES[node.op.__class__]
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_while(self, node, parent):
@@ -810,7 +762,6 @@ class TreeRebuilder(object):
newnode.test = self.visit(node.test, newnode)
newnode.body = [self.visit(child, newnode) for child in node.body]
newnode.orelse = [self.visit(child, newnode) for child in node.orelse]
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_with(self, node, parent):
@@ -825,7 +776,6 @@ class TreeRebuilder(object):
self.asscontext = None
newnode.items = [(expr, vars)]
newnode.body = [self.visit(child, newnode) for child in node.body]
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_yield(self, node, parent):
@@ -867,7 +817,6 @@ class TreeRebuilder3k(TreeRebuilder):
if node.name is not None:
newnode.name = self.visit_assname(node, newnode, node.name)
newnode.body = [self.visit(child, newnode) for child in node.body]
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_nonlocal(self, node, parent):
@@ -885,7 +834,6 @@ class TreeRebuilder3k(TreeRebuilder):
newnode.exc = self.visit(node.exc, newnode)
if node.cause is not None:
newnode.cause = self.visit(node.cause, newnode)
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_starred(self, node, parent):
@@ -893,7 +841,6 @@ class TreeRebuilder3k(TreeRebuilder):
newnode = new.Starred()
_lineno_parent(node, newnode, parent)
newnode.value = self.visit(node.value, newnode)
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_try(self, node, parent):
@@ -908,7 +855,6 @@ class TreeRebuilder3k(TreeRebuilder):
excnode.body = [self.visit(child, excnode) for child in node.body]
excnode.handlers = [self.visit(child, excnode) for child in node.handlers]
excnode.orelse = [self.visit(child, excnode) for child in node.orelse]
- excnode.set_line_info(excnode.last_child())
newnode.body = [excnode]
else:
newnode.body = [self.visit(child, newnode) for child in node.body]
@@ -918,7 +864,6 @@ class TreeRebuilder3k(TreeRebuilder):
newnode.body = [self.visit(child, newnode) for child in node.body]
newnode.handlers = [self.visit(child, newnode) for child in node.handlers]
newnode.orelse = [self.visit(child, newnode) for child in node.orelse]
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_with(self, node, parent):
@@ -940,7 +885,6 @@ class TreeRebuilder3k(TreeRebuilder):
newnode.items = [visit_child(child)
for child in node.items]
newnode.body = [self.visit(child, newnode) for child in node.body]
- newnode.set_line_info(newnode.last_child())
return newnode
def visit_yieldfrom(self, node, parent):
diff --git a/scoped_nodes.py b/scoped_nodes.py
index 08fc67bf..002a7894 100644
--- a/scoped_nodes.py
+++ b/scoped_nodes.py
@@ -39,7 +39,7 @@ from astroid.node_classes import Const, DelName, DelAttr, \
Dict, From, List, Pass, Raise, Return, Tuple, Yield, YieldFrom, \
LookupMixIn, const_factory as cf, unpack_infer, Name, CallFunc
from astroid.bases import NodeNG, InferenceContext, Instance,\
- YES, Generator, UnboundMethod, BoundMethod, _infer_stmts, copy_context, \
+ YES, Generator, UnboundMethod, BoundMethod, _infer_stmts, \
BUILTINS
from astroid.mixins import FilterStmtsMixin
from astroid.bases import Statement
@@ -311,10 +311,10 @@ class Module(LocalsDictNodeNG):
"""inferred getattr"""
# set lookup name since this is necessary to infer on import nodes for
# instance
- context = copy_context(context)
- context.lookupname = name
+ if not context:
+ context = InferenceContext()
try:
- return _infer_stmts(self.getattr(name, context), context, frame=self)
+ return _infer_stmts(self.getattr(name, context), context, frame=self, lookupname=name)
except NotFoundError:
raise InferenceError(name)
@@ -339,13 +339,17 @@ class Module(LocalsDictNodeNG):
return
if sys.version_info < (2, 8):
- def absolute_import_activated(self):
+ @cachedproperty
+ def _absolute_import_activated(self):
for stmt in self.locals.get('absolute_import', ()):
if isinstance(stmt, From) and stmt.modname == '__future__':
return True
return False
else:
- absolute_import_activated = lambda self: True
+ _absolute_import_activated = True
+
+ def absolute_import_activated(self):
+ return self._absolute_import_activated
def import_module(self, modname, relative_only=False, level=None):
"""import the given module considering self as context"""
@@ -633,22 +637,25 @@ class Function(Statement, Lambda):
self.locals = {}
self.args = []
self.body = []
- self.decorators = None
self.name = name
self.doc = doc
self.extra_decorators = []
self.instance_attrs = {}
- def set_line_info(self, lastchild):
- self.fromlineno = self.lineno
- # lineno is the line number of the first decorator, we want the def statement lineno
+ @cachedproperty
+ def fromlineno(self):
+ # lineno is the line number of the first decorator, we want the def
+ # statement lineno
+ lineno = self.lineno
if self.decorators is not None:
- self.fromlineno += sum(node.tolineno - node.lineno + 1
+ lineno += sum(node.tolineno - node.lineno + 1
for node in self.decorators.nodes)
- if self.args.fromlineno < self.fromlineno:
- self.args.fromlineno = self.fromlineno
- self.tolineno = lastchild.tolineno
- self.blockstart_tolineno = self.args.tolineno
+
+ return lineno
+
+ @cachedproperty
+ def blockstart_tolineno(self):
+ return self.args.tolineno
def block_range(self, lineno):
"""return block line numbers.
@@ -884,12 +891,12 @@ class Class(Statement, LocalsDictNodeNG, FilterStmtsMixin):
doc="boolean indicating if it's a new style class"
"or not")
- def set_line_info(self, lastchild):
- self.fromlineno = self.lineno
- self.blockstart_tolineno = self.bases and self.bases[-1].tolineno or self.fromlineno
- if lastchild is not None:
- self.tolineno = lastchild.tolineno
- # else this is a class with only a docstring, then tolineno is (should be) already ok
+ @cachedproperty
+ def blockstart_tolineno(self):
+ if self.bases:
+ return self.bases[-1].tolineno
+ else:
+ return self.fromlineno
def block_range(self, lineno):
"""return block line numbers.
@@ -971,28 +978,27 @@ class Class(Statement, LocalsDictNodeNG, FilterStmtsMixin):
if context is None:
context = InferenceContext()
for stmt in self.bases:
- with context.restore_path():
- try:
- for baseobj in stmt.infer(context):
- if not isinstance(baseobj, Class):
- if isinstance(baseobj, Instance):
- baseobj = baseobj._proxied
- else:
- # duh ?
- continue
- if baseobj in yielded:
- continue # cf xxx above
- yielded.add(baseobj)
- yield baseobj
- if recurs:
- for grandpa in baseobj.ancestors(True, context):
- if grandpa in yielded:
- continue # cf xxx above
- yielded.add(grandpa)
- yield grandpa
- except InferenceError:
- # XXX log error ?
- continue
+ try:
+ for baseobj in stmt.infer(context):
+ if not isinstance(baseobj, Class):
+ if isinstance(baseobj, Instance):
+ baseobj = baseobj._proxied
+ else:
+ # duh ?
+ continue
+ if baseobj in yielded:
+ continue # cf xxx above
+ yielded.add(baseobj)
+ yield baseobj
+ if recurs:
+ for grandpa in baseobj.ancestors(True, context):
+ if grandpa in yielded:
+ continue # cf xxx above
+ yielded.add(grandpa)
+ yield grandpa
+ except InferenceError:
+ # XXX log error ?
+ continue
def local_attr_ancestors(self, name, context=None):
"""return an iterator on astroid representation of parent classes
@@ -1087,11 +1093,11 @@ class Class(Statement, LocalsDictNodeNG, FilterStmtsMixin):
"""
# set lookup name since this is necessary to infer on import nodes for
# instance
- context = copy_context(context)
- context.lookupname = name
+ if not context:
+ context = InferenceContext()
try:
for infered in _infer_stmts(self.getattr(name, context), context,
- frame=self):
+ frame=self, lookupname=name):
# yield YES object instead of descriptors when necessary
if not isinstance(infered, Const) and isinstance(infered, Instance):
try:
diff --git a/test/unittest_inference.py b/test/unittest_inference.py
index 7cabf034..a5883810 100644
--- a/test/unittest_inference.py
+++ b/test/unittest_inference.py
@@ -913,7 +913,7 @@ def f(g = lambda: None):
g().x
'''
astroid = builder.string_build(code, __name__, __file__)
- callfuncnode = astroid['f'].body[0].value.expr
+ callfuncnode = astroid['f'].body[0].value.expr # 'g()'
infered = list(callfuncnode.infer())
self.assertEqual(len(infered), 2, infered)
infered.remove(YES)
diff --git a/test/unittest_nodes.py b/test/unittest_nodes.py
index 790c3b22..d4f164ac 100644
--- a/test/unittest_nodes.py
+++ b/test/unittest_nodes.py
@@ -308,11 +308,9 @@ except PickleError:
def test_absolute_import(self):
astroid = abuilder.file_build(self.datapath('absimport.py'))
ctx = InferenceContext()
- ctx.lookupname = 'message'
# will fail if absolute import failed
- astroid['message'].infer(ctx).next()
- ctx.lookupname = 'email'
- m = astroid['email'].infer(ctx).next()
+ astroid['message'].infer(ctx, lookupname='message').next()
+ m = astroid['email'].infer(ctx, lookupname='email').next()
self.assertFalse(m.file.startswith(self.datapath('email.py')))
def test_more_absolute_import(self):