summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSylvain Th?nault <sylvain.thenault@logilab.fr>2011-04-08 14:48:42 +0200
committerSylvain Th?nault <sylvain.thenault@logilab.fr>2011-04-08 14:48:42 +0200
commit16b24ec7837e73dce47b6b4e201b81d6a510ade7 (patch)
treee1231c791f5fcf09ad48924c8e342f0507eb6e74
parent7e5473cfdf80ffb05852caee3d6ab16452c561de (diff)
downloadlogilab-common-16b24ec7837e73dce47b6b4e201b81d6a510ade7.tar.gz
decorators: refactored @cached to allow usages such as @cached(cacheattr='_cachename') while keeping bw compat
-rw-r--r--ChangeLog7
-rw-r--r--decorators.py143
-rw-r--r--test/unittest_decorators.py38
3 files changed, 123 insertions, 65 deletions
diff --git a/ChangeLog b/ChangeLog
index c3044f4..b2b4505 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -2,8 +2,11 @@ ChangeLog for logilab.common
============================
--
-* date: new datetime/delta <-> seconds/days conversion function
-
+ * date: new datetime/delta <-> seconds/days conversion function
+
+ * decorators: refactored @cached to allow usages such as
+ @cached(cacheattr='_cachename') while keeping bw compat
+
2011-04-01 -- 0.55.2
* new function for password generation in shellutils
diff --git a/decorators.py b/decorators.py
index 8ace461..336af9d 100644
--- a/decorators.py
+++ b/decorators.py
@@ -19,98 +19,117 @@
__docformat__ = "restructuredtext en"
import types
-from time import clock, time
import sys, re
+from time import clock, time
# XXX rewrite so we can use the decorator syntax when keyarg has to be specified
def _is_generator_function(callableobj):
return callableobj.func_code.co_flags & 0x20
-def cached(callableobj, keyarg=None):
- """Simple decorator to cache result of method call."""
- assert not _is_generator_function(callableobj), 'cannot cache generator function: %s' % callableobj
- if callableobj.func_code.co_argcount == 1 or keyarg == 0:
-
- def cache_wrapper1(self, *args):
- cache = '_%s_cache_' % callableobj.__name__
- #print 'cache1?', cache
- try:
- return self.__dict__[cache]
- except KeyError:
- #print 'miss'
- value = callableobj(self, *args)
- setattr(self, cache, value)
- return value
+class cached_decorator(object):
+ def __init__(self, cacheattr=None, keyarg=None):
+ self.cacheattr = cacheattr
+ self.keyarg = keyarg
+ def __call__(self, callableobj=None):
+ assert not _is_generator_function(callableobj), \
+ 'cannot cache generator function: %s' % callableobj
+ if callableobj.func_code.co_argcount == 1 or self.keyarg == 0:
+ cache = _SingleValueCache(callableobj, self.cacheattr)
+ elif self.keyarg:
+ cache = _MultiValuesKeyArgCache(callableobj, self.keyarg, self.cacheattr)
+ print 'hop'
+ else:
+ cache = _MultiValuesCache(callableobj, self.cacheattr)
+ return cache.closure()
+
+class _SingleValueCache(object):
+ def __init__(self, callableobj, cacheattr=None):
+ self.callable = callableobj
+ if cacheattr is None:
+ self.cacheattr = '_%s_cache_' % callableobj.__name__
+ else:
+ assert cacheattr != callableobj.__name__
+ self.cacheattr = cacheattr
+
+ def __call__(__me, self, *args):
+ try:
+ return self.__dict__[__me.cacheattr]
+ except KeyError:
+ value = __me.callable(self, *args)
+ setattr(self, __me.cacheattr, value)
+ return value
+
+ def closure(self):
+ def wrapped(*args, **kwargs):
+ return self.__call__(*args, **kwargs)
+ wrapped.clear = self.clear
try:
- cache_wrapper1.__doc__ = callableobj.__doc__
- cache_wrapper1.func_name = callableobj.func_name
+ wrapped.__doc__ = self.callable.__doc__
+ wrapped.__name__ = self.callable.__name__
+ wrapped.func_name = self.callable.func_name
except:
pass
- return cache_wrapper1
+ return wrapped
- elif keyarg:
+ def clear(self, holder):
+ holder.__dict__.pop(self.cacheattr, None)
- def cache_wrapper2(self, *args, **kwargs):
- cache = '_%s_cache_' % callableobj.__name__
- key = args[keyarg-1]
- #print 'cache2?', cache, self, key
- try:
- _cache = self.__dict__[cache]
- except KeyError:
- #print 'init'
- _cache = {}
- setattr(self, cache, _cache)
- try:
- return _cache[key]
- except KeyError:
- #print 'miss', self, cache, key
- _cache[key] = callableobj(self, *args, **kwargs)
- return _cache[key]
- try:
- cache_wrapper2.__doc__ = callableobj.__doc__
- cache_wrapper2.func_name = callableobj.func_name
- except:
- pass
- return cache_wrapper2
- def cache_wrapper3(self, *args):
- cache = '_%s_cache_' % callableobj.__name__
- #print 'cache3?', cache, self, args
+class _MultiValuesCache(_SingleValueCache):
+ def _get_cache(self, holder):
try:
- _cache = self.__dict__[cache]
+ _cache = holder.__dict__[self.cacheattr]
except KeyError:
- #print 'init'
_cache = {}
- setattr(self, cache, _cache)
+ setattr(holder, self.cacheattr, _cache)
+ return _cache
+
+ def __call__(__me, self, *args, **kwargs):
+ _cache = __me._get_cache(self)
try:
return _cache[args]
except KeyError:
- #print 'miss'
- _cache[args] = callableobj(self, *args)
- return _cache[args]
- try:
- cache_wrapper3.__doc__ = callableobj.__doc__
- cache_wrapper3.func_name = callableobj.func_name
- except:
- pass
- return cache_wrapper3
+ _cache[args] = __me.callable(self, *args)
+ return _cache[args]
+
+class _MultiValuesKeyArgCache(_MultiValuesCache):
+ def __init__(self, callableobj, keyarg, cacheattr=None):
+ super(_MultiValuesKeyArgCache, self).__init__(callableobj, cacheattr)
+ self.keyarg = keyarg
+
+ def __call__(__me, self, *args, **kwargs):
+ _cache = __me._get_cache(self)
+ key = args[__me.keyarg-1]
+ try:
+ return _cache[key]
+ except KeyError:
+ _cache[key] = __me.callable(self, *args, **kwargs)
+ return _cache[key]
+
+
+def cached(callableobj=None, keyarg=None, **kwargs):
+ """Simple decorator to cache result of method call."""
+ kwargs['keyarg'] = keyarg
+ decorator = cached_decorator(**kwargs)
+ if callableobj is None:
+ return decorator
+ else:
+ return decorator(callableobj)
def clear_cache(obj, funcname):
"""Function to clear a cache handled by the cached decorator."""
- try:
- del obj.__dict__['_%s_cache_' % funcname]
- except KeyError:
- pass
+ getattr(obj, funcname).clear(obj)
def copy_cache(obj, funcname, cacheobj):
"""Copy cache for <funcname> from cacheobj to obj."""
- cache = '_%s_cache_' % funcname
+ cache = getattr(obj, funcname).cacheattr
try:
setattr(obj, cache, cacheobj.__dict__[cache])
except KeyError:
pass
+
class wproperty(object):
"""Simple descriptor expecting to take a modifier function as first argument
and looking for a _<function name> to retrieve the attribute.
diff --git a/test/unittest_decorators.py b/test/unittest_decorators.py
index 5c598f4..b3321fc 100644
--- a/test/unittest_decorators.py
+++ b/test/unittest_decorators.py
@@ -19,7 +19,7 @@
"""
from logilab.common.testlib import TestCase, unittest_main
-from logilab.common.decorators import monkeypatch, cached
+from logilab.common.decorators import monkeypatch, cached, clear_cache
class DecoratorsTC(TestCase):
@@ -60,11 +60,47 @@ class DecoratorsTC(TestCase):
def quux(self, zogzog):
""" what's up doc ? """
self.assertEqual(Foo.foo.__doc__, """ what's up doc ? """)
+ self.assertEqual(Foo.foo.__name__, 'foo')
self.assertEqual(Foo.foo.func_name, 'foo')
self.assertEqual(Foo.bar.__doc__, """ what's up doc ? """)
+ self.assertEqual(Foo.bar.__name__, 'bar')
self.assertEqual(Foo.bar.func_name, 'bar')
self.assertEqual(Foo.quux.__doc__, """ what's up doc ? """)
+ self.assertEqual(Foo.quux.__name__, 'quux')
self.assertEqual(Foo.quux.func_name, 'quux')
+ def test_cached_single_cache(self):
+ class Foo(object):
+ @cached(cacheattr=u'_foo')
+ def foo(self):
+ """ what's up doc ? """
+ foo = Foo()
+ foo.foo()
+ self.assertTrue(hasattr(foo, '_foo'))
+ clear_cache(foo, 'foo')
+ self.assertFalse(hasattr(foo, '_foo'))
+
+ def test_cached_multi_cache(self):
+ class Foo(object):
+ @cached(cacheattr=u'_foo')
+ def foo(self, args):
+ """ what's up doc ? """
+ foo = Foo()
+ foo.foo(1)
+ self.assertEqual(foo._foo, {(1,): None})
+ clear_cache(foo, 'foo')
+ self.assertFalse(hasattr(foo, '_foo'))
+
+ def test_cached_keyarg_cache(self):
+ class Foo(object):
+ @cached(cacheattr=u'_foo', keyarg=1)
+ def foo(self, other, args):
+ """ what's up doc ? """
+ foo = Foo()
+ foo.foo(2, 1)
+ self.assertEqual(foo._foo, {2: None})
+ clear_cache(foo, 'foo')
+ self.assertFalse(hasattr(foo, '_foo'))
+
if __name__ == '__main__':
unittest_main()