diff options
author | Claudiu Popa <cpopa@cloudbasesolutions.com> | 2015-05-09 15:58:31 +0300 |
---|---|---|
committer | Claudiu Popa <cpopa@cloudbasesolutions.com> | 2015-05-09 15:58:31 +0300 |
commit | f97ef6e9e89ed2de13b622c1580ef9d14e3fca1f (patch) | |
tree | 3217c294f74d007a33029934a85505289c16c663 | |
parent | 28d8eec1933525151b5468a1bf513f54fa5c50ff (diff) | |
download | astroid-f97ef6e9e89ed2de13b622c1580ef9d14e3fca1f.tar.gz |
Add basic support for understanding context managers.
Currently, there's no way to understand whatever __enter__ returns in a
context manager and what it is binded using the ``as`` keyword. With these changes,
we can understand ``bar`` in ``with foo() as bar``, which will be the result of __enter__.
There's no support for contextlib.contextmanager yet.
-rw-r--r-- | ChangeLog | 7 | ||||
-rw-r--r-- | astroid/protocols.py | 65 | ||||
-rw-r--r-- | astroid/tests/unittest_inference.py | 88 |
3 files changed, 151 insertions, 9 deletions
@@ -92,6 +92,13 @@ Change log for the astroid package (used to be astng) that all of them will have __getattr__ and __getattribute__ present and it is wrong to consider that those methods were actually implemented. + * Add basic support for understanding context managers. + + Currently, there's no way to understand whatever __enter__ returns in a + context manager and what it is binded using the ``as`` keyword. With these changes, + we can understand ``bar`` in ``with foo() as bar``, which will be the result of __enter__. + There's no support for contextlib.contextmanager yet. + 2015-03-14 -- 1.3.6 diff --git a/astroid/protocols.py b/astroid/protocols.py index 4c11f9c..5c9b200 100644 --- a/astroid/protocols.py +++ b/astroid/protocols.py @@ -24,8 +24,12 @@ import collections from astroid.exceptions import InferenceError, NoDefault, NotFoundError from astroid.node_classes import unpack_infer -from astroid.bases import InferenceContext, copy_context, \ - raise_if_nothing_infered, yes_if_nothing_infered, Instance, YES +from astroid.bases import ( + InferenceContext, copy_context, + raise_if_nothing_infered, yes_if_nothing_infered, + Instance, YES, BoundMethod, + Generator, +) from astroid.nodes import const_factory from astroid import nodes @@ -347,17 +351,60 @@ def excepthandler_assigned_stmts(self, node, context=None, asspath=None): assigned = Instance(assigned) yield assigned nodes.ExceptHandler.assigned_stmts = raise_if_nothing_infered(excepthandler_assigned_stmts) + +def _infer_context_manager(self, mgr, context): + try: + inferred = next(mgr.infer(context=context)) + except InferenceError: + return + if isinstance(inferred, Generator): + # TODO(cpopa): unsupported for now. + return + elif isinstance(inferred, Instance): + try: + enter = next(inferred.igetattr('__enter__', context=context)) + except (InferenceError, NotFoundError): + return + if not isinstance(enter, BoundMethod): + return + for result in enter.infer_call_result(self, context): + yield result def with_assigned_stmts(self, node, context=None, asspath=None): + """Infer names and other nodes from a *with* statement. + + This enables only inference for name binding in a *with* statement. + For instance, in the following code, inferring `func` will return + the `ContextManager` class, not whatever ``__enter__`` returns. + We are doing this intentionally, because we consider that the context + manager result is whatever __enter__ returns and what it is binded + using the ``as`` keyword. + + class ContextManager(object): + def __enter__(self): + return 42 + with ContextManager() as f: + pass + # ContextManager().infer() will return ContextManager + # f.infer() will return 42. + """ + + mgr = next(mgr for (mgr, vars) in self.items if vars == node) if asspath is None: - for _, vars in self.items: - if vars is None: - continue - for lst in vars.infer(context): - if isinstance(lst, (nodes.Tuple, nodes.List)): - for item in lst.nodes: - yield item + for result in _infer_context_manager(self, mgr, context): + yield result + else: + for result in _infer_context_manager(self, mgr, context): + # Walk the asspath and get the item at the final index. + obj = result + for index in asspath: + if not hasattr(obj, 'elts'): + raise InferenceError + obj = obj.elts[index] + yield obj + + nodes.With.assigned_stmts = raise_if_nothing_infered(with_assigned_stmts) diff --git a/astroid/tests/unittest_inference.py b/astroid/tests/unittest_inference.py index 3b2bad1..830eb90 100644 --- a/astroid/tests/unittest_inference.py +++ b/astroid/tests/unittest_inference.py @@ -1655,6 +1655,94 @@ class InferenceTest(resources.SysPathSetup, unittest.TestCase): self.assertIsInstance(inferred, nodes.Class) self.assertEqual(inferred.qname(), 'collections.Counter') + def test_inferring_with_statement_failures(self): + module = test_utils.build_module(''' + class NoEnter(object): + pass + class NoMethod(object): + __enter__ = None + class NoElts(object): + def __enter__(self): + return 42 + + with NoEnter() as no_enter: + pass + with NoMethod() as no_method: + pass + with NoElts() as (no_elts, no_elts1): + pass + ''') + self.assertRaises(InferenceError, next, module['no_enter'].infer()) + self.assertRaises(InferenceError, next, module['no_method'].infer()) + self.assertRaises(InferenceError, next, module['no_elts'].infer()) + + def test_inferring_with_statement(self): + module = test_utils.build_module(''' + class SelfContext(object): + def __enter__(self): + return self + + class OtherContext(object): + def __enter__(self): + return SelfContext() + + class MultipleReturns(object): + def __enter__(self): + return SelfContext(), OtherContext() + + class MultipleReturns2(object): + def __enter__(self): + return [1, [2, 3]] + + with SelfContext() as self_context: + pass + with OtherContext() as other_context: + pass + with MultipleReturns(), OtherContext() as multiple_with: + pass + with MultipleReturns2() as (stdout, (stderr, stdin)): + pass + ''') + self_context = module['self_context'] + inferred = next(self_context.infer()) + self.assertIsInstance(inferred, Instance) + self.assertEqual(inferred.name, 'SelfContext') + + other_context = module['other_context'] + inferred = next(other_context.infer()) + self.assertIsInstance(inferred, Instance) + self.assertEqual(inferred.name, 'SelfContext') + + multiple_with = module['multiple_with'] + inferred = next(multiple_with.infer()) + self.assertIsInstance(inferred, Instance) + self.assertEqual(inferred.name, 'SelfContext') + + stdout = module['stdout'] + inferred = next(stdout.infer()) + self.assertIsInstance(inferred, nodes.Const) + self.assertEqual(inferred.value, 1) + stderr = module['stderr'] + inferred = next(stderr.infer()) + self.assertIsInstance(inferred, nodes.Const) + self.assertEqual(inferred.value, 2) + + @unittest.expectedFailure + def test_inferring_with_contextlib_contextmanager(self): + module = test_utils.build_module(''' + from contextlib import contextmanager + + @contextlib.contextmanager + def manager(): + yield + + with manager() as none: #@ + pass + ''') + # TODO(cpopa): no support for contextlib.contextmanager yet. + none = module['none'] + next(none.infer()) + if __name__ == '__main__': unittest.main() |