diff options
author | Claudiu Popa <cpopa@cloudbasesolutions.com> | 2014-11-28 18:37:48 +0200 |
---|---|---|
committer | Claudiu Popa <cpopa@cloudbasesolutions.com> | 2014-11-28 18:37:48 +0200 |
commit | 3de8d78fa93140639a06d7d4dcc47d8d77990ab5 (patch) | |
tree | 1d5ab49dd873df8da50c668ad5147fd0233df097 | |
parent | 27869a6bd13166b852fb4b8bd40219c3bd286282 (diff) | |
download | astroid-3de8d78fa93140639a06d7d4dcc47d8d77990ab5.tar.gz |
Add inference tips for 'tuple', 'set' and 'list' builtins.
This patch adds some inference tips for the mentioned builtins,
making calls such as 'list([4, 5, 6])` be inferred the same as `[4, 5, 6]`.
Previosly, inferring those objects resulted in an Instance of the given builtin,
without having access to their elements. This is useful, for instance, when trying
to infer slots defined with set, list of tuple calls (instead of their syntactic equivalents).
-rw-r--r-- | ChangeLog | 2 | ||||
-rw-r--r-- | astroid/brain/builtin_inference.py | 109 | ||||
-rw-r--r-- | astroid/node_classes.py | 2 | ||||
-rw-r--r-- | astroid/tests/unittest_inference.py | 115 |
4 files changed, 223 insertions, 5 deletions
@@ -9,6 +9,8 @@ Change log for the astroid package (used to be astng) all the gathered file streams. This is consistent with how the API worked pre-astroid 1.3 (meaning that no seek is required). + * Add inference tips for 'tuple', 'list' and 'set' builtins. + 2014-11-22 -- 1.3.2 * Fixed a crash with invalid subscript index. diff --git a/astroid/brain/builtin_inference.py b/astroid/brain/builtin_inference.py new file mode 100644 index 0000000..bfff963 --- /dev/null +++ b/astroid/brain/builtin_inference.py @@ -0,0 +1,109 @@ +"""Astroid hooks for various builtins.""" +from functools import partial + +import six +from astroid import (MANAGER, UseInferenceDefault, + inference_tip, YES, InferenceError) +from astroid import nodes + + +def register_builtin_transform(transform, builtin_name): + """Register a new transform function for the given *builtin_name*. + + The transform function must accept two parameters, a node and + an optional context. + """ + def _transform_wrapper(node, context=None): + result = transform(node, context=context) + if result: + result.parent = node + result.lineno = node.lineno + return iter([result]) + + MANAGER.register_transform(nodes.CallFunc, + inference_tip(_transform_wrapper), + lambda n: (isinstance(n.func, nodes.Name) and + n.func.name == builtin_name)) + + +def _generic_inference(node, context, node_type, transform): + args = node.args + if not args: + return node_type() + if len(node.args) > 1: + raise UseInferenceDefault() + + arg, = args + transformed = transform(arg) + if not transformed: + try: + infered = next(arg.infer(context=context)) + except (InferenceError, StopIteration): + raise UseInferenceDefault() + if infered is YES: + raise UseInferenceDefault() + transformed = transform(infered) + if not transformed or transformed is YES: + raise UseInferenceDefault() + return transformed + + +def _generic_transform(arg, klass, iterables, build_elts): + if isinstance(arg, klass): + return arg + elif isinstance(arg, iterables): + if not all(isinstance(elt, nodes.Const) + for elt in arg.elts): + # TODO(cpopa): Don't support heterogenous elements. + # Not yet, though. + raise UseInferenceDefault() + elts = [elt.value for elt in arg.elts] + elif isinstance(arg, nodes.Dict): + if not all(isinstance(elt[0], nodes.Const) + for elt in arg.items): + raise UseInferenceDefault() + elts = [item[0].value for item in arg.items] + elif (isinstance(arg, nodes.Const) and + isinstance(arg.value, six.text_type)): + elts = arg.value + else: + return + return klass(elts=build_elts(elts)) + + +def _infer_builtin(node, context, + klass=None, iterables=None, + build_elts=None): + transform_func = partial( + _generic_transform, + klass=klass, + iterables=iterables, + build_elts=build_elts) + + return _generic_inference(node, context, klass, transform_func) + +# pylint: disable=invalid-name +infer_tuple = partial( + _infer_builtin, + klass=nodes.Tuple, + iterables=(nodes.List, nodes.Set), + build_elts=tuple) + +infer_list = partial( + _infer_builtin, + klass=nodes.List, + iterables=(nodes.Tuple, nodes.Set), + build_elts=list) + +infer_set = partial( + _infer_builtin, + klass=nodes.Set, + iterables=(nodes.List, nodes.Tuple), + build_elts=set) + +# Builtins inference +register_builtin_transform(infer_tuple, 'tuple') +register_builtin_transform(infer_set, 'set') +register_builtin_transform(infer_list, 'list') +# Not exactly the same as set, though. +register_builtin_transform(infer_set, 'frozenset') diff --git a/astroid/node_classes.py b/astroid/node_classes.py index 71e512f..9607f2a 100644 --- a/astroid/node_classes.py +++ b/astroid/node_classes.py @@ -46,7 +46,7 @@ def unpack_infer(stmt, context=None): if infered is stmt: yield infered return - # else, infer recursivly, except YES object that should be returned as is + # else, infer recursivly, except YES object that should be returned as is for infered in stmt.infer(context): if infered is YES: yield infered diff --git a/astroid/tests/unittest_inference.py b/astroid/tests/unittest_inference.py index b136a69..eda8fdd 100644 --- a/astroid/tests/unittest_inference.py +++ b/astroid/tests/unittest_inference.py @@ -529,8 +529,8 @@ class InferenceTest(resources.SysPathSetup, unittest.TestCase): name = test_utils.extract_node(code, __name__) it = name.infer() tags = next(it) - self.assertEqual(tags.__class__, Instance) - self.assertEqual(tags._proxied.name, 'list') + self.assertIsInstance(tags, nodes.List) + self.assertEqual(tags.elts, []) with self.assertRaises(StopIteration): next(it) @@ -1368,7 +1368,7 @@ class InferenceTest(resources.SysPathSetup, unittest.TestCase): pass """) self.assertIn( - 'object', + 'object', [base.name for base in klass.ancestors()]) def test_stop_iteration_leak(self): @@ -1381,7 +1381,114 @@ class InferenceTest(resources.SysPathSetup, unittest.TestCase): astroid = test_utils.extract_node(code, __name__) expr = astroid.func.expr self.assertIs(next(expr.infer()), YES) - + + def _test_builtin_inference(self, node_type, node, elts): + infered = next(node.infer()) + self.assertIsInstance(infered, node_type) + self.assertEqual(sorted(elt.value for elt in infered.elts), + elts) + + def test_tuple_builtin_inference(self): + code = """ + var = (1, 2) + tuple() #@ + tuple([1]) #@ + tuple({2}) #@ + tuple("abc") #@ + tuple({1: 2}) #@ + tuple(var) #@ + tuple(tuple([1])) #@ + + tuple(None) #@ + tuple(1) #@ + tuple(1, 2) #@ + """ + astroid = test_utils.extract_node(code, __name__) + tuple_inference = partial(self._test_builtin_inference, nodes.Tuple) + + tuple_inference(astroid[0], []) + tuple_inference(astroid[1], [1]) + tuple_inference(astroid[2], [2]) + tuple_inference(astroid[3], ["a", "b", "c"]) + tuple_inference(astroid[4], [1]) + tuple_inference(astroid[5], [1, 2]) + tuple_inference(astroid[6], [1]) + + for node in astroid[7:]: + infered = next(node.infer()) + self.assertIsInstance(infered, Instance) + self.assertEqual(infered.qname(), "{}.tuple".format(BUILTINS)) + + def test_set_builtin_inference(self): + code = """ + var = (1, 2) + set() #@ + set([1, 2, 1]) #@ + set({2, 3, 1}) #@ + set("abcab") #@ + set({1: 2}) #@ + set(var) #@ + set(tuple([1])) #@ + frozenset([1, 2, 1]) #@ + frozenset({1: 2, 2: 3}) #@ + frozenset() #@ + + set(set(tuple([4, 5, set([2])]))) #@ + set(None) #@ + set(1) #@ + set(1, 2) #@ + """ + astroid = test_utils.extract_node(code, __name__) + tuple_inference = partial(self._test_builtin_inference, nodes.Set) + + tuple_inference(astroid[0], []) + tuple_inference(astroid[1], [1, 2]) + tuple_inference(astroid[2], [1, 2, 3]) + tuple_inference(astroid[3], ["a", "b", "c"]) + tuple_inference(astroid[4], [1]) + tuple_inference(astroid[5], [1, 2]) + tuple_inference(astroid[6], [1]) + tuple_inference(astroid[7], [1, 2]) + tuple_inference(astroid[8], [1, 2]) + tuple_inference(astroid[9], []) + + for node in astroid[10:]: + infered = next(node.infer()) + self.assertIsInstance(infered, Instance) + self.assertEqual(infered.qname(), "{}.set".format(BUILTINS)) + + def test_list_builtin_inference(self): + code = """ + var = (1, 2) + list() #@ + list([1, 2, 1]) #@ + list({2, 3, 1}) #@ + list("abcab") #@ + list({1: 2}) #@ + list(var) #@ + list(tuple([1])) #@ + + list(list(tuple([4, 5, list([2])]))) #@ + list(None) #@ + list(1) #@ + list(1, 2) #@ + """ + astroid = test_utils.extract_node(code, __name__) + tuple_inference = partial(self._test_builtin_inference, nodes.List) + + tuple_inference(astroid[0], []) + tuple_inference(astroid[1], [1, 1, 2]) + tuple_inference(astroid[2], [1, 2, 3]) + tuple_inference(astroid[3], ["a", "a", "b", "b", "c"]) + tuple_inference(astroid[4], [1]) + tuple_inference(astroid[5], [1, 2]) + tuple_inference(astroid[6], [1]) + + for node in astroid[7:]: + infered = next(node.infer()) + self.assertIsInstance(infered, Instance) + self.assertEqual(infered.qname(), "{}.list".format(BUILTINS)) + if __name__ == '__main__': unittest.main() |