summaryrefslogtreecommitdiff
path: root/astroid/objects.py
diff options
context:
space:
mode:
authorCraig Franklin <craigjfranklin@gmail.com>2021-10-10 17:20:44 +1100
committerGitHub <noreply@github.com>2021-10-10 08:20:44 +0200
commit4f11b666b85082d66c18fce172027929637ea09f (patch)
tree57cddae5a53979189370d910a62da7c48f9e25eb /astroid/objects.py
parent82b94253ac4165d8a49854e2685da63cea01f2e6 (diff)
downloadastroid-git-4f11b666b85082d66c18fce172027929637ea09f.tar.gz
Add recognition of previous partial args/kwargs when chaining (#1209)
* Add recognition of previous partial args/kwargs when chaining Repeat use of functools.partial on the same function would only keep the args & kwargs locked in by the last call to partial. By checking if the wrapped function is itself a partial, and including its filled args/kwargs if it is, we can keep track of all locked args/kwargs across any number of nested partial calls. Co-authored-by: Pierre Sassoulas <pierre.sassoulas@gmail.com>
Diffstat (limited to 'astroid/objects.py')
-rw-r--r--astroid/objects.py12
1 files changed, 11 insertions, 1 deletions
diff --git a/astroid/objects.py b/astroid/objects.py
index 4241c170..ba2e4178 100644
--- a/astroid/objects.py
+++ b/astroid/objects.py
@@ -265,10 +265,20 @@ class PartialFunction(scoped_nodes.FunctionDef):
# A typical FunctionDef automatically adds its name to the parent scope,
# but a partial should not, so defer setting parent until after init
self.parent = parent
- self.filled_positionals = len(call.positional_arguments[1:])
self.filled_args = call.positional_arguments[1:]
self.filled_keywords = call.keyword_arguments
+ wrapped_function = call.positional_arguments[0]
+ inferred_wrapped_function = next(wrapped_function.infer())
+ if isinstance(inferred_wrapped_function, PartialFunction):
+ self.filled_args = inferred_wrapped_function.filled_args + self.filled_args
+ self.filled_keywords = {
+ **inferred_wrapped_function.filled_keywords,
+ **self.filled_keywords,
+ }
+
+ self.filled_positionals = len(self.filled_args)
+
def infer_call_result(self, caller=None, context=None):
if context:
current_passed_keywords = {