summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClaudiu Popa <cpopa@cloudbasesolutions.com>2014-12-03 12:57:44 +0200
committerClaudiu Popa <cpopa@cloudbasesolutions.com>2014-12-03 12:57:44 +0200
commitfd0784a6c31fd9e129c40081e244205e7586ffa5 (patch)
tree6e909d1c93d1ca2c59383639b2454417a71994a6
parent520f16326bb09dd3a91c52abdd502dd50ae230c2 (diff)
downloadastroid-fd0784a6c31fd9e129c40081e244205e7586ffa5.tar.gz
Add inference tips for dict calls.
This patch properly infers dict constructor calls, as in dict(), dict(kwarg=value), dict(<iterable>), dict(<mapping>) and dict(<iterable> or <mapping>, **kwarg) syntax.
-rw-r--r--ChangeLog2
-rw-r--r--astroid/brain/builtin_inference.py69
-rw-r--r--astroid/tests/unittest_inference.py52
3 files changed, 121 insertions, 2 deletions
diff --git a/ChangeLog b/ChangeLog
index a885359..a4764be 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -9,7 +9,7 @@ Change log for the astroid package (used to be astng)
all the gathered file streams. This is consistent with how
the API worked pre-astroid 1.3 (meaning that no seek is required).
- * Add inference tips for 'tuple', 'list' and 'set' builtins.
+ * Add inference tips for 'tuple', 'list', 'dict' and 'set' builtins.
2014-11-22 -- 1.3.2
diff --git a/astroid/brain/builtin_inference.py b/astroid/brain/builtin_inference.py
index 0aea53d..1173e55 100644
--- a/astroid/brain/builtin_inference.py
+++ b/astroid/brain/builtin_inference.py
@@ -3,7 +3,7 @@ from functools import partial
import six
from astroid import (MANAGER, UseInferenceDefault,
- inference_tip, YES, InferenceError)
+ inference_tip, YES, InferenceError, UnresolvableName)
from astroid import nodes
@@ -102,7 +102,74 @@ infer_set = partial(
iterables=(nodes.List, nodes.Tuple),
build_elts=set)
+
+def _get_elts(arg, context):
+ is_iterable = lambda n: isinstance(n,
+ (nodes.List, nodes.Tuple, nodes.Set))
+ try:
+ infered = next(arg.infer(context))
+ except (InferenceError, UnresolvableName):
+ raise UseInferenceDefault()
+ if isinstance(infered, nodes.Dict):
+ items = infered.items
+ elif is_iterable(infered):
+ items = []
+ for elt in infered.elts:
+ # If an item is not a pair of two items,
+ # then fallback to the default inference.
+ # Also, take in consideration only hashable items,
+ # tuples and consts. We are choosing Names as well.
+ if not is_iterable(elt):
+ raise UseInferenceDefault()
+ if len(elt.elts) != 2:
+ raise UseInferenceDefault()
+ if not isinstance(elt.elts[0],
+ (nodes.Tuple, nodes.Const, nodes.Name)):
+ raise UseInferenceDefault()
+ items.append(tuple(elt.elts))
+ else:
+ raise UseInferenceDefault()
+ return items
+
+def infer_dict(node, context=None):
+ """Try to infer a dict call to a Dict node.
+
+ The function treats the following cases:
+
+ * dict()
+ * dict(mapping)
+ * dict(iterable)
+ * dict(iterable, **kwargs)
+ * dict(mapping, **kwargs)
+ * dict(**kwargs)
+
+ If a case can't be infered, we'll fallback to default inference.
+ """
+ has_keywords = lambda args: all(isinstance(arg, nodes.Keyword)
+ for arg in args)
+ if not node.args and not node.kwargs:
+ # dict()
+ return nodes.Dict()
+ elif has_keywords(node.args) and node.args:
+ # dict(a=1, b=2, c=4)
+ items = [(nodes.Const(arg.arg), arg.value) for arg in node.args]
+ elif (len(node.args) >= 2 and
+ has_keywords(node.args[1:])):
+ # dict(some_iterable, b=2, c=4)
+ elts = _get_elts(node.args[0], context)
+ keys = [(nodes.Const(arg.arg), arg.value) for arg in node.args[1:]]
+ items = elts + keys
+ elif len(node.args) == 1:
+ items = _get_elts(node.args[0], context)
+ else:
+ raise UseInferenceDefault()
+
+ empty = nodes.Dict()
+ empty.items = items
+ return empty
+
# Builtins inference
register_builtin_transform(infer_tuple, 'tuple')
register_builtin_transform(infer_set, 'set')
register_builtin_transform(infer_list, 'list')
+register_builtin_transform(infer_dict, 'dict')
diff --git a/astroid/tests/unittest_inference.py b/astroid/tests/unittest_inference.py
index a097c1a..f1463b1 100644
--- a/astroid/tests/unittest_inference.py
+++ b/astroid/tests/unittest_inference.py
@@ -1388,6 +1388,14 @@ class InferenceTest(resources.SysPathSetup, unittest.TestCase):
self.assertEqual(sorted(elt.value for elt in infered.elts),
elts)
+ def _test_dict_inference(self, node, expected):
+ infered = next(node.infer())
+ self.assertIsInstance(infered, nodes.Dict)
+
+ elts = set([(key.value, value.value)
+ for (key, value) in infered.items])
+ self.assertEqual(sorted(elts), sorted(expected.items()))
+
def test_tuple_builtin_inference(self):
code = """
var = (1, 2)
@@ -1495,6 +1503,50 @@ class InferenceTest(resources.SysPathSetup, unittest.TestCase):
self._test_builtin_inference(nodes.Tuple, astroid[1], [97, 98, 99])
self._test_builtin_inference(nodes.Set, astroid[2], [97, 98, 99])
+ def test_dict_inference(self):
+ code = """
+ dict() #@
+ dict(a=1, b=2, c=3) #@
+ dict([(1, 2), (2, 3)]) #@
+ dict([[1, 2], [2, 3]]) #@
+ dict([(1, 2), [2, 3]]) #@
+ dict([('a', 2)], b=2, c=3) #@
+ dict({1: 2}) #@
+ dict({'c': 2}, a=4, b=5) #@
+ def func():
+ return dict(a=1, b=2)
+ func() #@
+ var = {'x': 2, 'y': 3}
+ dict(var, a=1, b=2) #@
+
+ dict([1, 2, 3]) #@
+ dict([(1, 2), (1, 2, 3)]) #@
+ dict({1: 2}, {1: 2}) #@
+ dict({1: 2}, (1, 2)) #@
+ dict({1: 2}, (1, 2), a=4) #@
+ dict([(1, 2), ([4, 5], 2)]) #@
+ dict([None, None]) #@
+
+ def using_unknown_kwargs(**kwargs):
+ return dict(**kwargs)
+ using_unknown_kwargs(a=1, b=2) #@
+ """
+ astroid = test_utils.extract_node(code, __name__)
+ self._test_dict_inference(astroid[0], {})
+ self._test_dict_inference(astroid[1], {'a': 1, 'b': 2, 'c': 3})
+ for i in range(2, 5):
+ self._test_dict_inference(astroid[i], {1: 2, 2: 3})
+ self._test_dict_inference(astroid[5], {'a': 2, 'b': 2, 'c': 3})
+ self._test_dict_inference(astroid[6], {1: 2})
+ self._test_dict_inference(astroid[7], {'c': 2, 'a': 4, 'b': 5})
+ self._test_dict_inference(astroid[8], {'a': 1, 'b': 2})
+ self._test_dict_inference(astroid[9], {'x': 2, 'y': 3, 'a': 1, 'b': 2})
+
+ for node in astroid[10:]:
+ infered = next(node.infer())
+ self.assertIsInstance(infered, Instance)
+ self.assertEqual(infered.qname(), "{}.dict".format(BUILTINS))
+
if __name__ == '__main__':
unittest.main()