diff options
author | Torsten Marek <shlomme@gmail.com> | 2014-11-22 00:23:13 +0100 |
---|---|---|
committer | Torsten Marek <shlomme@gmail.com> | 2014-11-22 00:23:13 +0100 |
commit | 0be7813bfb4e21ddb9164656e32c4346831d4e15 (patch) | |
tree | 45afde480718ec3d63c67ce63873ceff7af7f5de | |
parent | 5ca590070112c4d593de0488bac8a11cf328fe27 (diff) | |
download | astroid-0be7813bfb4e21ddb9164656e32c4346831d4e15.tar.gz |
Pass along the context when inferring the call result of a class to avoid infinite loops.
-rw-r--r-- | astroid/scoped_nodes.py | 10 | ||||
-rw-r--r-- | astroid/tests/unittest_inference.py | 14 |
2 files changed, 19 insertions, 5 deletions
diff --git a/astroid/scoped_nodes.py b/astroid/scoped_nodes.py index c70618b..52934f1 100644 --- a/astroid/scoped_nodes.py +++ b/astroid/scoped_nodes.py @@ -952,17 +952,17 @@ class Class(Statement, LocalsDictNodeNG, FilterStmtsMixin): def callable(self): return True - def _is_subtype_of(self, type_name): + def _is_subtype_of(self, type_name, context): if self.qname() == type_name: return True - for anc in self.ancestors(): + for anc in self.ancestors(context=context): if anc.qname() == type_name: return True def infer_call_result(self, caller, context=None): """infer what a class is returning when called""" - if self._is_subtype_of('%s.type' % (BUILTINS,)) and len(caller.args) == 3: - name_node = next(caller.args[0].infer()) + if self._is_subtype_of('%s.type' % (BUILTINS,), context) and len(caller.args) == 3: + name_node = next(caller.args[0].infer(context)) if (isinstance(name_node, Const) and isinstance(name_node.value, six.string_types)): name = name_node.value @@ -970,7 +970,7 @@ class Class(Statement, LocalsDictNodeNG, FilterStmtsMixin): yield YES return result = Class(name, None) - bases = next(caller.args[1].infer()) + bases = next(caller.args[1].infer(context)) if isinstance(bases, (Tuple, List)): result.bases = bases.itered() else: diff --git a/astroid/tests/unittest_inference.py b/astroid/tests/unittest_inference.py index 2fdb689..6ef0d62 100644 --- a/astroid/tests/unittest_inference.py +++ b/astroid/tests/unittest_inference.py @@ -1357,6 +1357,20 @@ class InferenceTest(resources.SysPathSetup, unittest.TestCase): node = astroid['do_a_thing'] self.assertEqual(node.type, 'function') + def test_no_infinite_ancestor_loop(self): + klass = test_utils.extract_node(""" + import datetime + + def method(self): + datetime.datetime = something() + + class something(datetime.datetime): #@ + pass + """) + self.assertIn( + 'object', + [base.name for base in klass.ancestors()]) + if __name__ == '__main__': unittest.main() |