summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClaudiu Popa <cpopa@cloudbasesolutions.com>2015-05-09 15:58:31 +0300
committerClaudiu Popa <cpopa@cloudbasesolutions.com>2015-05-09 15:58:31 +0300
commitf97ef6e9e89ed2de13b622c1580ef9d14e3fca1f (patch)
tree3217c294f74d007a33029934a85505289c16c663
parent28d8eec1933525151b5468a1bf513f54fa5c50ff (diff)
downloadastroid-f97ef6e9e89ed2de13b622c1580ef9d14e3fca1f.tar.gz
Add basic support for understanding context managers.
Currently, there's no way to understand whatever __enter__ returns in a context manager and what it is binded using the ``as`` keyword. With these changes, we can understand ``bar`` in ``with foo() as bar``, which will be the result of __enter__. There's no support for contextlib.contextmanager yet.
-rw-r--r--ChangeLog7
-rw-r--r--astroid/protocols.py65
-rw-r--r--astroid/tests/unittest_inference.py88
3 files changed, 151 insertions, 9 deletions
diff --git a/ChangeLog b/ChangeLog
index 6cda6c0..afe6f5a 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -92,6 +92,13 @@ Change log for the astroid package (used to be astng)
that all of them will have __getattr__ and __getattribute__ present and it
is wrong to consider that those methods were actually implemented.
+ * Add basic support for understanding context managers.
+
+ Currently, there's no way to understand whatever __enter__ returns in a
+ context manager and what it is binded using the ``as`` keyword. With these changes,
+ we can understand ``bar`` in ``with foo() as bar``, which will be the result of __enter__.
+ There's no support for contextlib.contextmanager yet.
+
2015-03-14 -- 1.3.6
diff --git a/astroid/protocols.py b/astroid/protocols.py
index 4c11f9c..5c9b200 100644
--- a/astroid/protocols.py
+++ b/astroid/protocols.py
@@ -24,8 +24,12 @@ import collections
from astroid.exceptions import InferenceError, NoDefault, NotFoundError
from astroid.node_classes import unpack_infer
-from astroid.bases import InferenceContext, copy_context, \
- raise_if_nothing_infered, yes_if_nothing_infered, Instance, YES
+from astroid.bases import (
+ InferenceContext, copy_context,
+ raise_if_nothing_infered, yes_if_nothing_infered,
+ Instance, YES, BoundMethod,
+ Generator,
+)
from astroid.nodes import const_factory
from astroid import nodes
@@ -347,17 +351,60 @@ def excepthandler_assigned_stmts(self, node, context=None, asspath=None):
assigned = Instance(assigned)
yield assigned
nodes.ExceptHandler.assigned_stmts = raise_if_nothing_infered(excepthandler_assigned_stmts)
+
+def _infer_context_manager(self, mgr, context):
+ try:
+ inferred = next(mgr.infer(context=context))
+ except InferenceError:
+ return
+ if isinstance(inferred, Generator):
+ # TODO(cpopa): unsupported for now.
+ return
+ elif isinstance(inferred, Instance):
+ try:
+ enter = next(inferred.igetattr('__enter__', context=context))
+ except (InferenceError, NotFoundError):
+ return
+ if not isinstance(enter, BoundMethod):
+ return
+ for result in enter.infer_call_result(self, context):
+ yield result
def with_assigned_stmts(self, node, context=None, asspath=None):
+ """Infer names and other nodes from a *with* statement.
+
+ This enables only inference for name binding in a *with* statement.
+ For instance, in the following code, inferring `func` will return
+ the `ContextManager` class, not whatever ``__enter__`` returns.
+ We are doing this intentionally, because we consider that the context
+ manager result is whatever __enter__ returns and what it is binded
+ using the ``as`` keyword.
+
+ class ContextManager(object):
+ def __enter__(self):
+ return 42
+ with ContextManager() as f:
+ pass
+ # ContextManager().infer() will return ContextManager
+ # f.infer() will return 42.
+ """
+
+ mgr = next(mgr for (mgr, vars) in self.items if vars == node)
if asspath is None:
- for _, vars in self.items:
- if vars is None:
- continue
- for lst in vars.infer(context):
- if isinstance(lst, (nodes.Tuple, nodes.List)):
- for item in lst.nodes:
- yield item
+ for result in _infer_context_manager(self, mgr, context):
+ yield result
+ else:
+ for result in _infer_context_manager(self, mgr, context):
+ # Walk the asspath and get the item at the final index.
+ obj = result
+ for index in asspath:
+ if not hasattr(obj, 'elts'):
+ raise InferenceError
+ obj = obj.elts[index]
+ yield obj
+
+
nodes.With.assigned_stmts = raise_if_nothing_infered(with_assigned_stmts)
diff --git a/astroid/tests/unittest_inference.py b/astroid/tests/unittest_inference.py
index 3b2bad1..830eb90 100644
--- a/astroid/tests/unittest_inference.py
+++ b/astroid/tests/unittest_inference.py
@@ -1655,6 +1655,94 @@ class InferenceTest(resources.SysPathSetup, unittest.TestCase):
self.assertIsInstance(inferred, nodes.Class)
self.assertEqual(inferred.qname(), 'collections.Counter')
+ def test_inferring_with_statement_failures(self):
+ module = test_utils.build_module('''
+ class NoEnter(object):
+ pass
+ class NoMethod(object):
+ __enter__ = None
+ class NoElts(object):
+ def __enter__(self):
+ return 42
+
+ with NoEnter() as no_enter:
+ pass
+ with NoMethod() as no_method:
+ pass
+ with NoElts() as (no_elts, no_elts1):
+ pass
+ ''')
+ self.assertRaises(InferenceError, next, module['no_enter'].infer())
+ self.assertRaises(InferenceError, next, module['no_method'].infer())
+ self.assertRaises(InferenceError, next, module['no_elts'].infer())
+
+ def test_inferring_with_statement(self):
+ module = test_utils.build_module('''
+ class SelfContext(object):
+ def __enter__(self):
+ return self
+
+ class OtherContext(object):
+ def __enter__(self):
+ return SelfContext()
+
+ class MultipleReturns(object):
+ def __enter__(self):
+ return SelfContext(), OtherContext()
+
+ class MultipleReturns2(object):
+ def __enter__(self):
+ return [1, [2, 3]]
+
+ with SelfContext() as self_context:
+ pass
+ with OtherContext() as other_context:
+ pass
+ with MultipleReturns(), OtherContext() as multiple_with:
+ pass
+ with MultipleReturns2() as (stdout, (stderr, stdin)):
+ pass
+ ''')
+ self_context = module['self_context']
+ inferred = next(self_context.infer())
+ self.assertIsInstance(inferred, Instance)
+ self.assertEqual(inferred.name, 'SelfContext')
+
+ other_context = module['other_context']
+ inferred = next(other_context.infer())
+ self.assertIsInstance(inferred, Instance)
+ self.assertEqual(inferred.name, 'SelfContext')
+
+ multiple_with = module['multiple_with']
+ inferred = next(multiple_with.infer())
+ self.assertIsInstance(inferred, Instance)
+ self.assertEqual(inferred.name, 'SelfContext')
+
+ stdout = module['stdout']
+ inferred = next(stdout.infer())
+ self.assertIsInstance(inferred, nodes.Const)
+ self.assertEqual(inferred.value, 1)
+ stderr = module['stderr']
+ inferred = next(stderr.infer())
+ self.assertIsInstance(inferred, nodes.Const)
+ self.assertEqual(inferred.value, 2)
+
+ @unittest.expectedFailure
+ def test_inferring_with_contextlib_contextmanager(self):
+ module = test_utils.build_module('''
+ from contextlib import contextmanager
+
+ @contextlib.contextmanager
+ def manager():
+ yield
+
+ with manager() as none: #@
+ pass
+ ''')
+ # TODO(cpopa): no support for contextlib.contextmanager yet.
+ none = module['none']
+ next(none.infer())
+
if __name__ == '__main__':
unittest.main()