1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
|
"""Redis cache backend."""
import pickle
import random
import re
from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache
from django.utils.functional import cached_property
from django.utils.module_loading import import_string
class RedisSerializer:
def __init__(self, protocol=None):
self.protocol = pickle.HIGHEST_PROTOCOL if protocol is None else protocol
def dumps(self, obj):
# Only skip pickling for integers, a int subclasses as bool should be
# pickled.
if type(obj) is int:
return obj
return pickle.dumps(obj, self.protocol)
def loads(self, data):
try:
return int(data)
except ValueError:
return pickle.loads(data)
class RedisCacheClient:
def __init__(
self,
servers,
serializer=None,
pool_class=None,
parser_class=None,
**options,
):
import redis
self._lib = redis
self._servers = servers
self._pools = {}
self._client = self._lib.Redis
if isinstance(pool_class, str):
pool_class = import_string(pool_class)
self._pool_class = pool_class or self._lib.ConnectionPool
if isinstance(serializer, str):
serializer = import_string(serializer)
if callable(serializer):
serializer = serializer()
self._serializer = serializer or RedisSerializer()
if isinstance(parser_class, str):
parser_class = import_string(parser_class)
parser_class = parser_class or self._lib.connection.DefaultParser
self._pool_options = {"parser_class": parser_class, **options}
def _get_connection_pool_index(self, write):
# Write to the first server. Read from other servers if there are more,
# otherwise read from the first server.
if write or len(self._servers) == 1:
return 0
return random.randint(1, len(self._servers) - 1)
def _get_connection_pool(self, write):
index = self._get_connection_pool_index(write)
if index not in self._pools:
self._pools[index] = self._pool_class.from_url(
self._servers[index],
**self._pool_options,
)
return self._pools[index]
def get_client(self, key=None, *, write=False):
# key is used so that the method signature remains the same and custom
# cache client can be implemented which might require the key to select
# the server, e.g. sharding.
pool = self._get_connection_pool(write)
return self._client(connection_pool=pool)
def add(self, key, value, timeout):
client = self.get_client(key, write=True)
value = self._serializer.dumps(value)
if timeout == 0:
if ret := bool(client.set(key, value, nx=True)):
client.delete(key)
return ret
else:
return bool(client.set(key, value, ex=timeout, nx=True))
def get(self, key, default):
client = self.get_client(key)
value = client.get(key)
return default if value is None else self._serializer.loads(value)
def set(self, key, value, timeout):
client = self.get_client(key, write=True)
value = self._serializer.dumps(value)
if timeout == 0:
client.delete(key)
else:
client.set(key, value, ex=timeout)
def touch(self, key, timeout):
client = self.get_client(key, write=True)
if timeout is None:
return bool(client.persist(key))
else:
return bool(client.expire(key, timeout))
def delete(self, key):
client = self.get_client(key, write=True)
return bool(client.delete(key))
def get_many(self, keys):
client = self.get_client(None)
ret = client.mget(keys)
return {
k: self._serializer.loads(v) for k, v in zip(keys, ret) if v is not None
}
def has_key(self, key):
client = self.get_client(key)
return bool(client.exists(key))
def incr(self, key, delta):
client = self.get_client(key, write=True)
if not client.exists(key):
raise ValueError("Key '%s' not found." % key)
return client.incr(key, delta)
def set_many(self, data, timeout):
client = self.get_client(None, write=True)
pipeline = client.pipeline()
pipeline.mset({k: self._serializer.dumps(v) for k, v in data.items()})
if timeout is not None:
# Setting timeout for each key as redis does not support timeout
# with mset().
for key in data:
pipeline.expire(key, timeout)
pipeline.execute()
def delete_many(self, keys):
client = self.get_client(None, write=True)
client.delete(*keys)
def clear(self):
client = self.get_client(None, write=True)
return bool(client.flushdb())
class RedisCache(BaseCache):
def __init__(self, server, params):
super().__init__(params)
if isinstance(server, str):
self._servers = re.split("[;,]", server)
else:
self._servers = server
self._class = RedisCacheClient
self._options = params.get("OPTIONS", {})
@cached_property
def _cache(self):
return self._class(self._servers, **self._options)
def get_backend_timeout(self, timeout=DEFAULT_TIMEOUT):
if timeout == DEFAULT_TIMEOUT:
timeout = self.default_timeout
# The key will be made persistent if None used as a timeout.
# Non-positive values will cause the key to be deleted.
return None if timeout is None else max(0, int(timeout))
def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
return self._cache.add(key, value, self.get_backend_timeout(timeout))
def get(self, key, default=None, version=None):
key = self.make_and_validate_key(key, version=version)
return self._cache.get(key, default)
def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
self._cache.set(key, value, self.get_backend_timeout(timeout))
def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
return self._cache.touch(key, self.get_backend_timeout(timeout))
def delete(self, key, version=None):
key = self.make_and_validate_key(key, version=version)
return self._cache.delete(key)
def get_many(self, keys, version=None):
key_map = {
self.make_and_validate_key(key, version=version): key for key in keys
}
ret = self._cache.get_many(key_map.keys())
return {key_map[k]: v for k, v in ret.items()}
def has_key(self, key, version=None):
key = self.make_and_validate_key(key, version=version)
return self._cache.has_key(key)
def incr(self, key, delta=1, version=None):
key = self.make_and_validate_key(key, version=version)
return self._cache.incr(key, delta)
def set_many(self, data, timeout=DEFAULT_TIMEOUT, version=None):
if not data:
return []
safe_data = {}
for key, value in data.items():
key = self.make_and_validate_key(key, version=version)
safe_data[key] = value
self._cache.set_many(safe_data, self.get_backend_timeout(timeout))
return []
def delete_many(self, keys, version=None):
if not keys:
return
safe_keys = [self.make_and_validate_key(key, version=version) for key in keys]
self._cache.delete_many(safe_keys)
def clear(self):
return self._cache.clear()
|