summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bases.py130
-rw-r--r--inference.py63
-rw-r--r--node_classes.py5
-rw-r--r--protocols.py19
-rw-r--r--scoped_nodes.py55
-rw-r--r--test/unittest_inference.py2
-rw-r--r--test/unittest_nodes.py10
-rw-r--r--test/unittest_scoped_nodes.py3
8 files changed, 160 insertions, 127 deletions
diff --git a/bases.py b/bases.py
index fc4f572f..bbdd3da3 100644
--- a/bases.py
+++ b/bases.py
@@ -56,6 +56,31 @@ class Proxy(object):
# Inference ##################################################################
+MISSING = object()
+
+
+class InferenceContextPathContext(object):
+ """Implementation of InferenceContext.push.
+
+ Can't be a @contextmanager because it raises StopIteration.
+ """
+ def __init__(self, context, node):
+ self.original_path = context.path.copy()
+ self.context = context
+ self.node = node
+
+ def __enter__(self):
+ name = self.context.lookupname
+ if (self.node, name) in self.context.path:
+ raise StopIteration
+
+ self.context.path.add((self.node, name))
+ return self
+
+ def __exit__(self, *exc_info):
+ self.context.path = self.original_path
+
+
class InferenceContext(object):
__slots__ = ('path', 'lookupname', 'callcontext', 'boundnode')
@@ -69,17 +94,21 @@ class InferenceContext(object):
self.boundnode = None
def push(self, node):
- name = self.lookupname
- if (node, name) in self.path:
- raise StopIteration()
- self.path.add( (node, name) )
-
- def clone(self):
- # XXX copy lookupname/callcontext ?
- clone = InferenceContext(self.path)
- clone.callcontext = self.callcontext
- clone.boundnode = self.boundnode
- return clone
+ return InferenceContextPathContext(self, node)
+
+ @contextmanager
+ def scope(self, lookupname=MISSING, callcontext=MISSING, boundnode=MISSING):
+ try:
+ orig = self.lookupname, self.callcontext, self.boundnode
+ if lookupname is not MISSING:
+ self.lookupname = lookupname
+ if callcontext is not MISSING:
+ self.callcontext = callcontext
+ if boundnode is not MISSING:
+ self.boundnode = boundnode
+ yield
+ finally:
+ self.lookupname, self.callcontext, self.boundnode = orig
@contextmanager
def restore_path(self):
@@ -87,39 +116,30 @@ class InferenceContext(object):
yield
self.path = path
-def copy_context(context):
- if context is not None:
- return context.clone()
- else:
- return InferenceContext()
-
def _infer_stmts(stmts, context, frame=None):
"""return an iterator on statements inferred by each statement in <stmts>
"""
stmt = None
infered = False
- if context is not None:
- name = context.lookupname
- context = context.clone()
- else:
- name = None
+ if context is None:
context = InferenceContext()
+ name = context.lookupname
for stmt in stmts:
if stmt is YES:
yield stmt
infered = True
continue
- context.lookupname = stmt._infer_name(frame, name)
- try:
- for infered in stmt.infer(context):
- yield infered
+ with context.scope(lookupname=stmt._infer_name(frame, name)):
+ try:
+ for infered in stmt.infer(context):
+ yield infered
+ infered = True
+ except UnresolvableName:
+ continue
+ except InferenceError:
+ yield YES
infered = True
- except UnresolvableName:
- continue
- except InferenceError:
- yield YES
- infered = True
if not infered:
raise InferenceError(str(stmt))
@@ -170,20 +190,23 @@ class Instance(Proxy):
def igetattr(self, name, context=None):
"""inferred getattr"""
+ if not context:
+ context = InferenceContext()
try:
# avoid recursively inferring the same attr on the same class
- if context:
- context.push((self._proxied, name))
- # XXX frame should be self._proxied, or not ?
- get_attr = self.getattr(name, context, lookupclass=False)
- return _infer_stmts(self._wrap_attr(get_attr, context), context,
- frame=self)
+ with context.push((self._proxied, name)):
+ # XXX frame should be self._proxied, or not ?
+ get_attr = self.getattr(name, context, lookupclass=False)
+ for infered in _infer_stmts(self._wrap_attr(get_attr, context), context,
+ frame=self):
+ yield infered
except NotFoundError:
try:
# fallback to class'igetattr since it has some logic to handle
# descriptors
- return self._wrap_attr(self._proxied.igetattr(name, context),
- context)
+ for attr in self._wrap_attr(self._proxied.igetattr(name, context),
+ context):
+ yield attr
except NotFoundError:
raise InferenceError(name)
@@ -274,9 +297,9 @@ class BoundMethod(UnboundMethod):
return True
def infer_call_result(self, caller, context):
- context = context.clone()
- context.boundnode = self.bound
- return self._proxied.infer_call_result(caller, context)
+ with context.scope(boundnode=self.bound, lookupname=None):
+ for infered in self._proxied.infer_call_result(caller, context):
+ yield infered
class Generator(Instance):
@@ -308,17 +331,17 @@ def path_wrapper(func):
"""wrapper function handling context"""
if context is None:
context = InferenceContext()
- context.push(node)
- yielded = set()
- for res in _func(node, context, **kwargs):
- # unproxy only true instance, not const, tuple, dict...
- if res.__class__ is Instance:
- ares = res._proxied
- else:
- ares = res
- if not ares in yielded:
- yield res
- yielded.add(ares)
+ with context.push(node):
+ yielded = set()
+ for res in _func(node, context, **kwargs):
+ # unproxy only true instance, not const, tuple, dict...
+ if res.__class__ is Instance:
+ ares = res._proxied
+ else:
+ ares = res
+ if not ares in yielded:
+ yield res
+ yielded.add(ares)
return wrapped
def yes_if_nothing_infered(func):
@@ -377,6 +400,7 @@ class NodeNG(object):
return self._explicit_inference(self, context, **kwargs)
except UseInferenceDefault:
pass
+
return self._infer(context, **kwargs)
def _repr_name(self):
diff --git a/inference.py b/inference.py
index 35cce332..7bc2313b 100644
--- a/inference.py
+++ b/inference.py
@@ -28,7 +28,7 @@ from astroid.manager import AstroidManager
from astroid.exceptions import (AstroidError,
InferenceError, NoDefault, NotFoundError, UnresolvableName)
from astroid.bases import YES, Instance, InferenceContext, \
- _infer_stmts, copy_context, path_wrapper, raise_if_nothing_infered
+ _infer_stmts, path_wrapper, raise_if_nothing_infered
from astroid.protocols import _arguments_infer_argname
MANAGER = AstroidManager()
@@ -146,29 +146,33 @@ def infer_name(self, context=None):
frame, stmts = self.lookup(self.name)
if not stmts:
raise UnresolvableName(self.name)
- context = context.clone()
- context.lookupname = self.name
- return _infer_stmts(stmts, context, frame)
+ with context.scope(lookupname=self.name):
+ for infered in _infer_stmts(stmts, context, frame):
+ yield infered
nodes.Name._infer = path_wrapper(infer_name)
nodes.AssName.infer_lhs = infer_name # won't work with a path wrapper
def infer_callfunc(self, context=None):
"""infer a CallFunc node by trying to guess what the function returns"""
- callcontext = context.clone()
- callcontext.callcontext = CallContext(self.args, self.starargs, self.kwargs)
- callcontext.boundnode = None
+ if context is None:
+ context = InferenceContext()
for callee in self.func.infer(context):
- if callee is YES:
- yield callee
- continue
- try:
- if hasattr(callee, 'infer_call_result'):
- for infered in callee.infer_call_result(self, callcontext):
- yield infered
- except InferenceError:
- ## XXX log error ?
- continue
+ with context.scope(
+ callcontext=CallContext(self.args, self.starargs, self.kwargs),
+ boundnode=None,
+ lookupname=None,
+ ):
+ if callee is YES:
+ yield callee
+ continue
+ try:
+ if hasattr(callee, 'infer_call_result'):
+ for infered in callee.infer_call_result(self, context):
+ yield infered
+ except InferenceError:
+ ## XXX log error ?
+ continue
nodes.CallFunc._infer = path_wrapper(raise_if_nothing_infered(infer_callfunc))
@@ -185,8 +189,9 @@ nodes.Import._infer = path_wrapper(infer_import)
def infer_name_module(self, name):
context = InferenceContext()
- context.lookupname = name
- return self.infer(context, asname=False)
+ with context.scope(lookupname=name):
+ for infered in self.infer(context, asname=False):
+ yield infered
nodes.Import.infer_name_module = infer_name_module
@@ -199,9 +204,9 @@ def infer_from(self, context=None, asname=True):
name = self.real_name(name)
module = self.do_import_module(self.modname)
try:
- context = copy_context(context)
- context.lookupname = name
- return _infer_stmts(module.getattr(name, ignore_locals=module is self.root()), context)
+ with context.scope(lookupname=name):
+ for infered in _infer_stmts(module.getattr(name, ignore_locals=module is self.root()), context):
+ yield infered
except NotFoundError:
raise InferenceError(name)
nodes.From._infer = path_wrapper(infer_from)
@@ -209,21 +214,21 @@ nodes.From._infer = path_wrapper(infer_from)
def infer_getattr(self, context=None):
"""infer a Getattr node by using getattr on the associated object"""
- #context = context.clone()
+ if not context:
+ context = InferenceContext()
for owner in self.expr.infer(context):
if owner is YES:
yield owner
continue
try:
- context.boundnode = owner
- for obj in owner.igetattr(self.attrname, context):
- yield obj
- context.boundnode = None
+ with context.scope(boundnode=owner, lookupname=None):
+ for obj in owner.igetattr(self.attrname, context):
+ yield obj
except (NotFoundError, InferenceError):
- context.boundnode = None
+ pass
except AttributeError:
# XXX method / function
- context.boundnode = None
+ pass
nodes.Getattr._infer = path_wrapper(raise_if_nothing_infered(infer_getattr))
nodes.AssAttr.infer_lhs = raise_if_nothing_infered(infer_getattr) # # won't work with a path wrapper
diff --git a/node_classes.py b/node_classes.py
index 01dc8d92..e247392e 100644
--- a/node_classes.py
+++ b/node_classes.py
@@ -125,8 +125,7 @@ class LookupMixIn(object):
the lookup method
"""
frame, stmts = self.lookup(name)
- context = InferenceContext()
- return _infer_stmts(stmts, context, frame)
+ return _infer_stmts(stmts, None, frame)
def _filter_stmts(self, stmts, frame, offset):
"""filter statements to remove ignorable statements.
@@ -889,7 +888,7 @@ class Yield(NodeNG):
value = None
class YieldFrom(Yield):
- """ Class representing a YieldFrom node. """
+ """ Class representing a YieldFrom node. """
# constants ##############################################################
diff --git a/protocols.py b/protocols.py
index e66b802c..616340c9 100644
--- a/protocols.py
+++ b/protocols.py
@@ -23,7 +23,7 @@ __doctype__ = "restructuredtext en"
from astroid.exceptions import InferenceError, NoDefault
from astroid.node_classes import unpack_infer
-from astroid.bases import copy_context, \
+from astroid.bases import InferenceContext, \
raise_if_nothing_infered, yes_if_nothing_infered, Instance, YES
from astroid.nodes import const_factory
from astroid import nodes
@@ -239,9 +239,11 @@ def _arguments_infer_argname(self, name, context):
# if there is a default value, yield it. And then yield YES to reflect
# we can't guess given argument value
try:
- context = copy_context(context)
- for infered in self.default_value(name).infer(context):
- yield infered
+ if context is None:
+ context = InferenceContext()
+ with context.scope(lookupname=None):
+ for infered in self.default_value(name).infer(context):
+ yield infered
yield YES
except NoDefault:
yield YES
@@ -251,11 +253,10 @@ def arguments_assigned_stmts(self, node, context, asspath=None):
if context.callcontext:
# reset call context/name
callcontext = context.callcontext
- context = copy_context(context)
- context.callcontext = None
- for infered in callcontext.infer_argument(self.parent, node.name, context):
- yield infered
- return
+ with context.scope(callcontext=None, lookupname=None):
+ for infered in callcontext.infer_argument(self.parent, node.name, context):
+ yield infered
+ return
for infered in _arguments_infer_argname(self, node.name, context):
yield infered
nodes.Arguments.assigned_stmts = arguments_assigned_stmts
diff --git a/scoped_nodes.py b/scoped_nodes.py
index e354c5b8..a5bc37e9 100644
--- a/scoped_nodes.py
+++ b/scoped_nodes.py
@@ -39,7 +39,7 @@ from astroid.node_classes import Const, DelName, DelAttr, \
Dict, From, List, Pass, Raise, Return, Tuple, Yield, YieldFrom, \
LookupMixIn, const_factory as cf, unpack_infer, Name
from astroid.bases import NodeNG, InferenceContext, Instance,\
- YES, Generator, UnboundMethod, BoundMethod, _infer_stmts, copy_context, \
+ YES, Generator, UnboundMethod, BoundMethod, _infer_stmts, \
BUILTINS
from astroid.mixins import FilterStmtsMixin
from astroid.bases import Statement
@@ -308,12 +308,14 @@ class Module(LocalsDictNodeNG):
"""inferred getattr"""
# set lookup name since this is necessary to infer on import nodes for
# instance
- context = copy_context(context)
- context.lookupname = name
- try:
- return _infer_stmts(self.getattr(name, context), context, frame=self)
- except NotFoundError:
- raise InferenceError(name)
+ if not context:
+ context = InferenceContext()
+ with context.scope(lookupname=name):
+ try:
+ for infered in _infer_stmts(self.getattr(name, context), context, frame=self):
+ yield infered
+ except NotFoundError:
+ raise InferenceError(name)
def fully_defined(self):
"""return True if this module has been built from a .py file
@@ -993,27 +995,28 @@ class Class(Statement, LocalsDictNodeNG, FilterStmtsMixin):
"""
# set lookup name since this is necessary to infer on import nodes for
# instance
- context = copy_context(context)
- context.lookupname = name
- try:
- for infered in _infer_stmts(self.getattr(name, context), context,
- frame=self):
- # yield YES object instead of descriptors when necessary
- if not isinstance(infered, Const) and isinstance(infered, Instance):
- try:
- infered._proxied.getattr('__get__', context)
- except NotFoundError:
- yield infered
+ if not context:
+ context = InferenceContext()
+ with context.scope(lookupname=name):
+ try:
+ for infered in _infer_stmts(self.getattr(name, context), context,
+ frame=self):
+ # yield YES object instead of descriptors when necessary
+ if not isinstance(infered, Const) and isinstance(infered, Instance):
+ try:
+ infered._proxied.getattr('__get__', context)
+ except NotFoundError:
+ yield infered
+ else:
+ yield YES
else:
- yield YES
+ yield function_to_method(infered, self)
+ except NotFoundError:
+ if not name.startswith('__') and self.has_dynamic_getattr(context):
+ # class handle some dynamic attributes, return a YES object
+ yield YES
else:
- yield function_to_method(infered, self)
- except NotFoundError:
- if not name.startswith('__') and self.has_dynamic_getattr(context):
- # class handle some dynamic attributes, return a YES object
- yield YES
- else:
- raise InferenceError(name)
+ raise InferenceError(name)
def has_dynamic_getattr(self, context=None):
"""return True if the class has a custom __getattr__ or
diff --git a/test/unittest_inference.py b/test/unittest_inference.py
index c417e1aa..db950e38 100644
--- a/test/unittest_inference.py
+++ b/test/unittest_inference.py
@@ -912,7 +912,7 @@ def f(g = lambda: None):
g().x
'''
astroid = builder.string_build(code, __name__, __file__)
- callfuncnode = astroid['f'].body[0].value.expr
+ callfuncnode = astroid['f'].body[0].value.expr # 'g()'
infered = list(callfuncnode.infer())
self.assertEqual(len(infered), 2, infered)
infered.remove(YES)
diff --git a/test/unittest_nodes.py b/test/unittest_nodes.py
index ee035e11..b5245caa 100644
--- a/test/unittest_nodes.py
+++ b/test/unittest_nodes.py
@@ -308,11 +308,11 @@ except PickleError:
def test_absolute_import(self):
astroid = abuilder.file_build(self.datapath('absimport.py'))
ctx = InferenceContext()
- ctx.lookupname = 'message'
- # will fail if absolute import failed
- astroid['message'].infer(ctx).next()
- ctx.lookupname = 'email'
- m = astroid['email'].infer(ctx).next()
+ with ctx.scope(lookupname='message'):
+ # will fail if absolute import failed
+ astroid['message'].infer(ctx).next()
+ with ctx.scope(lookupname='email'):
+ m = astroid['email'].infer(ctx).next()
self.assertFalse(m.file.startswith(self.datapath('email.py')))
diff --git a/test/unittest_scoped_nodes.py b/test/unittest_scoped_nodes.py
index b6f3434b..de0b395f 100644
--- a/test/unittest_scoped_nodes.py
+++ b/test/unittest_scoped_nodes.py
@@ -96,7 +96,8 @@ class ModuleNodeTC(TestCase):
del sys.path[1]
self.assertEqual(len(NONREGR.getattr('enumerate')), 2)
# raise ResolveError
- self.assertRaises(InferenceError, MODULE.igetattr, 'YOAA')
+ gen = MODULE.igetattr('YOAA')
+ self.assertRaises(InferenceError, list, gen)
def test_wildard_import_names(self):
m = abuilder.file_build(join(DATA, 'all.py'), 'all')