diff options
author | Claudiu Popa <cpopa@cloudbasesolutions.com> | 2014-12-29 15:09:56 +0200 |
---|---|---|
committer | Claudiu Popa <cpopa@cloudbasesolutions.com> | 2014-12-29 15:09:56 +0200 |
commit | b95b0c855cc76ed3eb86fc34b7e97c64f33ede31 (patch) | |
tree | bbb15d76efada21c8e12973bd48cb92d436a3e28 | |
parent | e51c0dcd66e4319739450348bca23f555c8c4b6a (diff) | |
download | astroid-b95b0c855cc76ed3eb86fc34b7e97c64f33ede31.tar.gz |
Improve the detection for functions decorated with decorators
which returns static or class methods.
-rw-r--r-- | ChangeLog | 3 | ||||
-rw-r--r-- | astroid/scoped_nodes.py | 97 | ||||
-rw-r--r-- | astroid/tests/unittest_scoped_nodes.py | 39 |
3 files changed, 84 insertions, 55 deletions
@@ -27,6 +27,9 @@ Change log for the astroid package (used to be astng) * Add brain tips for six.moves. Closes issue #63. + * Improve the detection for functions decorated with decorators + which returns static or class methods. + 2014-11-22 -- 1.3.2 diff --git a/astroid/scoped_nodes.py b/astroid/scoped_nodes.py index a52a3e8..a5dbd6d 100644 --- a/astroid/scoped_nodes.py +++ b/astroid/scoped_nodes.py @@ -556,50 +556,28 @@ else: # Function ################################################################### def _infer_decorator_callchain(node): - """ Detect decorator call chaining and see if the - end result is a static or a classmethod. + """Detect decorator call chaining and see if the end result is a + static or a classmethod. """ - current = node - while True: - if isinstance(current, CallFunc): - try: - current = next(current.func.infer()) - except InferenceError: - return - elif isinstance(current, Function): - if not current.parent: - return - try: - # TODO: We don't handle multiple inference results right now, - # because there's no flow to reason when the return - # is what we are looking for, a static or a class method. - result = next(current.infer_call_result(current.parent)) - if current is result: - # This will lead to an infinite loop, where a decorator - # returns itself. - return - except (StopIteration, InferenceError): - return - if isinstance(result, (Function, CallFunc)): - current = result - else: - if isinstance(result, Instance): - result = result._proxied - if isinstance(result, Class): - if (result.name == 'classmethod' and - result.root().name == BUILTINS): - return 'classmethod' - elif (result.name == 'staticmethod' and - result.root().name == BUILTINS): - return 'staticmethod' - else: - return - else: - # We aren't interested in anything else returned, - # so go back to the function type inference. - return - else: - return + if not isinstance(node, Function): + return + if not node.parent: + return + try: + # TODO: We don't handle multiple inference results right now, + # because there's no flow to reason when the return + # is what we are looking for, a static or a class method. + result = next(node.infer_call_result(node.parent)) + except (StopIteration, InferenceError): + return + if isinstance(result, Instance): + result = result._proxied + if isinstance(result, Class): + if result.is_subtype_of('%s.classmethod' % BUILTINS): + return 'classmethod' + if result.is_subtype_of('%s.staticmethod' % BUILTINS): + return 'staticmethod' + def _function_type(self): """ @@ -612,25 +590,34 @@ def _function_type(self): if self.decorators: for node in self.decorators.nodes: if isinstance(node, CallFunc): - _type = _infer_decorator_callchain(node) - if _type is None: + # Handle the following case: + # @some_decorator(arg1, arg2) + # def func(...) + # + try: + current = next(node.func.infer()) + except InferenceError: continue - else: + _type = _infer_decorator_callchain(current) + if _type is not None: return _type - if not isinstance(node, Name): - continue + try: for infered in node.infer(): + # Check to see if this returns a static or a class method. + _type = _infer_decorator_callchain(infered) + if _type is not None: + return _type + if not isinstance(infered, Class): continue for ancestor in infered.ancestors(): - if isinstance(ancestor, Class): - if (ancestor.name == 'classmethod' and - ancestor.root().name == BUILTINS): - return 'classmethod' - elif (ancestor.name == 'staticmethod' and - ancestor.root().name == BUILTINS): - return 'staticmethod' + if not isinstance(ancestor, Class): + continue + if ancestor.is_subtype_of('%s.classmethod' % BUILTINS): + return 'classmethod' + elif ancestor.is_subtype_of('%s.staticmethod' % BUILTINS): + return 'staticmethod' except InferenceError: pass return self._type diff --git a/astroid/tests/unittest_scoped_nodes.py b/astroid/tests/unittest_scoped_nodes.py index e38d2a6..a0e5a66 100644 --- a/astroid/tests/unittest_scoped_nodes.py +++ b/astroid/tests/unittest_scoped_nodes.py @@ -458,12 +458,36 @@ class FunctionNodeTest(ModuleLoader, unittest.TestCase): return staticmethod(f) return wrapper + def long_classmethod_decorator(platform=None, order=50): + def wrapper(f): + def wrapper2(f): + def wrapper3(f): + f.cgm_module = True + f.cgm_module_order = order + f.cgm_module_platform = platform + return classmethod(f) + return wrapper3(f) + return wrapper2(f) + return wrapper + def classmethod_decorator(platform=None): def wrapper(f): f.platform = platform return classmethod(f) return wrapper + def classmethod_wrapper(fn): + def wrapper(cls, *args, **kwargs): + result = fn(cls, *args, **kwargs) + return result + + return classmethod(wrapper) + + def staticmethod_wrapper(fn): + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + return staticmethod(wrapper) + class SomeClass(object): @static_decorator() def static(node, cfg): @@ -477,6 +501,15 @@ class FunctionNodeTest(ModuleLoader, unittest.TestCase): @classmethod_decorator def not_so_classmethod(node): pass + @classmethod_wrapper + def classmethod_wrapped(cls): + pass + @staticmethod_wrapper + def staticmethod_wrapped(): + pass + @long_classmethod_decorator() + def long_classmethod(cls): + pass """) node = astroid.locals['SomeClass'][0] self.assertEqual(node.locals['static'][0].type, @@ -487,6 +520,12 @@ class FunctionNodeTest(ModuleLoader, unittest.TestCase): 'method') self.assertEqual(node.locals['not_so_classmethod'][0].type, 'method') + self.assertEqual(node.locals['classmethod_wrapped'][0].type, + 'classmethod') + self.assertEqual(node.locals['staticmethod_wrapped'][0].type, + 'staticmethod') + self.assertEqual(node.locals['long_classmethod'][0].type, + 'classmethod') class ClassNodeTest(ModuleLoader, unittest.TestCase): |