summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClaudiu Popa <cpopa@cloudbasesolutions.com>2014-12-29 15:09:56 +0200
committerClaudiu Popa <cpopa@cloudbasesolutions.com>2014-12-29 15:09:56 +0200
commitb95b0c855cc76ed3eb86fc34b7e97c64f33ede31 (patch)
treebbb15d76efada21c8e12973bd48cb92d436a3e28
parente51c0dcd66e4319739450348bca23f555c8c4b6a (diff)
downloadastroid-b95b0c855cc76ed3eb86fc34b7e97c64f33ede31.tar.gz
Improve the detection for functions decorated with decorators
which returns static or class methods.
-rw-r--r--ChangeLog3
-rw-r--r--astroid/scoped_nodes.py97
-rw-r--r--astroid/tests/unittest_scoped_nodes.py39
3 files changed, 84 insertions, 55 deletions
diff --git a/ChangeLog b/ChangeLog
index 8975783..b9dfa5b 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -27,6 +27,9 @@ Change log for the astroid package (used to be astng)
* Add brain tips for six.moves. Closes issue #63.
+ * Improve the detection for functions decorated with decorators
+ which returns static or class methods.
+
2014-11-22 -- 1.3.2
diff --git a/astroid/scoped_nodes.py b/astroid/scoped_nodes.py
index a52a3e8..a5dbd6d 100644
--- a/astroid/scoped_nodes.py
+++ b/astroid/scoped_nodes.py
@@ -556,50 +556,28 @@ else:
# Function ###################################################################
def _infer_decorator_callchain(node):
- """ Detect decorator call chaining and see if the
- end result is a static or a classmethod.
+ """Detect decorator call chaining and see if the end result is a
+ static or a classmethod.
"""
- current = node
- while True:
- if isinstance(current, CallFunc):
- try:
- current = next(current.func.infer())
- except InferenceError:
- return
- elif isinstance(current, Function):
- if not current.parent:
- return
- try:
- # TODO: We don't handle multiple inference results right now,
- # because there's no flow to reason when the return
- # is what we are looking for, a static or a class method.
- result = next(current.infer_call_result(current.parent))
- if current is result:
- # This will lead to an infinite loop, where a decorator
- # returns itself.
- return
- except (StopIteration, InferenceError):
- return
- if isinstance(result, (Function, CallFunc)):
- current = result
- else:
- if isinstance(result, Instance):
- result = result._proxied
- if isinstance(result, Class):
- if (result.name == 'classmethod' and
- result.root().name == BUILTINS):
- return 'classmethod'
- elif (result.name == 'staticmethod' and
- result.root().name == BUILTINS):
- return 'staticmethod'
- else:
- return
- else:
- # We aren't interested in anything else returned,
- # so go back to the function type inference.
- return
- else:
- return
+ if not isinstance(node, Function):
+ return
+ if not node.parent:
+ return
+ try:
+ # TODO: We don't handle multiple inference results right now,
+ # because there's no flow to reason when the return
+ # is what we are looking for, a static or a class method.
+ result = next(node.infer_call_result(node.parent))
+ except (StopIteration, InferenceError):
+ return
+ if isinstance(result, Instance):
+ result = result._proxied
+ if isinstance(result, Class):
+ if result.is_subtype_of('%s.classmethod' % BUILTINS):
+ return 'classmethod'
+ if result.is_subtype_of('%s.staticmethod' % BUILTINS):
+ return 'staticmethod'
+
def _function_type(self):
"""
@@ -612,25 +590,34 @@ def _function_type(self):
if self.decorators:
for node in self.decorators.nodes:
if isinstance(node, CallFunc):
- _type = _infer_decorator_callchain(node)
- if _type is None:
+ # Handle the following case:
+ # @some_decorator(arg1, arg2)
+ # def func(...)
+ #
+ try:
+ current = next(node.func.infer())
+ except InferenceError:
continue
- else:
+ _type = _infer_decorator_callchain(current)
+ if _type is not None:
return _type
- if not isinstance(node, Name):
- continue
+
try:
for infered in node.infer():
+ # Check to see if this returns a static or a class method.
+ _type = _infer_decorator_callchain(infered)
+ if _type is not None:
+ return _type
+
if not isinstance(infered, Class):
continue
for ancestor in infered.ancestors():
- if isinstance(ancestor, Class):
- if (ancestor.name == 'classmethod' and
- ancestor.root().name == BUILTINS):
- return 'classmethod'
- elif (ancestor.name == 'staticmethod' and
- ancestor.root().name == BUILTINS):
- return 'staticmethod'
+ if not isinstance(ancestor, Class):
+ continue
+ if ancestor.is_subtype_of('%s.classmethod' % BUILTINS):
+ return 'classmethod'
+ elif ancestor.is_subtype_of('%s.staticmethod' % BUILTINS):
+ return 'staticmethod'
except InferenceError:
pass
return self._type
diff --git a/astroid/tests/unittest_scoped_nodes.py b/astroid/tests/unittest_scoped_nodes.py
index e38d2a6..a0e5a66 100644
--- a/astroid/tests/unittest_scoped_nodes.py
+++ b/astroid/tests/unittest_scoped_nodes.py
@@ -458,12 +458,36 @@ class FunctionNodeTest(ModuleLoader, unittest.TestCase):
return staticmethod(f)
return wrapper
+ def long_classmethod_decorator(platform=None, order=50):
+ def wrapper(f):
+ def wrapper2(f):
+ def wrapper3(f):
+ f.cgm_module = True
+ f.cgm_module_order = order
+ f.cgm_module_platform = platform
+ return classmethod(f)
+ return wrapper3(f)
+ return wrapper2(f)
+ return wrapper
+
def classmethod_decorator(platform=None):
def wrapper(f):
f.platform = platform
return classmethod(f)
return wrapper
+ def classmethod_wrapper(fn):
+ def wrapper(cls, *args, **kwargs):
+ result = fn(cls, *args, **kwargs)
+ return result
+
+ return classmethod(wrapper)
+
+ def staticmethod_wrapper(fn):
+ def wrapper(*args, **kwargs):
+ return fn(*args, **kwargs)
+ return staticmethod(wrapper)
+
class SomeClass(object):
@static_decorator()
def static(node, cfg):
@@ -477,6 +501,15 @@ class FunctionNodeTest(ModuleLoader, unittest.TestCase):
@classmethod_decorator
def not_so_classmethod(node):
pass
+ @classmethod_wrapper
+ def classmethod_wrapped(cls):
+ pass
+ @staticmethod_wrapper
+ def staticmethod_wrapped():
+ pass
+ @long_classmethod_decorator()
+ def long_classmethod(cls):
+ pass
""")
node = astroid.locals['SomeClass'][0]
self.assertEqual(node.locals['static'][0].type,
@@ -487,6 +520,12 @@ class FunctionNodeTest(ModuleLoader, unittest.TestCase):
'method')
self.assertEqual(node.locals['not_so_classmethod'][0].type,
'method')
+ self.assertEqual(node.locals['classmethod_wrapped'][0].type,
+ 'classmethod')
+ self.assertEqual(node.locals['staticmethod_wrapped'][0].type,
+ 'staticmethod')
+ self.assertEqual(node.locals['long_classmethod'][0].type,
+ 'classmethod')
class ClassNodeTest(ModuleLoader, unittest.TestCase):