diff options
| -rw-r--r-- | bases.py | 109 | ||||
| -rw-r--r-- | scoped_nodes.py | 43 |
2 files changed, 75 insertions, 77 deletions
@@ -61,41 +61,37 @@ class Proxy(object): MISSING = object() -class InferenceContextPathContext(object): - """Implementation of InferenceContext.push. - - Can't be a @contextmanager because it raises StopIteration. - """ - def __init__(self, context, key): - self.original_path = context.path.copy() - self.context = context - self.key = key - - def __enter__(self): - if self.key in self.context.path: - raise StopIteration - - self.context.path.add(self.key) - return self - - def __exit__(self, *exc_info): - self.context.path = self.original_path - - class InferenceContext(object): __slots__ = ('path', 'callcontext', 'boundnode', 'infered') - def __init__(self, path=None): + def __init__(self, + path=None, callcontext=None, boundnode=None, infered=None): if path is None: - self.path = set() + self.path = frozenset() else: self.path = path - self.callcontext = None - self.boundnode = None - self.infered = {} + self.callcontext = callcontext + self.boundnode = boundnode + if infered is None: + self.infered = {} + else: + self.infered = infered def push(self, key): - return InferenceContextPathContext(self, key) + # This returns a NEW context with the same attributes, but a new key + # added to `path`. The intention is that it's only passed to callees + # and then destroyed; otherwise scope() may not work correctly. + # The cache will be shared, since it's the same exact dict. + if key in self.path: + # End the containing generator + raise StopIteration + + return InferenceContext( + self.path.union([key]), + self.callcontext, + self.boundnode, + self.infered, + ) @contextmanager def scope(self, callcontext=MISSING, boundnode=MISSING): @@ -109,11 +105,14 @@ class InferenceContext(object): finally: self.callcontext, self.boundnode = orig - @contextmanager - def restore_path(self): - path = set(self.path) - yield - self.path = path + def cache_generator(self, key, generator): + results = [] + for result in generator: + results.append(result) + yield result + + self.infered[key] = tuple(results) + return def _infer_stmts(stmts, context, frame=None, lookupname=None): @@ -198,19 +197,20 @@ class Instance(Proxy): context = InferenceContext() try: # avoid recursively inferring the same attr on the same class - with context.push((self._proxied, name)): - # XXX frame should be self._proxied, or not ? - get_attr = self.getattr(name, context, lookupclass=False) - for infered in _infer_stmts(self._wrap_attr(get_attr, context), context, - frame=self): - yield infered + new_context = context.push((self._proxied, name)) + # XXX frame should be self._proxied, or not ? + get_attr = self.getattr(name, new_context, lookupclass=False) + return _infer_stmts( + self._wrap_attr(get_attr, new_context), + new_context, + frame=self, + ) except NotFoundError: try: # fallback to class'igetattr since it has some logic to handle # descriptors - for attr in self._wrap_attr(self._proxied.igetattr(name, context), - context): - yield attr + return self._wrap_attr(self._proxied.igetattr(name, context), + context) except NotFoundError: raise InferenceError(name) @@ -335,17 +335,18 @@ def path_wrapper(func): """wrapper function handling context""" if context is None: context = InferenceContext() - with context.push((node, kwargs.get('lookupname'))): - yielded = set() - for res in _func(node, context, **kwargs): - # unproxy only true instance, not const, tuple, dict... - if res.__class__ is Instance: - ares = res._proxied - else: - ares = res - if not ares in yielded: - yield res - yielded.add(ares) + context = context.push((node, kwargs.get('lookupname'))) + + yielded = set() + for res in _func(node, context, **kwargs): + # unproxy only true instance, not const, tuple, dict... + if res.__class__ is Instance: + ares = res._proxied + else: + ares = res + if not ares in yielded: + yield res + yielded.add(ares) return wrapped def yes_if_nothing_infered(func): @@ -412,9 +413,7 @@ class NodeNG(object): if key in context.infered: return iter(context.infered[key]) - eager = tuple(self._infer(context, **kwargs)) - context.infered[key] = eager - return iter(eager) + return context.cache_generator(key, self._infer(context, **kwargs)) def _repr_name(self): """return self.name or self.attrname or '' for nice representation""" diff --git a/scoped_nodes.py b/scoped_nodes.py index 60185717..7e6349df 100644 --- a/scoped_nodes.py +++ b/scoped_nodes.py @@ -886,28 +886,27 @@ class Class(Statement, LocalsDictNodeNG, FilterStmtsMixin): if context is None: context = InferenceContext() for stmt in self.bases: - with context.restore_path(): - try: - for baseobj in stmt.infer(context): - if not isinstance(baseobj, Class): - if isinstance(baseobj, Instance): - baseobj = baseobj._proxied - else: - # duh ? - continue - if baseobj in yielded: - continue # cf xxx above - yielded.add(baseobj) - yield baseobj - if recurs: - for grandpa in baseobj.ancestors(True, context): - if grandpa in yielded: - continue # cf xxx above - yielded.add(grandpa) - yield grandpa - except InferenceError: - # XXX log error ? - continue + try: + for baseobj in stmt.infer(context): + if not isinstance(baseobj, Class): + if isinstance(baseobj, Instance): + baseobj = baseobj._proxied + else: + # duh ? + continue + if baseobj in yielded: + continue # cf xxx above + yielded.add(baseobj) + yield baseobj + if recurs: + for grandpa in baseobj.ancestors(True, context): + if grandpa in yielded: + continue # cf xxx above + yielded.add(grandpa) + yield grandpa + except InferenceError: + # XXX log error ? + continue def local_attr_ancestors(self, name, context=None): """return an iterator on astroid representation of parent classes |
