summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClaudiu Popa <cpopa@cloudbasesolutions.com>2014-11-28 18:37:48 +0200
committerClaudiu Popa <cpopa@cloudbasesolutions.com>2014-11-28 18:37:48 +0200
commit3de8d78fa93140639a06d7d4dcc47d8d77990ab5 (patch)
tree1d5ab49dd873df8da50c668ad5147fd0233df097
parent27869a6bd13166b852fb4b8bd40219c3bd286282 (diff)
downloadastroid-3de8d78fa93140639a06d7d4dcc47d8d77990ab5.tar.gz
Add inference tips for 'tuple', 'set' and 'list' builtins.
This patch adds some inference tips for the mentioned builtins, making calls such as 'list([4, 5, 6])` be inferred the same as `[4, 5, 6]`. Previosly, inferring those objects resulted in an Instance of the given builtin, without having access to their elements. This is useful, for instance, when trying to infer slots defined with set, list of tuple calls (instead of their syntactic equivalents).
-rw-r--r--ChangeLog2
-rw-r--r--astroid/brain/builtin_inference.py109
-rw-r--r--astroid/node_classes.py2
-rw-r--r--astroid/tests/unittest_inference.py115
4 files changed, 223 insertions, 5 deletions
diff --git a/ChangeLog b/ChangeLog
index fa4d466..a885359 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -9,6 +9,8 @@ Change log for the astroid package (used to be astng)
all the gathered file streams. This is consistent with how
the API worked pre-astroid 1.3 (meaning that no seek is required).
+ * Add inference tips for 'tuple', 'list' and 'set' builtins.
+
2014-11-22 -- 1.3.2
* Fixed a crash with invalid subscript index.
diff --git a/astroid/brain/builtin_inference.py b/astroid/brain/builtin_inference.py
new file mode 100644
index 0000000..bfff963
--- /dev/null
+++ b/astroid/brain/builtin_inference.py
@@ -0,0 +1,109 @@
+"""Astroid hooks for various builtins."""
+from functools import partial
+
+import six
+from astroid import (MANAGER, UseInferenceDefault,
+ inference_tip, YES, InferenceError)
+from astroid import nodes
+
+
+def register_builtin_transform(transform, builtin_name):
+ """Register a new transform function for the given *builtin_name*.
+
+ The transform function must accept two parameters, a node and
+ an optional context.
+ """
+ def _transform_wrapper(node, context=None):
+ result = transform(node, context=context)
+ if result:
+ result.parent = node
+ result.lineno = node.lineno
+ return iter([result])
+
+ MANAGER.register_transform(nodes.CallFunc,
+ inference_tip(_transform_wrapper),
+ lambda n: (isinstance(n.func, nodes.Name) and
+ n.func.name == builtin_name))
+
+
+def _generic_inference(node, context, node_type, transform):
+ args = node.args
+ if not args:
+ return node_type()
+ if len(node.args) > 1:
+ raise UseInferenceDefault()
+
+ arg, = args
+ transformed = transform(arg)
+ if not transformed:
+ try:
+ infered = next(arg.infer(context=context))
+ except (InferenceError, StopIteration):
+ raise UseInferenceDefault()
+ if infered is YES:
+ raise UseInferenceDefault()
+ transformed = transform(infered)
+ if not transformed or transformed is YES:
+ raise UseInferenceDefault()
+ return transformed
+
+
+def _generic_transform(arg, klass, iterables, build_elts):
+ if isinstance(arg, klass):
+ return arg
+ elif isinstance(arg, iterables):
+ if not all(isinstance(elt, nodes.Const)
+ for elt in arg.elts):
+ # TODO(cpopa): Don't support heterogenous elements.
+ # Not yet, though.
+ raise UseInferenceDefault()
+ elts = [elt.value for elt in arg.elts]
+ elif isinstance(arg, nodes.Dict):
+ if not all(isinstance(elt[0], nodes.Const)
+ for elt in arg.items):
+ raise UseInferenceDefault()
+ elts = [item[0].value for item in arg.items]
+ elif (isinstance(arg, nodes.Const) and
+ isinstance(arg.value, six.text_type)):
+ elts = arg.value
+ else:
+ return
+ return klass(elts=build_elts(elts))
+
+
+def _infer_builtin(node, context,
+ klass=None, iterables=None,
+ build_elts=None):
+ transform_func = partial(
+ _generic_transform,
+ klass=klass,
+ iterables=iterables,
+ build_elts=build_elts)
+
+ return _generic_inference(node, context, klass, transform_func)
+
+# pylint: disable=invalid-name
+infer_tuple = partial(
+ _infer_builtin,
+ klass=nodes.Tuple,
+ iterables=(nodes.List, nodes.Set),
+ build_elts=tuple)
+
+infer_list = partial(
+ _infer_builtin,
+ klass=nodes.List,
+ iterables=(nodes.Tuple, nodes.Set),
+ build_elts=list)
+
+infer_set = partial(
+ _infer_builtin,
+ klass=nodes.Set,
+ iterables=(nodes.List, nodes.Tuple),
+ build_elts=set)
+
+# Builtins inference
+register_builtin_transform(infer_tuple, 'tuple')
+register_builtin_transform(infer_set, 'set')
+register_builtin_transform(infer_list, 'list')
+# Not exactly the same as set, though.
+register_builtin_transform(infer_set, 'frozenset')
diff --git a/astroid/node_classes.py b/astroid/node_classes.py
index 71e512f..9607f2a 100644
--- a/astroid/node_classes.py
+++ b/astroid/node_classes.py
@@ -46,7 +46,7 @@ def unpack_infer(stmt, context=None):
if infered is stmt:
yield infered
return
- # else, infer recursivly, except YES object that should be returned as is
+ # else, infer recursivly, except YES object that should be returned as is
for infered in stmt.infer(context):
if infered is YES:
yield infered
diff --git a/astroid/tests/unittest_inference.py b/astroid/tests/unittest_inference.py
index b136a69..eda8fdd 100644
--- a/astroid/tests/unittest_inference.py
+++ b/astroid/tests/unittest_inference.py
@@ -529,8 +529,8 @@ class InferenceTest(resources.SysPathSetup, unittest.TestCase):
name = test_utils.extract_node(code, __name__)
it = name.infer()
tags = next(it)
- self.assertEqual(tags.__class__, Instance)
- self.assertEqual(tags._proxied.name, 'list')
+ self.assertIsInstance(tags, nodes.List)
+ self.assertEqual(tags.elts, [])
with self.assertRaises(StopIteration):
next(it)
@@ -1368,7 +1368,7 @@ class InferenceTest(resources.SysPathSetup, unittest.TestCase):
pass
""")
self.assertIn(
- 'object',
+ 'object',
[base.name for base in klass.ancestors()])
def test_stop_iteration_leak(self):
@@ -1381,7 +1381,114 @@ class InferenceTest(resources.SysPathSetup, unittest.TestCase):
astroid = test_utils.extract_node(code, __name__)
expr = astroid.func.expr
self.assertIs(next(expr.infer()), YES)
-
+
+ def _test_builtin_inference(self, node_type, node, elts):
+ infered = next(node.infer())
+ self.assertIsInstance(infered, node_type)
+ self.assertEqual(sorted(elt.value for elt in infered.elts),
+ elts)
+
+ def test_tuple_builtin_inference(self):
+ code = """
+ var = (1, 2)
+ tuple() #@
+ tuple([1]) #@
+ tuple({2}) #@
+ tuple("abc") #@
+ tuple({1: 2}) #@
+ tuple(var) #@
+ tuple(tuple([1])) #@
+
+ tuple(None) #@
+ tuple(1) #@
+ tuple(1, 2) #@
+ """
+ astroid = test_utils.extract_node(code, __name__)
+ tuple_inference = partial(self._test_builtin_inference, nodes.Tuple)
+
+ tuple_inference(astroid[0], [])
+ tuple_inference(astroid[1], [1])
+ tuple_inference(astroid[2], [2])
+ tuple_inference(astroid[3], ["a", "b", "c"])
+ tuple_inference(astroid[4], [1])
+ tuple_inference(astroid[5], [1, 2])
+ tuple_inference(astroid[6], [1])
+
+ for node in astroid[7:]:
+ infered = next(node.infer())
+ self.assertIsInstance(infered, Instance)
+ self.assertEqual(infered.qname(), "{}.tuple".format(BUILTINS))
+
+ def test_set_builtin_inference(self):
+ code = """
+ var = (1, 2)
+ set() #@
+ set([1, 2, 1]) #@
+ set({2, 3, 1}) #@
+ set("abcab") #@
+ set({1: 2}) #@
+ set(var) #@
+ set(tuple([1])) #@
+ frozenset([1, 2, 1]) #@
+ frozenset({1: 2, 2: 3}) #@
+ frozenset() #@
+
+ set(set(tuple([4, 5, set([2])]))) #@
+ set(None) #@
+ set(1) #@
+ set(1, 2) #@
+ """
+ astroid = test_utils.extract_node(code, __name__)
+ tuple_inference = partial(self._test_builtin_inference, nodes.Set)
+
+ tuple_inference(astroid[0], [])
+ tuple_inference(astroid[1], [1, 2])
+ tuple_inference(astroid[2], [1, 2, 3])
+ tuple_inference(astroid[3], ["a", "b", "c"])
+ tuple_inference(astroid[4], [1])
+ tuple_inference(astroid[5], [1, 2])
+ tuple_inference(astroid[6], [1])
+ tuple_inference(astroid[7], [1, 2])
+ tuple_inference(astroid[8], [1, 2])
+ tuple_inference(astroid[9], [])
+
+ for node in astroid[10:]:
+ infered = next(node.infer())
+ self.assertIsInstance(infered, Instance)
+ self.assertEqual(infered.qname(), "{}.set".format(BUILTINS))
+
+ def test_list_builtin_inference(self):
+ code = """
+ var = (1, 2)
+ list() #@
+ list([1, 2, 1]) #@
+ list({2, 3, 1}) #@
+ list("abcab") #@
+ list({1: 2}) #@
+ list(var) #@
+ list(tuple([1])) #@
+
+ list(list(tuple([4, 5, list([2])]))) #@
+ list(None) #@
+ list(1) #@
+ list(1, 2) #@
+ """
+ astroid = test_utils.extract_node(code, __name__)
+ tuple_inference = partial(self._test_builtin_inference, nodes.List)
+
+ tuple_inference(astroid[0], [])
+ tuple_inference(astroid[1], [1, 1, 2])
+ tuple_inference(astroid[2], [1, 2, 3])
+ tuple_inference(astroid[3], ["a", "a", "b", "b", "c"])
+ tuple_inference(astroid[4], [1])
+ tuple_inference(astroid[5], [1, 2])
+ tuple_inference(astroid[6], [1])
+
+ for node in astroid[7:]:
+ infered = next(node.infer())
+ self.assertIsInstance(infered, Instance)
+ self.assertEqual(infered.qname(), "{}.list".format(BUILTINS))
+
if __name__ == '__main__':
unittest.main()