diff options
author | Andy McCurdy <andy@andymccurdy.com> | 2014-05-12 13:25:40 -0700 |
---|---|---|
committer | Andy McCurdy <andy@andymccurdy.com> | 2014-05-12 13:25:40 -0700 |
commit | 6a55935205666d9461ab26205da7f2d5f18a78d0 (patch) | |
tree | 05fe195cca0806bf61ba27b1204988fed7232c1f | |
parent | d48eea2450ba76d13df1ac4204f19be3fe388166 (diff) | |
download | redis-py-6a55935205666d9461ab26205da7f2d5f18a78d0.tar.gz |
string literals no longer get encoded before being send to Redis
previously all pieces of a command, including the command name and literal
options to it (such as "WITHSCORES" on ZSET commands) would get encoded.
this works fine on utf-8, but other encodings like utf-16 break.
a new Token class has been introduced that command names and literal options
get wrapped. the encoder falls back to the latin-1 encoding for these
literals as they are all ascii.
fixes #430
-rw-r--r-- | redis/client.py | 291 | ||||
-rw-r--r-- | redis/connection.py | 53 | ||||
-rw-r--r-- | tests/test_encoding.py | 9 |
3 files changed, 193 insertions, 160 deletions
diff --git a/redis/client.py b/redis/client.py index 756eee6..59bdf28 100644 --- a/redis/client.py +++ b/redis/client.py @@ -7,7 +7,7 @@ import threading import time as mod_time from redis._compat import (b, basestring, bytes, imap, iteritems, iterkeys, itervalues, izip, long, nativestr, unicode) -from redis.connection import (ConnectionPool, UnixDomainSocketConnection) +from redis.connection import ConnectionPool, UnixDomainSocketConnection, Token from redis.exceptions import ( ConnectionError, DataError, @@ -155,25 +155,24 @@ def parse_sentinel_state(item): return result -def parse_sentinel(response, **options): - "Parse the result of Redis's SENTINEL command" - parse = options.get('parse') - if parse == 'SENTINEL_INFO': - return [parse_sentinel_state(imap(nativestr, item)) - for item in response] - elif parse == 'SENTINEL_INFO_MASTERS': - result = {} - for item in response: - state = parse_sentinel_state(imap(nativestr, item)) - result[state['name']] = state - return result - elif parse == 'SENTINEL_INFO_MASTER': - return parse_sentinel_state(imap(nativestr, response)) - elif parse == 'SENTINEL_ADDR_PORT': - if response is None: - return - return response[0], int(response[1]) - return response +def parse_sentinel_master(response): + return parse_sentinel_state(imap(nativestr, response)) + + +def parse_sentinel_masters(response): + result = {} + for item in response: + state = parse_sentinel_state(imap(nativestr, item)) + result[state['name']] = state + return result + + +def parse_sentinel_slaves_and_sentinels(response): + return [parse_sentinel_state(imap(nativestr, item)) for item in response] + + +def parse_sentinel_get_master(response): + return response and (response[0], int(response[1])) or None def pairs_to_dict(response): @@ -232,35 +231,20 @@ def float_or_none(response): return float(response) -def parse_client(response, **options): - parse = options['parse'] - if parse == 'LIST': - clients = [] - for c in nativestr(response).splitlines(): - clients.append(dict([pair.split('=') for pair in c.split(' ')])) - return clients - elif parse == 'KILL': - return bool(response) - elif parse == 'GETNAME': - return response and nativestr(response) - elif parse == 'SETNAME': - return nativestr(response) == 'OK' - - -def parse_config(response, **options): - if options['parse'] == 'GET': - response = [nativestr(i) if i is not None else None for i in response] - return response and pairs_to_dict(response) or {} +def bool_ok(response): return nativestr(response) == 'OK' -def parse_script(response, **options): - parse = options['parse'] - if parse in ('FLUSH', 'KILL'): - return response == 'OK' - if parse == 'EXISTS': - return list(imap(bool, response)) - return response +def parse_client_list(response, **options): + clients = [] + for c in nativestr(response).splitlines(): + clients.append(dict([pair.split('=') for pair in c.split(' ')])) + return clients + + +def parse_config_get(response, **options): + response = [nativestr(i) if i is not None else None for i in response] + return response and pairs_to_dict(response) or {} def parse_scan(response, **options): @@ -280,19 +264,13 @@ def parse_zscan(response, **options): return long(cursor), list(izip(it, imap(score_cast_func, it))) -def parse_slowlog(response, **options): - parse = options['parse'] - if parse == 'LEN': - return int(response) - elif parse == 'RESET': - return nativestr(response) == 'OK' - elif parse == 'GET': - return [{ - 'id': item[0], - 'start_time': int(item[1]), - 'duration': int(item[2]), - 'command': b(' ').join(item[3]) - } for item in response] +def parse_slowlog_get(response, **options): + return [{ + 'id': item[0], + 'start_time': int(item[1]), + 'duration': int(item[2]), + 'command': b(' ').join(item[3]) + } for item in response] class StrictRedis(object): @@ -329,7 +307,7 @@ class StrictRedis(object): string_keys_to_dict( 'FLUSHALL FLUSHDB LSET LTRIM MSET PFMERGE RENAME ' 'SAVE SELECT SHUTDOWN SLAVEOF WATCH UNWATCH', - lambda r: nativestr(r) == 'OK' + bool_ok ), string_keys_to_dict('BLPOP BRPOP', lambda r: r and tuple(r) or None), string_keys_to_dict( @@ -343,9 +321,14 @@ class StrictRedis(object): string_keys_to_dict('ZRANK ZREVRANK', int_or_none), string_keys_to_dict('BGREWRITEAOF BGSAVE', lambda r: True), { - 'CLIENT': parse_client, - 'CONFIG': parse_config, - 'DEBUG': parse_debug_object, + 'CLIENT GETNAME': lambda r: r and nativestr(r), + 'CLIENT KILL': bool_ok, + 'CLIENT LIST': parse_client_list, + 'CLIENT SETNAME': bool_ok, + 'CONFIG GET': parse_config_get, + 'CONFIG RESETSTAT': bool_ok, + 'CONFIG SET': bool_ok, + 'DEBUG OBJECT': parse_debug_object, 'HGETALL': lambda r: r and pairs_to_dict(r) or {}, 'HSCAN': parse_hscan, 'INFO': parse_info, @@ -353,13 +336,25 @@ class StrictRedis(object): 'OBJECT': parse_object, 'PING': lambda r: nativestr(r) == 'PONG', 'RANDOMKEY': lambda r: r and r or None, - 'SCRIPT': parse_script, - 'SET': lambda r: r and nativestr(r) == 'OK', - 'TIME': lambda x: (int(x[0]), int(x[1])), - 'SENTINEL': parse_sentinel, 'SCAN': parse_scan, - 'SLOWLOG': parse_slowlog, + 'SCRIPT EXISTS': lambda r: list(imap(bool, r)), + 'SCRIPT FLUSH': bool_ok, + 'SCRIPT KILL': bool_ok, + 'SCRIPT LOAD': nativestr, + 'SENTINEL GET-MASTER-ADDR-BY-NAME': parse_sentinel_get_master, + 'SENTINEL MASTER': parse_sentinel_master, + 'SENTINEL MASTERS': parse_sentinel_masters, + 'SENTINEL MONITOR': bool_ok, + 'SENTINEL REMOVE': bool_ok, + 'SENTINEL SENTINELS': parse_sentinel_slaves_and_sentinels, + 'SENTINEL SET': bool_ok, + 'SENTINEL SLAVES': parse_sentinel_slaves_and_sentinels, + 'SET': lambda r: r and nativestr(r) == 'OK', + 'SLOWLOG GET': parse_slowlog_get, + 'SLOWLOG LEN': int, + 'SLOWLOG RESET': bool_ok, 'SSCAN': parse_scan, + 'TIME': lambda x: (int(x[0]), int(x[1])), 'ZSCAN': parse_zscan } ) @@ -519,43 +514,43 @@ class StrictRedis(object): def client_kill(self, address): "Disconnects the client at ``address`` (ip:port)" - return self.execute_command('CLIENT', 'KILL', address, parse='KILL') + return self.execute_command('CLIENT KILL', address) def client_list(self): "Returns a list of currently connected clients" - return self.execute_command('CLIENT', 'LIST', parse='LIST') + return self.execute_command('CLIENT LIST') def client_getname(self): "Returns the current connection name" - return self.execute_command('CLIENT', 'GETNAME', parse='GETNAME') + return self.execute_command('CLIENT GETNAME') def client_setname(self, name): "Sets the current connection name" - return self.execute_command('CLIENT', 'SETNAME', name, parse='SETNAME') + return self.execute_command('CLIENT SETNAME', name) def config_get(self, pattern="*"): "Return a dictionary of configuration based on the ``pattern``" - return self.execute_command('CONFIG', 'GET', pattern, parse='GET') + return self.execute_command('CONFIG GET', pattern) def config_set(self, name, value): "Set config item ``name`` with ``value``" - return self.execute_command('CONFIG', 'SET', name, value, parse='SET') + return self.execute_command('CONFIG SET', name, value) def config_resetstat(self): "Reset runtime statistics" - return self.execute_command('CONFIG', 'RESETSTAT', parse='RESETSTAT') + return self.execute_command('CONFIG RESETSTAT') def config_rewrite(self): "Rewrite config file with the minimal change to reflect running config" - return self.execute_command('CONFIG', 'REWRITE') + return self.execute_command('CONFIG REWRITE') def dbsize(self): "Returns the number of keys in the current database" return self.execute_command('DBSIZE') def debug_object(self, key): - "Returns version specific metainformation about a give key" - return self.execute_command('DEBUG', 'OBJECT', key) + "Returns version specific meta information about a given key" + return self.execute_command('DEBUG OBJECT', key) def echo(self, value): "Echo the string back from the server" @@ -607,50 +602,42 @@ class StrictRedis(object): return self.execute_command('SAVE') def sentinel(self, *args): - "Redis Sentinel's SENTINEL command" - if args[0] in ['masters', 'slaves', 'sentinels']: - parse = 'SENTINEL_INFO' - else: - parse = 'SENTINEL' - return self.execute_command('SENTINEL', *args, **{'parse': parse}) + "Redis Sentinel's SENTINEL command." + warnings.warn( + DeprecationWarning('Use the individual sentinel_* methods')) + + def sentinel_get_master_addr_by_name(self, service_name): + "Returns a (host, port) pair for the given ``service_name``" + return self.execute_command('SENTINEL GET-MASTER-ADDR-BY-NAME', + service_name) def sentinel_master(self, service_name): "Returns a dictionary containing the specified masters state." - return self.execute_command('SENTINEL', 'master', service_name, - parse='SENTINEL_INFO_MASTER') + return self.execute_command('SENTINEL MASTER', service_name) def sentinel_masters(self): "Returns a list of dictionaries containing each master's state." - return self.execute_command('SENTINEL', 'masters', - parse='SENTINEL_INFO_MASTERS') - - def sentinel_slaves(self, service_name): - "Returns a list of slaves for ``service_name``" - return self.execute_command('SENTINEL', 'slaves', service_name, - parse='SENTINEL_INFO') - - def sentinel_sentinels(self, service_name): - "Returns a list of sentinels for ``service_name``" - return self.execute_command('SENTINEL', 'sentinels', service_name, - parse='SENTINEL_INFO') - - def sentinel_get_master_addr_by_name(self, service_name): - "Returns a (host, port) pair for the given ``service_name``" - return self.execute_command('SENTINEL', 'get-master-addr-by-name', - service_name, parse='SENTINEL_ADDR_PORT') + return self.execute_command('SENTINEL MASTERS') def sentinel_monitor(self, name, ip, port, quorum): "Add a new master to Sentinel to be monitored" - return self.execute_command('SENTINEL', 'MONITOR', name, ip, port, - quorum) + return self.execute_command('SENTINEL MONITOR', name, ip, port, quorum) def sentinel_remove(self, name): "Remove a master from Sentinel's monitoring" - return self.execute_command('SENTINEL', 'REMOVE', name) + return self.execute_command('SENTINEL REMOVE', name) + + def sentinel_sentinels(self, service_name): + "Returns a list of sentinels for ``service_name``" + return self.execute_command('SENTINEL SENTINELS', service_name) def sentinel_set(self, name, option, value): "Set Sentinel monitoring parameters for a given master" - return self.execute_command('SENTINEL', 'SET', name, option, value) + return self.execute_command('SENTINEL SET', name, option, value) + + def sentinel_slaves(self, service_name): + "Returns a list of slaves for ``service_name``" + return self.execute_command('SENTINEL SLAVES', service_name) def shutdown(self): "Shutdown the server" @@ -668,26 +655,26 @@ class StrictRedis(object): instance is promoted to a master instead. """ if host is None and port is None: - return self.execute_command("SLAVEOF", "NO", "ONE") - return self.execute_command("SLAVEOF", host, port) + return self.execute_command('SLAVEOF', Token('NO'), Token('ONE')) + return self.execute_command('SLAVEOF', host, port) def slowlog_get(self, num=None): """ Get the entries from the slowlog. If ``num`` is specified, get the most recent ``num`` items. """ - args = ['SLOWLOG', 'GET'] + args = ['SLOWLOG GET'] if num is not None: args.append(num) - return self.execute_command(*args, parse='GET') + return self.execute_command(*args) def slowlog_len(self): "Get the number of items in the slowlog" - return self.execute_command('SLOWLOG', 'LEN', parse='LEN') + return self.execute_command('SLOWLOG LEN') def slowlog_reset(self): "Remove all items in the slowlog" - return self.execute_command('SLOWLOG', 'RESET', parse='RESET') + return self.execute_command('SLOWLOG RESET') def time(self): """ @@ -1214,10 +1201,10 @@ class StrictRedis(object): pieces = [name] if by is not None: - pieces.append('BY') + pieces.append(Token('BY')) pieces.append(by) if start is not None and num is not None: - pieces.append('LIMIT') + pieces.append(Token('LIMIT')) pieces.append(start) pieces.append(num) if get is not None: @@ -1226,18 +1213,18 @@ class StrictRedis(object): # values. We can't just iterate blindly because strings are # iterable. if isinstance(get, basestring): - pieces.append('GET') + pieces.append(Token('GET')) pieces.append(get) else: for g in get: - pieces.append('GET') + pieces.append(Token('GET')) pieces.append(g) if desc: - pieces.append('DESC') + pieces.append(Token('DESC')) if alpha: - pieces.append('ALPHA') + pieces.append(Token('ALPHA')) if store is not None: - pieces.append('STORE') + pieces.append(Token('STORE')) pieces.append(store) if groups: @@ -1261,9 +1248,9 @@ class StrictRedis(object): """ pieces = [cursor] if match is not None: - pieces.extend(['MATCH', match]) + pieces.extend([Token('MATCH'), match]) if count is not None: - pieces.extend(['COUNT', count]) + pieces.extend([Token('COUNT'), count]) return self.execute_command('SCAN', *pieces) def scan_iter(self, match=None, count=None): @@ -1292,9 +1279,9 @@ class StrictRedis(object): """ pieces = [name, cursor] if match is not None: - pieces.extend(['MATCH', match]) + pieces.extend([Token('MATCH'), match]) if count is not None: - pieces.extend(['COUNT', count]) + pieces.extend([Token('COUNT'), count]) return self.execute_command('SSCAN', *pieces) def sscan_iter(self, name, match=None, count=None): @@ -1324,9 +1311,9 @@ class StrictRedis(object): """ pieces = [name, cursor] if match is not None: - pieces.extend(['MATCH', match]) + pieces.extend([Token('MATCH'), match]) if count is not None: - pieces.extend(['COUNT', count]) + pieces.extend([Token('COUNT'), count]) return self.execute_command('HSCAN', *pieces) def hscan_iter(self, name, match=None, count=None): @@ -1359,9 +1346,9 @@ class StrictRedis(object): """ pieces = [name, cursor] if match is not None: - pieces.extend(['MATCH', match]) + pieces.extend([Token('MATCH'), match]) if count is not None: - pieces.extend(['COUNT', count]) + pieces.extend([Token('COUNT'), count]) options = {'score_cast_func': score_cast_func} return self.execute_command('ZSCAN', *pieces, **options) @@ -1537,9 +1524,11 @@ class StrictRedis(object): score_cast_func) pieces = ['ZRANGE', name, start, end] if withscores: - pieces.append('withscores') + pieces.append(Token('WITHSCORES')) options = { - 'withscores': withscores, 'score_cast_func': score_cast_func} + 'withscores': withscores, + 'score_cast_func': score_cast_func + } return self.execute_command(*pieces, **options) def zrangebylex(self, name, min, max, start=None, num=None): @@ -1555,7 +1544,7 @@ class StrictRedis(object): raise RedisError("``start`` and ``num`` must both be specified") pieces = ['ZRANGEBYLEX', name, min, max] if start is not None and num is not None: - pieces.extend(['LIMIT', start, num]) + pieces.extend([Token('LIMIT'), start, num]) return self.execute_command(*pieces) def zrangebyscore(self, name, min, max, start=None, num=None, @@ -1577,11 +1566,13 @@ class StrictRedis(object): raise RedisError("``start`` and ``num`` must both be specified") pieces = ['ZRANGEBYSCORE', name, min, max] if start is not None and num is not None: - pieces.extend(['LIMIT', start, num]) + pieces.extend([Token('LIMIT'), start, num]) if withscores: - pieces.append('withscores') + pieces.append(Token('WITHSCORES')) options = { - 'withscores': withscores, 'score_cast_func': score_cast_func} + 'withscores': withscores, + 'score_cast_func': score_cast_func + } return self.execute_command(*pieces, **options) def zrank(self, name, value): @@ -1635,9 +1626,11 @@ class StrictRedis(object): """ pieces = ['ZREVRANGE', name, start, end] if withscores: - pieces.append('withscores') + pieces.append(Token('WITHSCORES')) options = { - 'withscores': withscores, 'score_cast_func': score_cast_func} + 'withscores': withscores, + 'score_cast_func': score_cast_func + } return self.execute_command(*pieces, **options) def zrevrangebyscore(self, name, max, min, start=None, num=None, @@ -1659,11 +1652,13 @@ class StrictRedis(object): raise RedisError("``start`` and ``num`` must both be specified") pieces = ['ZREVRANGEBYSCORE', name, max, min] if start is not None and num is not None: - pieces.extend(['LIMIT', start, num]) + pieces.extend([Token('LIMIT'), start, num]) if withscores: - pieces.append('withscores') + pieces.append(Token('WITHSCORES')) options = { - 'withscores': withscores, 'score_cast_func': score_cast_func} + 'withscores': withscores, + 'score_cast_func': score_cast_func + } return self.execute_command(*pieces, **options) def zrevrank(self, name, value): @@ -1693,10 +1688,10 @@ class StrictRedis(object): weights = None pieces.extend(keys) if weights: - pieces.append('WEIGHTS') + pieces.append(Token('WEIGHTS')) pieces.extend(weights) if aggregate: - pieces.append('AGGREGATE') + pieces.append(Token('AGGREGATE')) pieces.append(aggregate) return self.execute_command(*pieces) @@ -1763,7 +1758,7 @@ class StrictRedis(object): Set ``key`` to ``value`` within hash ``name`` if ``key`` does not exist. Returns 1 if HSETNX created a field, otherwise 0. """ - return self.execute_command("HSETNX", name, key, value) + return self.execute_command('HSETNX', name, key, value) def hmset(self, name, mapping): """ @@ -1822,23 +1817,19 @@ class StrictRedis(object): each script as ``args``. Returns a list of boolean values indicating if if each already script exists in the cache. """ - options = {'parse': 'EXISTS'} - return self.execute_command('SCRIPT', 'EXISTS', *args, **options) + return self.execute_command('SCRIPT EXISTS', *args) def script_flush(self): "Flush all scripts from the script cache" - options = {'parse': 'FLUSH'} - return self.execute_command('SCRIPT', 'FLUSH', **options) + return self.execute_command('SCRIPT FLUSH') def script_kill(self): "Kill the currently executing Lua script" - options = {'parse': 'KILL'} - return self.execute_command('SCRIPT', 'KILL', **options) + return self.execute_command('SCRIPT KILL') def script_load(self, script): "Load a Lua ``script`` into the script cache. Returns the SHA." - options = {'parse': 'LOAD'} - return self.execute_command('SCRIPT', 'LOAD', script, **options) + return self.execute_command('SCRIPT LOAD', script) def register_script(self, script): """ diff --git a/redis/connection.py b/redis/connection.py index c5c5e7a..2272133 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -41,6 +41,24 @@ SYM_CRLF = b('\r\n') SYM_EMPTY = b('') +class Token(object): + """ + Literal strings in Redis commands, such as the command names and any + hard-coded arguments are wrapped in this class so we know not to apply + and encoding rules on them. + """ + def __init__(self, value): + if isinstance(value, Token): + value = value.value + self.value = value + + def __repr__(self): + return self.value + + def __str__(self): + return self.value + + class BaseParser(object): EXCEPTION_CLASSES = { 'ERR': ResponseError, @@ -487,32 +505,47 @@ class Connection(object): def encode(self, value): "Return a bytestring representation of the value" + if isinstance(value, Token): + return b(value.value) if isinstance(value, bytes): return value - if isinstance(value, float): - value = repr(value) - if not isinstance(value, basestring): + if isinstance(value, (int, long, float)): + value = b(repr(value)) + elif not isinstance(value, basestring): value = str(value) if isinstance(value, unicode): value = value.encode(self.encoding, self.encoding_errors) return value def pack_command(self, *args): - "Pack a series of arguments into a value Redis command" + "Pack a series of arguments into the Redis protocol" output = [] + # the client might have included 1 or more literal arguments in + # the command name, e.g., 'CONFIG GET'. The Redis server expects these + # arguments to be sent separately, so split the first argument + # manually. All of these arguements get wrapped in the Token class + # to prevent them from being encoded. + command = args[0] + if ' ' in command: + args = tuple([Token(s) for s in command.split(' ')]) + args[1:] + else: + args = (Token(command),) + args[1:] + buff = SYM_EMPTY.join( (SYM_STAR, b(str(len(args))), SYM_CRLF)) - for k in imap(self.encode, args): - if len(buff) > 6000 or len(k) > 6000: + for arg in imap(self.encode, args): + # to avoid large string mallocs, chunk the command into the + # output list if we're sending large values + if len(buff) > 6000 or len(arg) > 6000: buff = SYM_EMPTY.join( - (buff, SYM_DOLLAR, b(str(len(k))), SYM_CRLF)) + (buff, SYM_DOLLAR, b(str(len(arg))), SYM_CRLF)) output.append(buff) - output.append(k) + output.append(arg) buff = SYM_CRLF else: - buff = SYM_EMPTY.join((buff, SYM_DOLLAR, b(str(len(k))), - SYM_CRLF, k, SYM_CRLF)) + buff = SYM_EMPTY.join((buff, SYM_DOLLAR, b(str(len(arg))), + SYM_CRLF, arg, SYM_CRLF)) output.append(buff) return output diff --git a/tests/test_encoding.py b/tests/test_encoding.py index 2aa95d3..b1df0a5 100644 --- a/tests/test_encoding.py +++ b/tests/test_encoding.py @@ -22,3 +22,12 @@ class TestEncoding(object): result = [unicode_string, unicode_string, unicode_string] r.rpush('a', *result) assert r.lrange('a', 0, -1) == result + + +class TestCommandsAndTokensArentEncoded(object): + @pytest.fixture() + def r(self, request): + return _redis_client(request=request, charset='utf-16') + + def test_basic_command(self, r): + r.set('hello', 'world') |