summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClaudiu Popa <pcmanticore@gmail.com>2020-03-05 10:28:34 +0100
committerClaudiu Popa <pcmanticore@gmail.com>2020-03-05 10:28:34 +0100
commit555085e0bd850c5381e29d15294cd37287f79bd6 (patch)
treecbbb2f018bd5545227f60d4544af2161e45baba2
parent5f0675c41ff8c463ab0d657fb6756aa9679cffbf (diff)
downloadastroid-git-555085e0bd850c5381e29d15294cd37287f79bd6.tar.gz
Prevent a recursion error when inferring self-referential variables without definition
Close PyCQA/pylint#1285
-rw-r--r--ChangeLog4
-rw-r--r--astroid/protocols.py2
-rw-r--r--tests/unittest_inference.py17
3 files changed, 22 insertions, 1 deletions
diff --git a/ChangeLog b/ChangeLog
index 75a9b99b..2e712db2 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -34,6 +34,10 @@ Release Date: TBA
Close #762
+* Prevent a recursion error when inferring self-referential variables without definition
+
+ Close PyCQA/pylint#1285
+
* Numpy `datetime64.astype` return value is inferred as a `ndarray`.
Close PyCQA/pylint#3332
diff --git a/astroid/protocols.py b/astroid/protocols.py
index 33d90ea3..6179ab34 100644
--- a/astroid/protocols.py
+++ b/astroid/protocols.py
@@ -367,7 +367,7 @@ def arguments_assigned_stmts(self, node=None, context=None, assign_path=None):
callcontext = context.callcontext
context = contextmod.copy_context(context)
context.callcontext = None
- args = arguments.CallSite(callcontext)
+ args = arguments.CallSite(callcontext, context=context)
return args.infer_argument(self.parent, node.name, context)
return _arguments_infer_argname(self, node.name, context)
diff --git a/tests/unittest_inference.py b/tests/unittest_inference.py
index ec396f8a..4dc785ab 100644
--- a/tests/unittest_inference.py
+++ b/tests/unittest_inference.py
@@ -5667,5 +5667,22 @@ def test_dataclasses_subscript_inference_recursion_error():
assert helpers.safe_infer(node) is None
+def test_self_reference_infer_does_not_trigger_recursion_error():
+ # Prevents https://github.com/PyCQA/pylint/issues/1285
+ code = """
+ def func(elems):
+ return elems
+
+ class BaseModel(object):
+
+ def __init__(self, *args, **kwargs):
+ self._reference = func(*self._reference.split('.'))
+ BaseModel()._reference
+ """
+ node = extract_node(code)
+ inferred = next(node.infer())
+ assert inferred is util.Uninferable
+
+
if __name__ == "__main__":
unittest.main()