diff options
| author | Claudiu Popa <pcmanticore@gmail.com> | 2014-10-20 17:41:05 +0300 |
|---|---|---|
| committer | Claudiu Popa <pcmanticore@gmail.com> | 2014-10-20 17:41:05 +0300 |
| commit | 2e2d04b4c270f31c1a9883e9ae41bb2d4cbca3fe (patch) | |
| tree | 0be6f15bf00f1a276418ca292cad029d53ebd35b | |
| parent | 1ed485519d63e3fa9cd87cf336f39445d3488396 (diff) | |
| parent | 3a090e2819c85abacae5dd244733408cb110e427 (diff) | |
| download | astroid-git-2e2d04b4c270f31c1a9883e9ae41bb2d4cbca3fe.tar.gz | |
Various speed improvements.
Patch by Alex Munroe.
| -rw-r--r-- | ChangeLog | 2 | ||||
| -rw-r--r-- | __init__.py | 5 | ||||
| -rw-r--r-- | bases.py | 158 | ||||
| -rw-r--r-- | brain/py2stdlib.py | 10 | ||||
| -rw-r--r-- | inference.py | 87 | ||||
| -rw-r--r-- | mixins.py | 10 | ||||
| -rw-r--r-- | node_classes.py | 36 | ||||
| -rw-r--r-- | protocols.py | 14 | ||||
| -rw-r--r-- | rebuilder.py | 78 | ||||
| -rw-r--r-- | scoped_nodes.py | 98 | ||||
| -rw-r--r-- | test/unittest_inference.py | 2 | ||||
| -rw-r--r-- | test/unittest_nodes.py | 6 |
12 files changed, 254 insertions, 252 deletions
@@ -14,6 +14,8 @@ Change log for the astroid package (used to be astng) * Fix an infinite loop with decorator call chain inference, where the decorator returns itself. Closes issue #50. + * Various speed improvements. Patch by Alex Munroe. + 2014-08-24 -- 1.2.1 * Fix a crash occurred when inferring decorator call chain. diff --git a/__init__.py b/__init__.py index 19c80902..6df06b1e 100644 --- a/__init__.py +++ b/__init__.py @@ -79,6 +79,9 @@ class AsStringRegexpPredicate(object): If specified, the second argument is an `attrgetter` expression that will be applied on the node first to get the actual node on which `as_string` should be called. + + WARNING: This can be fairly slow, as it has to convert every AST node back + to Python code; you should consider examining the AST directly instead. """ def __init__(self, regexp, expression=None): self.regexp = re.compile(regexp) @@ -98,7 +101,7 @@ def inference_tip(infer_function): .. sourcecode:: python MANAGER.register_transform(CallFunc, inference_tip(infer_named_tuple), - AsStringRegexpPredicate('namedtuple', 'func')) + predicate) """ def transform(node, infer_function=infer_function): node._explicit_inference = infer_function @@ -24,6 +24,8 @@ __docformat__ = "restructuredtext en" import sys from contextlib import contextmanager +from logilab.common.decorators import cachedproperty + from astroid.exceptions import (InferenceError, AstroidError, NotFoundError, UnresolvableName, UseInferenceDefault) @@ -56,63 +58,84 @@ class Proxy(object): # Inference ################################################################## +MISSING = object() + + class InferenceContext(object): - __slots__ = ('path', 'lookupname', 'callcontext', 'boundnode') + __slots__ = ('path', 'callcontext', 'boundnode', 'infered') - def __init__(self, path=None): + def __init__(self, + path=None, callcontext=None, boundnode=None, infered=None): if path is None: - self.path = set() + self.path = frozenset() else: self.path = path - self.lookupname = None - self.callcontext = None - self.boundnode = None - - def push(self, node): - name = self.lookupname - if (node, name) in self.path: - raise StopIteration() - self.path.add((node, name)) - - def clone(self): - # XXX copy lookupname/callcontext ? - clone = InferenceContext(self.path) - clone.callcontext = self.callcontext - clone.boundnode = self.boundnode - return clone + self.callcontext = callcontext + self.boundnode = boundnode + if infered is None: + self.infered = {} + else: + self.infered = infered + + def push(self, key): + # This returns a NEW context with the same attributes, but a new key + # added to `path`. The intention is that it's only passed to callees + # and then destroyed; otherwise scope() may not work correctly. + # The cache will be shared, since it's the same exact dict. + if key in self.path: + # End the containing generator + raise StopIteration + + return InferenceContext( + self.path.union([key]), + self.callcontext, + self.boundnode, + self.infered, + ) @contextmanager - def restore_path(self): - path = set(self.path) - yield - self.path = path - -def copy_context(context): - if context is not None: - return context.clone() - else: - return InferenceContext() + def scope(self, callcontext=MISSING, boundnode=MISSING): + try: + orig = self.callcontext, self.boundnode + if callcontext is not MISSING: + self.callcontext = callcontext + if boundnode is not MISSING: + self.boundnode = boundnode + yield + finally: + self.callcontext, self.boundnode = orig + + def cache_generator(self, key, generator): + results = [] + for result in generator: + results.append(result) + yield result + + self.infered[key] = tuple(results) + return -def _infer_stmts(stmts, context, frame=None): +def _infer_stmts(stmts, context, frame=None, lookupname=None): """return an iterator on statements inferred by each statement in <stmts> """ stmt = None infered = False - if context is not None: - name = context.lookupname - context = context.clone() - else: - name = None + if context is None: context = InferenceContext() for stmt in stmts: if stmt is YES: yield stmt infered = True continue - context.lookupname = stmt._infer_name(frame, name) + + kw = {} + infered_name = stmt._infer_name(frame, lookupname) + if infered_name is not None: + # only returns not None if .infer() accepts a lookupname kwarg + kw['lookupname'] = infered_name + try: - for infered in stmt.infer(context): + for infered in stmt.infer(context, **kw): yield infered infered = True except UnresolvableName: @@ -170,20 +193,24 @@ class Instance(Proxy): def igetattr(self, name, context=None): """inferred getattr""" + if not context: + context = InferenceContext() try: # avoid recursively inferring the same attr on the same class - if context: - context.push((self._proxied, name)) + new_context = context.push((self._proxied, name)) # XXX frame should be self._proxied, or not ? - get_attr = self.getattr(name, context, lookupclass=False) - return _infer_stmts(self._wrap_attr(get_attr, context), context, - frame=self) + get_attr = self.getattr(name, new_context, lookupclass=False) + return _infer_stmts( + self._wrap_attr(get_attr, new_context), + new_context, + frame=self, + ) except NotFoundError: try: # fallback to class'igetattr since it has some logic to handle # descriptors return self._wrap_attr(self._proxied.igetattr(name, context), - context) + context) except NotFoundError: raise InferenceError(name) @@ -274,9 +301,9 @@ class BoundMethod(UnboundMethod): return True def infer_call_result(self, caller, context): - context = context.clone() - context.boundnode = self.bound - return self._proxied.infer_call_result(caller, context) + with context.scope(boundnode=self.bound): + for infered in self._proxied.infer_call_result(caller, context): + yield infered class Generator(Instance): @@ -308,7 +335,8 @@ def path_wrapper(func): """wrapper function handling context""" if context is None: context = InferenceContext() - context.push(node) + context = context.push((node, kwargs.get('lookupname'))) + yielded = set() for res in _func(node, context, **kwargs): # unproxy only true instance, not const, tuple, dict... @@ -377,7 +405,15 @@ class NodeNG(object): return self._explicit_inference(self, context, **kwargs) except UseInferenceDefault: pass - return self._infer(context, **kwargs) + + if not context: + return self._infer(context, **kwargs) + + key = (self, kwargs.get('lookupname'), context.callcontext, context.boundnode) + if key in context.infered: + return iter(context.infered[key]) + + return context.cache_generator(key, self._infer(context, **kwargs)) def _repr_name(self): """return self.name or self.attrname or '' for nice representation""" @@ -415,7 +451,7 @@ class NodeNG(object): attr = getattr(self, field) if not attr: # None or empty listy / tuple continue - if isinstance(attr, (list, tuple)): + if attr.__class__ in (list, tuple): return attr[-1] else: return attr @@ -506,16 +542,28 @@ class NodeNG(object): # FIXME: raise an exception if nearest is None ? return nearest[0] - def set_line_info(self, lastchild): + # these are lazy because they're relatively expensive to compute for every + # single node, and they rarely get looked at + + @cachedproperty + def fromlineno(self): if self.lineno is None: - self.fromlineno = self._fixed_source_line() + return self._fixed_source_line() + else: + return self.lineno + + @cachedproperty + def tolineno(self): + if not self._astroid_fields: + # can't have children + lastchild = None else: - self.fromlineno = self.lineno + lastchild = self.last_child() if lastchild is None: - self.tolineno = self.fromlineno + return self.fromlineno else: - self.tolineno = lastchild.tolineno - return + return lastchild.tolineno + # TODO / FIXME: assert self.fromlineno is not None, self assert self.tolineno is not None, self diff --git a/brain/py2stdlib.py b/brain/py2stdlib.py index d728071b..92c783b4 100644 --- a/brain/py2stdlib.py +++ b/brain/py2stdlib.py @@ -259,6 +259,14 @@ MODULE_TRANSFORMS['subprocess'] = subprocess_transform # namedtuple support ########################################################### +def looks_like_namedtuple(node): + func = node.func + if type(func) is nodes.Getattr: + return func.attrname == 'namedtuple' + if type(func) is nodes.Name: + return func.name == 'namedtuple' + return False + def infer_named_tuple(node, context=None): """Specific inference function for namedtuple CallFunc node""" class_node, name, attributes = infer_func_form(node, nodes.Tuple._proxied, @@ -336,7 +344,7 @@ def infer_enum_class(node, context=None): return node MANAGER.register_transform(nodes.CallFunc, inference_tip(infer_named_tuple), - AsStringRegexpPredicate('namedtuple', 'func')) + looks_like_namedtuple) MANAGER.register_transform(nodes.CallFunc, inference_tip(infer_enum), AsStringRegexpPredicate('Enum', 'func')) MANAGER.register_transform(nodes.Class, infer_enum_class) diff --git a/inference.py b/inference.py index 4186307a..3f216ddf 100644 --- a/inference.py +++ b/inference.py @@ -28,7 +28,7 @@ from astroid.manager import AstroidManager from astroid.exceptions import (AstroidError, InferenceError, NoDefault, NotFoundError, UnresolvableName) from astroid.bases import (YES, Instance, InferenceContext, - _infer_stmts, copy_context, path_wrapper, + _infer_stmts, path_wrapper, raise_if_nothing_infered) from astroid.protocols import ( _arguments_infer_argname, @@ -175,93 +175,89 @@ def infer_name(self, context=None): if not stmts: raise UnresolvableName(self.name) - context = context.clone() - context.lookupname = self.name - return _infer_stmts(stmts, context, frame) + return _infer_stmts(stmts, context, frame, self.name) nodes.Name._infer = path_wrapper(infer_name) nodes.AssName.infer_lhs = infer_name # won't work with a path wrapper def infer_callfunc(self, context=None): """infer a CallFunc node by trying to guess what the function returns""" - callcontext = context.clone() - callcontext.callcontext = CallContext(self.args, self.starargs, self.kwargs) - callcontext.boundnode = None + if context is None: + context = InferenceContext() for callee in self.func.infer(context): - if callee is YES: - yield callee - continue - try: - if hasattr(callee, 'infer_call_result'): - for infered in callee.infer_call_result(self, callcontext): - yield infered - except InferenceError: - ## XXX log error ? - continue + with context.scope( + callcontext=CallContext(self.args, self.starargs, self.kwargs), + boundnode=None, + ): + if callee is YES: + yield callee + continue + try: + if hasattr(callee, 'infer_call_result'): + for infered in callee.infer_call_result(self, context): + yield infered + except InferenceError: + ## XXX log error ? + continue nodes.CallFunc._infer = path_wrapper(raise_if_nothing_infered(infer_callfunc)) -def infer_import(self, context=None, asname=True): +def infer_import(self, context=None, asname=True, lookupname=None): """infer an Import node: return the imported module/object""" - name = context.lookupname - if name is None: + if lookupname is None: raise InferenceError() if asname: - yield self.do_import_module(self.real_name(name)) + yield self.do_import_module(self.real_name(lookupname)) else: - yield self.do_import_module(name) + yield self.do_import_module(lookupname) nodes.Import._infer = path_wrapper(infer_import) def infer_name_module(self, name): context = InferenceContext() - context.lookupname = name - return self.infer(context, asname=False) + return self.infer(context, asname=False, lookupname=name) nodes.Import.infer_name_module = infer_name_module -def infer_from(self, context=None, asname=True): +def infer_from(self, context=None, asname=True, lookupname=None): """infer a From nodes: return the imported module/object""" - name = context.lookupname - if name is None: + if lookupname is None: raise InferenceError() if asname: - name = self.real_name(name) + lookupname = self.real_name(lookupname) module = self.do_import_module() try: - context = copy_context(context) - context.lookupname = name - return _infer_stmts(module.getattr(name, ignore_locals=module is self.root()), context) + return _infer_stmts(module.getattr(lookupname, ignore_locals=module is self.root()), context, lookupname=lookupname) except NotFoundError: - raise InferenceError(name) + raise InferenceError(lookupname) nodes.From._infer = path_wrapper(infer_from) def infer_getattr(self, context=None): """infer a Getattr node by using getattr on the associated object""" - #context = context.clone() + if not context: + context = InferenceContext() for owner in self.expr.infer(context): if owner is YES: yield owner continue try: - context.boundnode = owner - for obj in owner.igetattr(self.attrname, context): - yield obj - context.boundnode = None + with context.scope(boundnode=owner): + for obj in owner.igetattr(self.attrname, context): + yield obj except (NotFoundError, InferenceError): - context.boundnode = None + pass except AttributeError: # XXX method / function - context.boundnode = None + pass nodes.Getattr._infer = path_wrapper(raise_if_nothing_infered(infer_getattr)) nodes.AssAttr.infer_lhs = raise_if_nothing_infered(infer_getattr) # # won't work with a path wrapper -def infer_global(self, context=None): - if context.lookupname is None: +def infer_global(self, context=None, lookupname=None): + if lookupname is None: raise InferenceError() try: - return _infer_stmts(self.root().getattr(context.lookupname), context) + return _infer_stmts(self.root().getattr(lookupname), context) except NotFoundError: raise InferenceError() nodes.Global._infer = path_wrapper(infer_global) @@ -347,11 +343,10 @@ def infer_binop(self, context=None): nodes.BinOp._infer = path_wrapper(infer_binop) -def infer_arguments(self, context=None): - name = context.lookupname - if name is None: +def infer_arguments(self, context=None, lookupname=None): + if lookupname is None: raise InferenceError() - return _arguments_infer_argname(self, name, context) + return _arguments_infer_argname(self, lookupname, context) nodes.Arguments._infer = infer_arguments @@ -18,16 +18,18 @@ """This module contains some mixins for the different nodes. """ +from logilab.common.decorators import cachedproperty + from astroid.exceptions import (AstroidBuildingException, InferenceError, NotFoundError) class BlockRangeMixIn(object): """override block range """ - def set_line_info(self, lastchild): - self.fromlineno = self.lineno - self.tolineno = lastchild.tolineno - self.blockstart_tolineno = self._blockstart_toline() + + @cachedproperty + def blockstart_tolineno(self): + return self.lineno def _elsed_block_range(self, lineno, orelse, last=None): """handle block line numbers range for try/finally, for, if and while diff --git a/node_classes.py b/node_classes.py index b92fc247..13376d11 100644 --- a/node_classes.py +++ b/node_classes.py @@ -20,6 +20,8 @@ import sys +from logilab.common.decorators import cachedproperty + from astroid.exceptions import NoDefault from astroid.bases import (NodeNG, Statement, Instance, InferenceContext, _infer_stmts, YES, BUILTINS) @@ -127,8 +129,7 @@ class LookupMixIn(object): the lookup method """ frame, stmts = self.lookup(name) - context = InferenceContext() - return _infer_stmts(stmts, context, frame) + return _infer_stmts(stmts, None, frame) def _filter_stmts(self, stmts, frame, offset): """filter statements to remove ignorable statements. @@ -303,6 +304,11 @@ class Arguments(NodeNG, AssignTypeMixin): return name return None + @cachedproperty + def fromlineno(self): + lineno = super(Arguments, self).fromlineno + return max(lineno, self.parent.fromlineno) + def format_args(self): """return arguments formatted as string""" result = [] @@ -597,7 +603,8 @@ class ExceptHandler(Statement, AssignTypeMixin): name = None body = None - def _blockstart_toline(self): + @cachedproperty + def blockstart_tolineno(self): if self.name: return self.name.tolineno elif self.type: @@ -605,11 +612,6 @@ class ExceptHandler(Statement, AssignTypeMixin): else: return self.lineno - def set_line_info(self, lastchild): - self.fromlineno = self.lineno - self.tolineno = lastchild.tolineno - self.blockstart_tolineno = self._blockstart_toline() - def catch(self, exceptions): if self.type is None or exceptions is None: return True @@ -640,7 +642,8 @@ class For(BlockRangeMixIn, AssignTypeMixin, Statement): orelse = None optional_assign = True - def _blockstart_toline(self): + @cachedproperty + def blockstart_tolineno(self): return self.iter.tolineno @@ -675,7 +678,8 @@ class If(BlockRangeMixIn, Statement): body = None orelse = None - def _blockstart_toline(self): + @cachedproperty + def blockstart_tolineno(self): return self.test.tolineno def block_range(self, lineno): @@ -826,9 +830,6 @@ class TryExcept(BlockRangeMixIn, Statement): def _infer_name(self, frame, name): return name - def _blockstart_toline(self): - return self.lineno - def block_range(self, lineno): """handle block line numbers range for try/except statements""" last = None @@ -848,9 +849,6 @@ class TryFinally(BlockRangeMixIn, Statement): body = None finalbody = None - def _blockstart_toline(self): - return self.lineno - def block_range(self, lineno): """handle block line numbers range for try/finally statements""" child = self.body[0] @@ -894,7 +892,8 @@ class While(BlockRangeMixIn, Statement): body = None orelse = None - def _blockstart_toline(self): + @cachedproperty + def blockstart_tolineno(self): return self.test.tolineno def block_range(self, lineno): @@ -908,7 +907,8 @@ class With(BlockRangeMixIn, AssignTypeMixin, Statement): items = None body = None - def _blockstart_toline(self): + @cachedproperty + def blockstart_tolineno(self): return self.items[-1][0].tolineno def get_children(self): diff --git a/protocols.py b/protocols.py index e7703a06..cc284285 100644 --- a/protocols.py +++ b/protocols.py @@ -23,7 +23,7 @@ __doctype__ = "restructuredtext en" from astroid.exceptions import InferenceError, NoDefault, NotFoundError from astroid.node_classes import unpack_infer -from astroid.bases import copy_context, \ +from astroid.bases import InferenceContext, \ raise_if_nothing_infered, yes_if_nothing_infered, Instance, YES from astroid.nodes import const_factory from astroid import nodes @@ -282,7 +282,8 @@ def _arguments_infer_argname(self, name, context): # if there is a default value, yield it. And then yield YES to reflect # we can't guess given argument value try: - context = copy_context(context) + if context is None: + context = InferenceContext() for infered in self.default_value(name).infer(context): yield infered yield YES @@ -294,13 +295,8 @@ def arguments_assigned_stmts(self, node, context, asspath=None): if context.callcontext: # reset call context/name callcontext = context.callcontext - context = copy_context(context) - context.callcontext = None - for infered in callcontext.infer_argument(self.parent, node.name, context): - yield infered - return - for infered in _arguments_infer_argname(self, node.name, context): - yield infered + return callcontext.infer_argument(self.parent, node.name, context) + return _arguments_infer_argname(self, node.name, context) nodes.Arguments.assigned_stmts = arguments_assigned_stmts diff --git a/rebuilder.py b/rebuilder.py index 8fa7ee92..e3899fd5 100644 --- a/rebuilder.py +++ b/rebuilder.py @@ -99,7 +99,6 @@ def _init_set_doc(node, newnode): newnode.doc = None try: if isinstance(node.body[0], Discard) and isinstance(node.body[0].value, Str): - newnode.tolineno = node.body[0].lineno newnode.doc = node.body[0].value.s node.body = node.body[1:] @@ -108,10 +107,8 @@ def _init_set_doc(node, newnode): def _lineno_parent(oldnode, newnode, parent): newnode.parent = parent - if hasattr(oldnode, 'lineno'): - newnode.lineno = oldnode.lineno - if hasattr(oldnode, 'col_offset'): - newnode.col_offset = oldnode.col_offset + newnode.lineno = oldnode.lineno + newnode.col_offset = oldnode.col_offset def _set_infos(oldnode, newnode, parent): newnode.parent = parent @@ -119,14 +116,12 @@ def _set_infos(oldnode, newnode, parent): newnode.lineno = oldnode.lineno if hasattr(oldnode, 'col_offset'): newnode.col_offset = oldnode.col_offset - newnode.set_line_info(newnode.last_child()) # set_line_info accepts None def _create_yield_node(node, parent, rebuilder, factory): newnode = factory() _lineno_parent(node, newnode, parent) if node.value is not None: newnode.value = rebuilder.visit(node.value, newnode) - newnode.set_line_info(newnode.last_child()) return newnode @@ -146,10 +141,9 @@ class TreeRebuilder(object): """visit a Module node by returning a fresh instance of it""" newnode = new.Module(modname, None) newnode.package = package - _lineno_parent(node, newnode, parent=None) + newnode.parent = None _init_set_doc(node, newnode) newnode.body = [self.visit(child, newnode) for child in node.body] - newnode.set_line_info(newnode.last_child()) return self._transform(newnode) def visit(self, node, parent): @@ -174,7 +168,7 @@ class TreeRebuilder(object): def visit_arguments(self, node, parent): """visit a Arguments node by returning a fresh instance of it""" newnode = new.Arguments() - _lineno_parent(node, newnode, parent) + newnode.parent = parent self.asscontext = "Ass" newnode.args = [self.visit(child, newnode) for child in node.args] self.asscontext = None @@ -210,7 +204,6 @@ class TreeRebuilder(object): newnode.parent.set_local(vararg, newnode) if kwarg: newnode.parent.set_local(kwarg, newnode) - newnode.set_line_info(newnode.last_child()) return newnode def visit_assattr(self, node, parent): @@ -221,7 +214,6 @@ class TreeRebuilder(object): newnode.expr = self.visit(node.expr, newnode) self.asscontext = assc self._delayed_assattr.append(newnode) - newnode.set_line_info(newnode.last_child()) return newnode def visit_assert(self, node, parent): @@ -231,7 +223,6 @@ class TreeRebuilder(object): newnode.test = self.visit(node.test, newnode) if node.msg is not None: newnode.fail = self.visit(node.msg, newnode) - newnode.set_line_info(newnode.last_child()) return newnode def visit_assign(self, node, parent): @@ -259,7 +250,6 @@ class TreeRebuilder(object): meth.extra_decorators.append(newnode.value) except (AttributeError, KeyError): continue - newnode.set_line_info(newnode.last_child()) return newnode def visit_assname(self, node, parent, node_name=None): @@ -279,7 +269,6 @@ class TreeRebuilder(object): newnode.target = self.visit(node.target, newnode) self.asscontext = None newnode.value = self.visit(node.value, newnode) - newnode.set_line_info(newnode.last_child()) return newnode def visit_backquote(self, node, parent): @@ -287,7 +276,6 @@ class TreeRebuilder(object): newnode = new.Backquote() _lineno_parent(node, newnode, parent) newnode.value = self.visit(node.value, newnode) - newnode.set_line_info(newnode.last_child()) return newnode def visit_binop(self, node, parent): @@ -297,7 +285,6 @@ class TreeRebuilder(object): newnode.left = self.visit(node.left, newnode) newnode.right = self.visit(node.right, newnode) newnode.op = _BIN_OP_CLASSES[node.op.__class__] - newnode.set_line_info(newnode.last_child()) return newnode def visit_boolop(self, node, parent): @@ -306,7 +293,6 @@ class TreeRebuilder(object): _lineno_parent(node, newnode, parent) newnode.values = [self.visit(child, newnode) for child in node.values] newnode.op = _BOOL_OP_CLASSES[node.op.__class__] - newnode.set_line_info(newnode.last_child()) return newnode def visit_break(self, node, parent): @@ -325,8 +311,8 @@ class TreeRebuilder(object): newnode.starargs = self.visit(node.starargs, newnode) if node.kwargs is not None: newnode.kwargs = self.visit(node.kwargs, newnode) - newnode.args.extend(self.visit(child, newnode) for child in node.keywords) - newnode.set_line_info(newnode.last_child()) + for child in node.keywords: + newnode.args.append(self.visit(child, newnode)) return newnode def visit_class(self, node, parent): @@ -338,7 +324,6 @@ class TreeRebuilder(object): newnode.body = [self.visit(child, newnode) for child in node.body] if 'decorator_list' in node._fields and node.decorator_list:# py >= 2.6 newnode.decorators = self.visit_decorators(node, newnode) - newnode.set_line_info(newnode.last_child()) newnode.parent.frame().set_local(newnode.name, newnode) return newnode @@ -361,19 +346,17 @@ class TreeRebuilder(object): newnode.left = self.visit(node.left, newnode) newnode.ops = [(_CMP_OP_CLASSES[op.__class__], self.visit(expr, newnode)) for (op, expr) in zip(node.ops, node.comparators)] - newnode.set_line_info(newnode.last_child()) return newnode def visit_comprehension(self, node, parent): """visit a Comprehension node by returning a fresh instance of it""" newnode = new.Comprehension() - _lineno_parent(node, newnode, parent) + newnode.parent = parent self.asscontext = "Ass" newnode.target = self.visit(node.target, newnode) self.asscontext = None newnode.iter = self.visit(node.iter, newnode) newnode.ifs = [self.visit(child, newnode) for child in node.ifs] - newnode.set_line_info(newnode.last_child()) return newnode def visit_decorators(self, node, parent): @@ -387,7 +370,6 @@ class TreeRebuilder(object): else: decorators = node.decorator_list newnode.nodes = [self.visit(child, newnode) for child in decorators] - newnode.set_line_info(newnode.last_child()) return newnode def visit_delete(self, node, parent): @@ -397,7 +379,6 @@ class TreeRebuilder(object): self.asscontext = "Del" newnode.targets = [self.visit(child, newnode) for child in node.targets] self.asscontext = None - newnode.set_line_info(newnode.last_child()) return newnode def visit_dict(self, node, parent): @@ -406,7 +387,6 @@ class TreeRebuilder(object): _lineno_parent(node, newnode, parent) newnode.items = [(self.visit(key, newnode), self.visit(value, newnode)) for key, value in zip(node.keys, node.values)] - newnode.set_line_info(newnode.last_child()) return newnode def visit_dictcomp(self, node, parent): @@ -417,7 +397,6 @@ class TreeRebuilder(object): newnode.value = self.visit(node.value, newnode) newnode.generators = [self.visit(child, newnode) for child in node.generators] - newnode.set_line_info(newnode.last_child()) return newnode def visit_discard(self, node, parent): @@ -425,7 +404,6 @@ class TreeRebuilder(object): newnode = new.Discard() _lineno_parent(node, newnode, parent) newnode.value = self.visit(node.value, newnode) - newnode.set_line_info(newnode.last_child()) return newnode def visit_ellipsis(self, node, parent): @@ -452,7 +430,6 @@ class TreeRebuilder(object): newnode.name = self.visit(node.name, newnode) self.asscontext = None newnode.body = [self.visit(child, newnode) for child in node.body] - newnode.set_line_info(newnode.last_child()) return newnode def visit_exec(self, node, parent): @@ -464,15 +441,13 @@ class TreeRebuilder(object): newnode.globals = self.visit(node.globals, newnode) if node.locals is not None: newnode.locals = self.visit(node.locals, newnode) - newnode.set_line_info(newnode.last_child()) return newnode def visit_extslice(self, node, parent): """visit an ExtSlice node by returning a fresh instance of it""" newnode = new.ExtSlice() - _lineno_parent(node, newnode, parent) + newnode.parent = parent newnode.dims = [self.visit(dim, newnode) for dim in node.dims] - newnode.set_line_info(newnode.last_child()) return newnode def visit_for(self, node, parent): @@ -485,7 +460,6 @@ class TreeRebuilder(object): newnode.iter = self.visit(node.iter, newnode) newnode.body = [self.visit(child, newnode) for child in node.body] newnode.orelse = [self.visit(child, newnode) for child in node.orelse] - newnode.set_line_info(newnode.last_child()) return newnode def visit_from(self, node, parent): @@ -514,7 +488,6 @@ class TreeRebuilder(object): newnode.decorators = self.visit_decorators(node, newnode) if PY3K and node.returns: newnode.returns = self.visit(node.returns, newnode) - newnode.set_line_info(newnode.last_child()) self._global_names.pop() frame = newnode.parent.frame() if isinstance(frame, new.Class): @@ -538,7 +511,6 @@ class TreeRebuilder(object): _lineno_parent(node, newnode, parent) newnode.elt = self.visit(node.elt, newnode) newnode.generators = [self.visit(child, newnode) for child in node.generators] - newnode.set_line_info(newnode.last_child()) return newnode def visit_getattr(self, node, parent): @@ -558,7 +530,6 @@ class TreeRebuilder(object): newnode.expr = self.visit(node.value, newnode) self.asscontext = asscontext newnode.attrname = node.attr - newnode.set_line_info(newnode.last_child()) return newnode def visit_global(self, node, parent): @@ -577,7 +548,6 @@ class TreeRebuilder(object): newnode.test = self.visit(node.test, newnode) newnode.body = [self.visit(child, newnode) for child in node.body] newnode.orelse = [self.visit(child, newnode) for child in node.orelse] - newnode.set_line_info(newnode.last_child()) return newnode def visit_ifexp(self, node, parent): @@ -587,7 +557,6 @@ class TreeRebuilder(object): newnode.test = self.visit(node.test, newnode) newnode.body = self.visit(node.body, newnode) newnode.orelse = self.visit(node.orelse, newnode) - newnode.set_line_info(newnode.last_child()) return newnode def visit_import(self, node, parent): @@ -604,18 +573,16 @@ class TreeRebuilder(object): def visit_index(self, node, parent): """visit a Index node by returning a fresh instance of it""" newnode = new.Index() - _lineno_parent(node, newnode, parent) + newnode.parent = parent newnode.value = self.visit(node.value, newnode) - newnode.set_line_info(newnode.last_child()) return newnode def visit_keyword(self, node, parent): """visit a Keyword node by returning a fresh instance of it""" newnode = new.Keyword() - _lineno_parent(node, newnode, parent) + newnode.parent = parent newnode.arg = node.arg newnode.value = self.visit(node.value, newnode) - newnode.set_line_info(newnode.last_child()) return newnode def visit_lambda(self, node, parent): @@ -624,7 +591,6 @@ class TreeRebuilder(object): _lineno_parent(node, newnode, parent) newnode.args = self.visit(node.args, newnode) newnode.body = self.visit(node.body, newnode) - newnode.set_line_info(newnode.last_child()) return newnode def visit_list(self, node, parent): @@ -632,7 +598,6 @@ class TreeRebuilder(object): newnode = new.List() _lineno_parent(node, newnode, parent) newnode.elts = [self.visit(child, newnode) for child in node.elts] - newnode.set_line_info(newnode.last_child()) return newnode def visit_listcomp(self, node, parent): @@ -642,7 +607,6 @@ class TreeRebuilder(object): newnode.elt = self.visit(node.elt, newnode) newnode.generators = [self.visit(child, newnode) for child in node.generators] - newnode.set_line_info(newnode.last_child()) return newnode def visit_name(self, node, parent): @@ -665,7 +629,6 @@ class TreeRebuilder(object): # XXX REMOVE me : if self.asscontext in ('Del', 'Ass'): # 'Aug' ?? self._save_assignment(newnode) - newnode.set_line_info(newnode.last_child()) return newnode def visit_bytes(self, node, parent): @@ -700,7 +663,6 @@ class TreeRebuilder(object): if node.dest is not None: newnode.dest = self.visit(node.dest, newnode) newnode.values = [self.visit(child, newnode) for child in node.values] - newnode.set_line_info(newnode.last_child()) return newnode def visit_raise(self, node, parent): @@ -713,7 +675,6 @@ class TreeRebuilder(object): newnode.inst = self.visit(node.inst, newnode) if node.tback is not None: newnode.tback = self.visit(node.tback, newnode) - newnode.set_line_info(newnode.last_child()) return newnode def visit_return(self, node, parent): @@ -722,7 +683,6 @@ class TreeRebuilder(object): _lineno_parent(node, newnode, parent) if node.value is not None: newnode.value = self.visit(node.value, newnode) - newnode.set_line_info(newnode.last_child()) return newnode def visit_set(self, node, parent): @@ -730,7 +690,6 @@ class TreeRebuilder(object): newnode = new.Set() _lineno_parent(node, newnode, parent) newnode.elts = [self.visit(child, newnode) for child in node.elts] - newnode.set_line_info(newnode.last_child()) return newnode def visit_setcomp(self, node, parent): @@ -740,20 +699,18 @@ class TreeRebuilder(object): newnode.elt = self.visit(node.elt, newnode) newnode.generators = [self.visit(child, newnode) for child in node.generators] - newnode.set_line_info(newnode.last_child()) return newnode def visit_slice(self, node, parent): """visit a Slice node by returning a fresh instance of it""" newnode = new.Slice() - _lineno_parent(node, newnode, parent) + newnode.parent = parent if node.lower is not None: newnode.lower = self.visit(node.lower, newnode) if node.upper is not None: newnode.upper = self.visit(node.upper, newnode) if node.step is not None: newnode.step = self.visit(node.step, newnode) - newnode.set_line_info(newnode.last_child()) return newnode def visit_subscript(self, node, parent): @@ -764,7 +721,6 @@ class TreeRebuilder(object): newnode.value = self.visit(node.value, newnode) newnode.slice = self.visit(node.slice, newnode) self.asscontext = subcontext - newnode.set_line_info(newnode.last_child()) return newnode def visit_tryexcept(self, node, parent): @@ -774,7 +730,6 @@ class TreeRebuilder(object): newnode.body = [self.visit(child, newnode) for child in node.body] newnode.handlers = [self.visit(child, newnode) for child in node.handlers] newnode.orelse = [self.visit(child, newnode) for child in node.orelse] - newnode.set_line_info(newnode.last_child()) return newnode def visit_tryfinally(self, node, parent): @@ -783,7 +738,6 @@ class TreeRebuilder(object): _lineno_parent(node, newnode, parent) newnode.body = [self.visit(child, newnode) for child in node.body] newnode.finalbody = [self.visit(n, newnode) for n in node.finalbody] - newnode.set_line_info(newnode.last_child()) return newnode def visit_tuple(self, node, parent): @@ -791,7 +745,6 @@ class TreeRebuilder(object): newnode = new.Tuple() _lineno_parent(node, newnode, parent) newnode.elts = [self.visit(child, newnode) for child in node.elts] - newnode.set_line_info(newnode.last_child()) return newnode def visit_unaryop(self, node, parent): @@ -800,7 +753,6 @@ class TreeRebuilder(object): _lineno_parent(node, newnode, parent) newnode.operand = self.visit(node.operand, newnode) newnode.op = _UNARY_OP_CLASSES[node.op.__class__] - newnode.set_line_info(newnode.last_child()) return newnode def visit_while(self, node, parent): @@ -810,7 +762,6 @@ class TreeRebuilder(object): newnode.test = self.visit(node.test, newnode) newnode.body = [self.visit(child, newnode) for child in node.body] newnode.orelse = [self.visit(child, newnode) for child in node.orelse] - newnode.set_line_info(newnode.last_child()) return newnode def visit_with(self, node, parent): @@ -825,7 +776,6 @@ class TreeRebuilder(object): self.asscontext = None newnode.items = [(expr, vars)] newnode.body = [self.visit(child, newnode) for child in node.body] - newnode.set_line_info(newnode.last_child()) return newnode def visit_yield(self, node, parent): @@ -867,7 +817,6 @@ class TreeRebuilder3k(TreeRebuilder): if node.name is not None: newnode.name = self.visit_assname(node, newnode, node.name) newnode.body = [self.visit(child, newnode) for child in node.body] - newnode.set_line_info(newnode.last_child()) return newnode def visit_nonlocal(self, node, parent): @@ -885,7 +834,6 @@ class TreeRebuilder3k(TreeRebuilder): newnode.exc = self.visit(node.exc, newnode) if node.cause is not None: newnode.cause = self.visit(node.cause, newnode) - newnode.set_line_info(newnode.last_child()) return newnode def visit_starred(self, node, parent): @@ -893,7 +841,6 @@ class TreeRebuilder3k(TreeRebuilder): newnode = new.Starred() _lineno_parent(node, newnode, parent) newnode.value = self.visit(node.value, newnode) - newnode.set_line_info(newnode.last_child()) return newnode def visit_try(self, node, parent): @@ -908,7 +855,6 @@ class TreeRebuilder3k(TreeRebuilder): excnode.body = [self.visit(child, excnode) for child in node.body] excnode.handlers = [self.visit(child, excnode) for child in node.handlers] excnode.orelse = [self.visit(child, excnode) for child in node.orelse] - excnode.set_line_info(excnode.last_child()) newnode.body = [excnode] else: newnode.body = [self.visit(child, newnode) for child in node.body] @@ -918,7 +864,6 @@ class TreeRebuilder3k(TreeRebuilder): newnode.body = [self.visit(child, newnode) for child in node.body] newnode.handlers = [self.visit(child, newnode) for child in node.handlers] newnode.orelse = [self.visit(child, newnode) for child in node.orelse] - newnode.set_line_info(newnode.last_child()) return newnode def visit_with(self, node, parent): @@ -940,7 +885,6 @@ class TreeRebuilder3k(TreeRebuilder): newnode.items = [visit_child(child) for child in node.items] newnode.body = [self.visit(child, newnode) for child in node.body] - newnode.set_line_info(newnode.last_child()) return newnode def visit_yieldfrom(self, node, parent): diff --git a/scoped_nodes.py b/scoped_nodes.py index 08fc67bf..002a7894 100644 --- a/scoped_nodes.py +++ b/scoped_nodes.py @@ -39,7 +39,7 @@ from astroid.node_classes import Const, DelName, DelAttr, \ Dict, From, List, Pass, Raise, Return, Tuple, Yield, YieldFrom, \ LookupMixIn, const_factory as cf, unpack_infer, Name, CallFunc from astroid.bases import NodeNG, InferenceContext, Instance,\ - YES, Generator, UnboundMethod, BoundMethod, _infer_stmts, copy_context, \ + YES, Generator, UnboundMethod, BoundMethod, _infer_stmts, \ BUILTINS from astroid.mixins import FilterStmtsMixin from astroid.bases import Statement @@ -311,10 +311,10 @@ class Module(LocalsDictNodeNG): """inferred getattr""" # set lookup name since this is necessary to infer on import nodes for # instance - context = copy_context(context) - context.lookupname = name + if not context: + context = InferenceContext() try: - return _infer_stmts(self.getattr(name, context), context, frame=self) + return _infer_stmts(self.getattr(name, context), context, frame=self, lookupname=name) except NotFoundError: raise InferenceError(name) @@ -339,13 +339,17 @@ class Module(LocalsDictNodeNG): return if sys.version_info < (2, 8): - def absolute_import_activated(self): + @cachedproperty + def _absolute_import_activated(self): for stmt in self.locals.get('absolute_import', ()): if isinstance(stmt, From) and stmt.modname == '__future__': return True return False else: - absolute_import_activated = lambda self: True + _absolute_import_activated = True + + def absolute_import_activated(self): + return self._absolute_import_activated def import_module(self, modname, relative_only=False, level=None): """import the given module considering self as context""" @@ -633,22 +637,25 @@ class Function(Statement, Lambda): self.locals = {} self.args = [] self.body = [] - self.decorators = None self.name = name self.doc = doc self.extra_decorators = [] self.instance_attrs = {} - def set_line_info(self, lastchild): - self.fromlineno = self.lineno - # lineno is the line number of the first decorator, we want the def statement lineno + @cachedproperty + def fromlineno(self): + # lineno is the line number of the first decorator, we want the def + # statement lineno + lineno = self.lineno if self.decorators is not None: - self.fromlineno += sum(node.tolineno - node.lineno + 1 + lineno += sum(node.tolineno - node.lineno + 1 for node in self.decorators.nodes) - if self.args.fromlineno < self.fromlineno: - self.args.fromlineno = self.fromlineno - self.tolineno = lastchild.tolineno - self.blockstart_tolineno = self.args.tolineno + + return lineno + + @cachedproperty + def blockstart_tolineno(self): + return self.args.tolineno def block_range(self, lineno): """return block line numbers. @@ -884,12 +891,12 @@ class Class(Statement, LocalsDictNodeNG, FilterStmtsMixin): doc="boolean indicating if it's a new style class" "or not") - def set_line_info(self, lastchild): - self.fromlineno = self.lineno - self.blockstart_tolineno = self.bases and self.bases[-1].tolineno or self.fromlineno - if lastchild is not None: - self.tolineno = lastchild.tolineno - # else this is a class with only a docstring, then tolineno is (should be) already ok + @cachedproperty + def blockstart_tolineno(self): + if self.bases: + return self.bases[-1].tolineno + else: + return self.fromlineno def block_range(self, lineno): """return block line numbers. @@ -971,28 +978,27 @@ class Class(Statement, LocalsDictNodeNG, FilterStmtsMixin): if context is None: context = InferenceContext() for stmt in self.bases: - with context.restore_path(): - try: - for baseobj in stmt.infer(context): - if not isinstance(baseobj, Class): - if isinstance(baseobj, Instance): - baseobj = baseobj._proxied - else: - # duh ? - continue - if baseobj in yielded: - continue # cf xxx above - yielded.add(baseobj) - yield baseobj - if recurs: - for grandpa in baseobj.ancestors(True, context): - if grandpa in yielded: - continue # cf xxx above - yielded.add(grandpa) - yield grandpa - except InferenceError: - # XXX log error ? - continue + try: + for baseobj in stmt.infer(context): + if not isinstance(baseobj, Class): + if isinstance(baseobj, Instance): + baseobj = baseobj._proxied + else: + # duh ? + continue + if baseobj in yielded: + continue # cf xxx above + yielded.add(baseobj) + yield baseobj + if recurs: + for grandpa in baseobj.ancestors(True, context): + if grandpa in yielded: + continue # cf xxx above + yielded.add(grandpa) + yield grandpa + except InferenceError: + # XXX log error ? + continue def local_attr_ancestors(self, name, context=None): """return an iterator on astroid representation of parent classes @@ -1087,11 +1093,11 @@ class Class(Statement, LocalsDictNodeNG, FilterStmtsMixin): """ # set lookup name since this is necessary to infer on import nodes for # instance - context = copy_context(context) - context.lookupname = name + if not context: + context = InferenceContext() try: for infered in _infer_stmts(self.getattr(name, context), context, - frame=self): + frame=self, lookupname=name): # yield YES object instead of descriptors when necessary if not isinstance(infered, Const) and isinstance(infered, Instance): try: diff --git a/test/unittest_inference.py b/test/unittest_inference.py index 7cabf034..a5883810 100644 --- a/test/unittest_inference.py +++ b/test/unittest_inference.py @@ -913,7 +913,7 @@ def f(g = lambda: None): g().x ''' astroid = builder.string_build(code, __name__, __file__) - callfuncnode = astroid['f'].body[0].value.expr + callfuncnode = astroid['f'].body[0].value.expr # 'g()' infered = list(callfuncnode.infer()) self.assertEqual(len(infered), 2, infered) infered.remove(YES) diff --git a/test/unittest_nodes.py b/test/unittest_nodes.py index 790c3b22..d4f164ac 100644 --- a/test/unittest_nodes.py +++ b/test/unittest_nodes.py @@ -308,11 +308,9 @@ except PickleError: def test_absolute_import(self): astroid = abuilder.file_build(self.datapath('absimport.py')) ctx = InferenceContext() - ctx.lookupname = 'message' # will fail if absolute import failed - astroid['message'].infer(ctx).next() - ctx.lookupname = 'email' - m = astroid['email'].infer(ctx).next() + astroid['message'].infer(ctx, lookupname='message').next() + m = astroid['email'].infer(ctx, lookupname='email').next() self.assertFalse(m.file.startswith(self.datapath('email.py'))) def test_more_absolute_import(self): |
