diff options
| -rw-r--r-- | bases.py | 130 | ||||
| -rw-r--r-- | inference.py | 63 | ||||
| -rw-r--r-- | node_classes.py | 5 | ||||
| -rw-r--r-- | protocols.py | 19 | ||||
| -rw-r--r-- | scoped_nodes.py | 55 | ||||
| -rw-r--r-- | test/unittest_inference.py | 2 | ||||
| -rw-r--r-- | test/unittest_nodes.py | 10 | ||||
| -rw-r--r-- | test/unittest_scoped_nodes.py | 3 |
8 files changed, 160 insertions, 127 deletions
@@ -56,6 +56,31 @@ class Proxy(object): # Inference ################################################################## +MISSING = object() + + +class InferenceContextPathContext(object): + """Implementation of InferenceContext.push. + + Can't be a @contextmanager because it raises StopIteration. + """ + def __init__(self, context, node): + self.original_path = context.path.copy() + self.context = context + self.node = node + + def __enter__(self): + name = self.context.lookupname + if (self.node, name) in self.context.path: + raise StopIteration + + self.context.path.add((self.node, name)) + return self + + def __exit__(self, *exc_info): + self.context.path = self.original_path + + class InferenceContext(object): __slots__ = ('path', 'lookupname', 'callcontext', 'boundnode') @@ -69,17 +94,21 @@ class InferenceContext(object): self.boundnode = None def push(self, node): - name = self.lookupname - if (node, name) in self.path: - raise StopIteration() - self.path.add( (node, name) ) - - def clone(self): - # XXX copy lookupname/callcontext ? - clone = InferenceContext(self.path) - clone.callcontext = self.callcontext - clone.boundnode = self.boundnode - return clone + return InferenceContextPathContext(self, node) + + @contextmanager + def scope(self, lookupname=MISSING, callcontext=MISSING, boundnode=MISSING): + try: + orig = self.lookupname, self.callcontext, self.boundnode + if lookupname is not MISSING: + self.lookupname = lookupname + if callcontext is not MISSING: + self.callcontext = callcontext + if boundnode is not MISSING: + self.boundnode = boundnode + yield + finally: + self.lookupname, self.callcontext, self.boundnode = orig @contextmanager def restore_path(self): @@ -87,39 +116,30 @@ class InferenceContext(object): yield self.path = path -def copy_context(context): - if context is not None: - return context.clone() - else: - return InferenceContext() - def _infer_stmts(stmts, context, frame=None): """return an iterator on statements inferred by each statement in <stmts> """ stmt = None infered = False - if context is not None: - name = context.lookupname - context = context.clone() - else: - name = None + if context is None: context = InferenceContext() + name = context.lookupname for stmt in stmts: if stmt is YES: yield stmt infered = True continue - context.lookupname = stmt._infer_name(frame, name) - try: - for infered in stmt.infer(context): - yield infered + with context.scope(lookupname=stmt._infer_name(frame, name)): + try: + for infered in stmt.infer(context): + yield infered + infered = True + except UnresolvableName: + continue + except InferenceError: + yield YES infered = True - except UnresolvableName: - continue - except InferenceError: - yield YES - infered = True if not infered: raise InferenceError(str(stmt)) @@ -170,20 +190,23 @@ class Instance(Proxy): def igetattr(self, name, context=None): """inferred getattr""" + if not context: + context = InferenceContext() try: # avoid recursively inferring the same attr on the same class - if context: - context.push((self._proxied, name)) - # XXX frame should be self._proxied, or not ? - get_attr = self.getattr(name, context, lookupclass=False) - return _infer_stmts(self._wrap_attr(get_attr, context), context, - frame=self) + with context.push((self._proxied, name)): + # XXX frame should be self._proxied, or not ? + get_attr = self.getattr(name, context, lookupclass=False) + for infered in _infer_stmts(self._wrap_attr(get_attr, context), context, + frame=self): + yield infered except NotFoundError: try: # fallback to class'igetattr since it has some logic to handle # descriptors - return self._wrap_attr(self._proxied.igetattr(name, context), - context) + for attr in self._wrap_attr(self._proxied.igetattr(name, context), + context): + yield attr except NotFoundError: raise InferenceError(name) @@ -274,9 +297,9 @@ class BoundMethod(UnboundMethod): return True def infer_call_result(self, caller, context): - context = context.clone() - context.boundnode = self.bound - return self._proxied.infer_call_result(caller, context) + with context.scope(boundnode=self.bound, lookupname=None): + for infered in self._proxied.infer_call_result(caller, context): + yield infered class Generator(Instance): @@ -308,17 +331,17 @@ def path_wrapper(func): """wrapper function handling context""" if context is None: context = InferenceContext() - context.push(node) - yielded = set() - for res in _func(node, context, **kwargs): - # unproxy only true instance, not const, tuple, dict... - if res.__class__ is Instance: - ares = res._proxied - else: - ares = res - if not ares in yielded: - yield res - yielded.add(ares) + with context.push(node): + yielded = set() + for res in _func(node, context, **kwargs): + # unproxy only true instance, not const, tuple, dict... + if res.__class__ is Instance: + ares = res._proxied + else: + ares = res + if not ares in yielded: + yield res + yielded.add(ares) return wrapped def yes_if_nothing_infered(func): @@ -377,6 +400,7 @@ class NodeNG(object): return self._explicit_inference(self, context, **kwargs) except UseInferenceDefault: pass + return self._infer(context, **kwargs) def _repr_name(self): diff --git a/inference.py b/inference.py index 35cce332..7bc2313b 100644 --- a/inference.py +++ b/inference.py @@ -28,7 +28,7 @@ from astroid.manager import AstroidManager from astroid.exceptions import (AstroidError, InferenceError, NoDefault, NotFoundError, UnresolvableName) from astroid.bases import YES, Instance, InferenceContext, \ - _infer_stmts, copy_context, path_wrapper, raise_if_nothing_infered + _infer_stmts, path_wrapper, raise_if_nothing_infered from astroid.protocols import _arguments_infer_argname MANAGER = AstroidManager() @@ -146,29 +146,33 @@ def infer_name(self, context=None): frame, stmts = self.lookup(self.name) if not stmts: raise UnresolvableName(self.name) - context = context.clone() - context.lookupname = self.name - return _infer_stmts(stmts, context, frame) + with context.scope(lookupname=self.name): + for infered in _infer_stmts(stmts, context, frame): + yield infered nodes.Name._infer = path_wrapper(infer_name) nodes.AssName.infer_lhs = infer_name # won't work with a path wrapper def infer_callfunc(self, context=None): """infer a CallFunc node by trying to guess what the function returns""" - callcontext = context.clone() - callcontext.callcontext = CallContext(self.args, self.starargs, self.kwargs) - callcontext.boundnode = None + if context is None: + context = InferenceContext() for callee in self.func.infer(context): - if callee is YES: - yield callee - continue - try: - if hasattr(callee, 'infer_call_result'): - for infered in callee.infer_call_result(self, callcontext): - yield infered - except InferenceError: - ## XXX log error ? - continue + with context.scope( + callcontext=CallContext(self.args, self.starargs, self.kwargs), + boundnode=None, + lookupname=None, + ): + if callee is YES: + yield callee + continue + try: + if hasattr(callee, 'infer_call_result'): + for infered in callee.infer_call_result(self, context): + yield infered + except InferenceError: + ## XXX log error ? + continue nodes.CallFunc._infer = path_wrapper(raise_if_nothing_infered(infer_callfunc)) @@ -185,8 +189,9 @@ nodes.Import._infer = path_wrapper(infer_import) def infer_name_module(self, name): context = InferenceContext() - context.lookupname = name - return self.infer(context, asname=False) + with context.scope(lookupname=name): + for infered in self.infer(context, asname=False): + yield infered nodes.Import.infer_name_module = infer_name_module @@ -199,9 +204,9 @@ def infer_from(self, context=None, asname=True): name = self.real_name(name) module = self.do_import_module(self.modname) try: - context = copy_context(context) - context.lookupname = name - return _infer_stmts(module.getattr(name, ignore_locals=module is self.root()), context) + with context.scope(lookupname=name): + for infered in _infer_stmts(module.getattr(name, ignore_locals=module is self.root()), context): + yield infered except NotFoundError: raise InferenceError(name) nodes.From._infer = path_wrapper(infer_from) @@ -209,21 +214,21 @@ nodes.From._infer = path_wrapper(infer_from) def infer_getattr(self, context=None): """infer a Getattr node by using getattr on the associated object""" - #context = context.clone() + if not context: + context = InferenceContext() for owner in self.expr.infer(context): if owner is YES: yield owner continue try: - context.boundnode = owner - for obj in owner.igetattr(self.attrname, context): - yield obj - context.boundnode = None + with context.scope(boundnode=owner, lookupname=None): + for obj in owner.igetattr(self.attrname, context): + yield obj except (NotFoundError, InferenceError): - context.boundnode = None + pass except AttributeError: # XXX method / function - context.boundnode = None + pass nodes.Getattr._infer = path_wrapper(raise_if_nothing_infered(infer_getattr)) nodes.AssAttr.infer_lhs = raise_if_nothing_infered(infer_getattr) # # won't work with a path wrapper diff --git a/node_classes.py b/node_classes.py index 01dc8d92..e247392e 100644 --- a/node_classes.py +++ b/node_classes.py @@ -125,8 +125,7 @@ class LookupMixIn(object): the lookup method """ frame, stmts = self.lookup(name) - context = InferenceContext() - return _infer_stmts(stmts, context, frame) + return _infer_stmts(stmts, None, frame) def _filter_stmts(self, stmts, frame, offset): """filter statements to remove ignorable statements. @@ -889,7 +888,7 @@ class Yield(NodeNG): value = None class YieldFrom(Yield): - """ Class representing a YieldFrom node. """ + """ Class representing a YieldFrom node. """ # constants ############################################################## diff --git a/protocols.py b/protocols.py index e66b802c..616340c9 100644 --- a/protocols.py +++ b/protocols.py @@ -23,7 +23,7 @@ __doctype__ = "restructuredtext en" from astroid.exceptions import InferenceError, NoDefault from astroid.node_classes import unpack_infer -from astroid.bases import copy_context, \ +from astroid.bases import InferenceContext, \ raise_if_nothing_infered, yes_if_nothing_infered, Instance, YES from astroid.nodes import const_factory from astroid import nodes @@ -239,9 +239,11 @@ def _arguments_infer_argname(self, name, context): # if there is a default value, yield it. And then yield YES to reflect # we can't guess given argument value try: - context = copy_context(context) - for infered in self.default_value(name).infer(context): - yield infered + if context is None: + context = InferenceContext() + with context.scope(lookupname=None): + for infered in self.default_value(name).infer(context): + yield infered yield YES except NoDefault: yield YES @@ -251,11 +253,10 @@ def arguments_assigned_stmts(self, node, context, asspath=None): if context.callcontext: # reset call context/name callcontext = context.callcontext - context = copy_context(context) - context.callcontext = None - for infered in callcontext.infer_argument(self.parent, node.name, context): - yield infered - return + with context.scope(callcontext=None, lookupname=None): + for infered in callcontext.infer_argument(self.parent, node.name, context): + yield infered + return for infered in _arguments_infer_argname(self, node.name, context): yield infered nodes.Arguments.assigned_stmts = arguments_assigned_stmts diff --git a/scoped_nodes.py b/scoped_nodes.py index e354c5b8..a5bc37e9 100644 --- a/scoped_nodes.py +++ b/scoped_nodes.py @@ -39,7 +39,7 @@ from astroid.node_classes import Const, DelName, DelAttr, \ Dict, From, List, Pass, Raise, Return, Tuple, Yield, YieldFrom, \ LookupMixIn, const_factory as cf, unpack_infer, Name from astroid.bases import NodeNG, InferenceContext, Instance,\ - YES, Generator, UnboundMethod, BoundMethod, _infer_stmts, copy_context, \ + YES, Generator, UnboundMethod, BoundMethod, _infer_stmts, \ BUILTINS from astroid.mixins import FilterStmtsMixin from astroid.bases import Statement @@ -308,12 +308,14 @@ class Module(LocalsDictNodeNG): """inferred getattr""" # set lookup name since this is necessary to infer on import nodes for # instance - context = copy_context(context) - context.lookupname = name - try: - return _infer_stmts(self.getattr(name, context), context, frame=self) - except NotFoundError: - raise InferenceError(name) + if not context: + context = InferenceContext() + with context.scope(lookupname=name): + try: + for infered in _infer_stmts(self.getattr(name, context), context, frame=self): + yield infered + except NotFoundError: + raise InferenceError(name) def fully_defined(self): """return True if this module has been built from a .py file @@ -993,27 +995,28 @@ class Class(Statement, LocalsDictNodeNG, FilterStmtsMixin): """ # set lookup name since this is necessary to infer on import nodes for # instance - context = copy_context(context) - context.lookupname = name - try: - for infered in _infer_stmts(self.getattr(name, context), context, - frame=self): - # yield YES object instead of descriptors when necessary - if not isinstance(infered, Const) and isinstance(infered, Instance): - try: - infered._proxied.getattr('__get__', context) - except NotFoundError: - yield infered + if not context: + context = InferenceContext() + with context.scope(lookupname=name): + try: + for infered in _infer_stmts(self.getattr(name, context), context, + frame=self): + # yield YES object instead of descriptors when necessary + if not isinstance(infered, Const) and isinstance(infered, Instance): + try: + infered._proxied.getattr('__get__', context) + except NotFoundError: + yield infered + else: + yield YES else: - yield YES + yield function_to_method(infered, self) + except NotFoundError: + if not name.startswith('__') and self.has_dynamic_getattr(context): + # class handle some dynamic attributes, return a YES object + yield YES else: - yield function_to_method(infered, self) - except NotFoundError: - if not name.startswith('__') and self.has_dynamic_getattr(context): - # class handle some dynamic attributes, return a YES object - yield YES - else: - raise InferenceError(name) + raise InferenceError(name) def has_dynamic_getattr(self, context=None): """return True if the class has a custom __getattr__ or diff --git a/test/unittest_inference.py b/test/unittest_inference.py index c417e1aa..db950e38 100644 --- a/test/unittest_inference.py +++ b/test/unittest_inference.py @@ -912,7 +912,7 @@ def f(g = lambda: None): g().x ''' astroid = builder.string_build(code, __name__, __file__) - callfuncnode = astroid['f'].body[0].value.expr + callfuncnode = astroid['f'].body[0].value.expr # 'g()' infered = list(callfuncnode.infer()) self.assertEqual(len(infered), 2, infered) infered.remove(YES) diff --git a/test/unittest_nodes.py b/test/unittest_nodes.py index ee035e11..b5245caa 100644 --- a/test/unittest_nodes.py +++ b/test/unittest_nodes.py @@ -308,11 +308,11 @@ except PickleError: def test_absolute_import(self): astroid = abuilder.file_build(self.datapath('absimport.py')) ctx = InferenceContext() - ctx.lookupname = 'message' - # will fail if absolute import failed - astroid['message'].infer(ctx).next() - ctx.lookupname = 'email' - m = astroid['email'].infer(ctx).next() + with ctx.scope(lookupname='message'): + # will fail if absolute import failed + astroid['message'].infer(ctx).next() + with ctx.scope(lookupname='email'): + m = astroid['email'].infer(ctx).next() self.assertFalse(m.file.startswith(self.datapath('email.py'))) diff --git a/test/unittest_scoped_nodes.py b/test/unittest_scoped_nodes.py index b6f3434b..de0b395f 100644 --- a/test/unittest_scoped_nodes.py +++ b/test/unittest_scoped_nodes.py @@ -96,7 +96,8 @@ class ModuleNodeTC(TestCase): del sys.path[1] self.assertEqual(len(NONREGR.getattr('enumerate')), 2) # raise ResolveError - self.assertRaises(InferenceError, MODULE.igetattr, 'YOAA') + gen = MODULE.igetattr('YOAA') + self.assertRaises(InferenceError, list, gen) def test_wildard_import_names(self): m = abuilder.file_build(join(DATA, 'all.py'), 'all') |
