summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bases.py109
-rw-r--r--scoped_nodes.py43
2 files changed, 75 insertions, 77 deletions
diff --git a/bases.py b/bases.py
index 131d5b1b..ac3414ba 100644
--- a/bases.py
+++ b/bases.py
@@ -61,41 +61,37 @@ class Proxy(object):
MISSING = object()
-class InferenceContextPathContext(object):
- """Implementation of InferenceContext.push.
-
- Can't be a @contextmanager because it raises StopIteration.
- """
- def __init__(self, context, key):
- self.original_path = context.path.copy()
- self.context = context
- self.key = key
-
- def __enter__(self):
- if self.key in self.context.path:
- raise StopIteration
-
- self.context.path.add(self.key)
- return self
-
- def __exit__(self, *exc_info):
- self.context.path = self.original_path
-
-
class InferenceContext(object):
__slots__ = ('path', 'callcontext', 'boundnode', 'infered')
- def __init__(self, path=None):
+ def __init__(self,
+ path=None, callcontext=None, boundnode=None, infered=None):
if path is None:
- self.path = set()
+ self.path = frozenset()
else:
self.path = path
- self.callcontext = None
- self.boundnode = None
- self.infered = {}
+ self.callcontext = callcontext
+ self.boundnode = boundnode
+ if infered is None:
+ self.infered = {}
+ else:
+ self.infered = infered
def push(self, key):
- return InferenceContextPathContext(self, key)
+ # This returns a NEW context with the same attributes, but a new key
+ # added to `path`. The intention is that it's only passed to callees
+ # and then destroyed; otherwise scope() may not work correctly.
+ # The cache will be shared, since it's the same exact dict.
+ if key in self.path:
+ # End the containing generator
+ raise StopIteration
+
+ return InferenceContext(
+ self.path.union([key]),
+ self.callcontext,
+ self.boundnode,
+ self.infered,
+ )
@contextmanager
def scope(self, callcontext=MISSING, boundnode=MISSING):
@@ -109,11 +105,14 @@ class InferenceContext(object):
finally:
self.callcontext, self.boundnode = orig
- @contextmanager
- def restore_path(self):
- path = set(self.path)
- yield
- self.path = path
+ def cache_generator(self, key, generator):
+ results = []
+ for result in generator:
+ results.append(result)
+ yield result
+
+ self.infered[key] = tuple(results)
+ return
def _infer_stmts(stmts, context, frame=None, lookupname=None):
@@ -198,19 +197,20 @@ class Instance(Proxy):
context = InferenceContext()
try:
# avoid recursively inferring the same attr on the same class
- with context.push((self._proxied, name)):
- # XXX frame should be self._proxied, or not ?
- get_attr = self.getattr(name, context, lookupclass=False)
- for infered in _infer_stmts(self._wrap_attr(get_attr, context), context,
- frame=self):
- yield infered
+ new_context = context.push((self._proxied, name))
+ # XXX frame should be self._proxied, or not ?
+ get_attr = self.getattr(name, new_context, lookupclass=False)
+ return _infer_stmts(
+ self._wrap_attr(get_attr, new_context),
+ new_context,
+ frame=self,
+ )
except NotFoundError:
try:
# fallback to class'igetattr since it has some logic to handle
# descriptors
- for attr in self._wrap_attr(self._proxied.igetattr(name, context),
- context):
- yield attr
+ return self._wrap_attr(self._proxied.igetattr(name, context),
+ context)
except NotFoundError:
raise InferenceError(name)
@@ -335,17 +335,18 @@ def path_wrapper(func):
"""wrapper function handling context"""
if context is None:
context = InferenceContext()
- with context.push((node, kwargs.get('lookupname'))):
- yielded = set()
- for res in _func(node, context, **kwargs):
- # unproxy only true instance, not const, tuple, dict...
- if res.__class__ is Instance:
- ares = res._proxied
- else:
- ares = res
- if not ares in yielded:
- yield res
- yielded.add(ares)
+ context = context.push((node, kwargs.get('lookupname')))
+
+ yielded = set()
+ for res in _func(node, context, **kwargs):
+ # unproxy only true instance, not const, tuple, dict...
+ if res.__class__ is Instance:
+ ares = res._proxied
+ else:
+ ares = res
+ if not ares in yielded:
+ yield res
+ yielded.add(ares)
return wrapped
def yes_if_nothing_infered(func):
@@ -412,9 +413,7 @@ class NodeNG(object):
if key in context.infered:
return iter(context.infered[key])
- eager = tuple(self._infer(context, **kwargs))
- context.infered[key] = eager
- return iter(eager)
+ return context.cache_generator(key, self._infer(context, **kwargs))
def _repr_name(self):
"""return self.name or self.attrname or '' for nice representation"""
diff --git a/scoped_nodes.py b/scoped_nodes.py
index 60185717..7e6349df 100644
--- a/scoped_nodes.py
+++ b/scoped_nodes.py
@@ -886,28 +886,27 @@ class Class(Statement, LocalsDictNodeNG, FilterStmtsMixin):
if context is None:
context = InferenceContext()
for stmt in self.bases:
- with context.restore_path():
- try:
- for baseobj in stmt.infer(context):
- if not isinstance(baseobj, Class):
- if isinstance(baseobj, Instance):
- baseobj = baseobj._proxied
- else:
- # duh ?
- continue
- if baseobj in yielded:
- continue # cf xxx above
- yielded.add(baseobj)
- yield baseobj
- if recurs:
- for grandpa in baseobj.ancestors(True, context):
- if grandpa in yielded:
- continue # cf xxx above
- yielded.add(grandpa)
- yield grandpa
- except InferenceError:
- # XXX log error ?
- continue
+ try:
+ for baseobj in stmt.infer(context):
+ if not isinstance(baseobj, Class):
+ if isinstance(baseobj, Instance):
+ baseobj = baseobj._proxied
+ else:
+ # duh ?
+ continue
+ if baseobj in yielded:
+ continue # cf xxx above
+ yielded.add(baseobj)
+ yield baseobj
+ if recurs:
+ for grandpa in baseobj.ancestors(True, context):
+ if grandpa in yielded:
+ continue # cf xxx above
+ yielded.add(grandpa)
+ yield grandpa
+ except InferenceError:
+ # XXX log error ?
+ continue
def local_attr_ancestors(self, name, context=None):
"""return an iterator on astroid representation of parent classes