From 9c8be70e12c501dd0682232e9a7b686ac5e70ec3 Mon Sep 17 00:00:00 2001 From: Nick Gaya Date: Thu, 5 Mar 2020 01:59:57 -0800 Subject: Clear pipeline watch state after exec --- .gitignore | 3 ++- CHANGES | 3 +++ redis/client.py | 3 +++ tests/conftest.py | 21 +++++++++++++++++++++ tests/test_monitor.py | 23 ++--------------------- tests/test_pipeline.py | 35 +++++++++++++++++++++++++++++++++++ 6 files changed, 66 insertions(+), 22 deletions(-) diff --git a/.gitignore b/.gitignore index ab39968..7de7594 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,5 @@ vagrant/.vagrant .python-version .cache .eggs -.idea \ No newline at end of file +.idea +.coverage diff --git a/CHANGES b/CHANGES index 1eff175..32e8b95 100644 --- a/CHANGES +++ b/CHANGES @@ -18,6 +18,9 @@ deprecated now. Thanks to @laixintao #1271 * Don't manually DISCARD when encountering an ExecAbortError. Thanks @nickgaya, #1300/#1301 + * Reset the watched state of pipelines after calling exec. This saves + a roundtrip to the server by not having to call UNWATCH within + Pipeline.reset(). Thanks @nickgaya, #1299/#1302 * 3.4.1 * Move the username argument in the Redis and Connection classes to the end of the argument list. This helps those poor souls that specify all diff --git a/redis/client.py b/redis/client.py index 19707b2..9f75465 100755 --- a/redis/client.py +++ b/redis/client.py @@ -3902,6 +3902,9 @@ class Pipeline(Redis): raise errors[0][1] raise sys.exc_info()[1] + # EXEC clears any watched keys + self.watching = False + if response is None: raise WatchError("Watched variable changed.") diff --git a/tests/conftest.py b/tests/conftest.py index 0007b84..b0827b3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,5 @@ +import random + import pytest import redis from mock import Mock @@ -146,3 +148,22 @@ def mock_cluster_resp_slaves(request, **kwargs): "slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 " "1447836789290 3 connected']") return _gen_cluster_mock_resp(r, response) + + +def wait_for_command(client, monitor, command): + # issue a command with a key name that's local to this process. + # if we find a command with our key before the command we're waiting + # for, something went wrong + redis_version = REDIS_INFO["version"] + if StrictVersion(redis_version) >= StrictVersion('5.0.0'): + id_str = str(client.client_id()) + else: + id_str = '%08x' % random.randrange(2**32) + key = '__REDIS-PY-%s__' % id_str + client.get(key) + while True: + monitor_response = monitor.next_command() + if command in monitor_response['command']: + return monitor_response + if key in monitor_response['command']: + return None diff --git a/tests/test_monitor.py b/tests/test_monitor.py index 7ef8ecd..0e39ec0 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -1,31 +1,15 @@ from __future__ import unicode_literals from redis._compat import unicode -from .conftest import skip_if_server_version_lt +from .conftest import wait_for_command -def wait_for_command(client, monitor, command): - # issue a command with a key name that's local to this process. - # if we find a command with our key before the command we're waiting - # for, something went wrong - key = '__REDIS-PY-%s__' % str(client.client_id()) - client.get(key) - while True: - monitor_response = monitor.next_command() - if command in monitor_response['command']: - return monitor_response - if key in monitor_response['command']: - return None - - -class TestPipeline(object): - @skip_if_server_version_lt('5.0.0') +class TestMonitor(object): def test_wait_command_not_found(self, r): "Make sure the wait_for_command func works when command is not found" with r.monitor() as m: response = wait_for_command(r, m, 'nothing') assert response is None - @skip_if_server_version_lt('5.0.0') def test_response_values(self, r): with r.monitor() as m: r.ping() @@ -37,14 +21,12 @@ class TestPipeline(object): assert isinstance(response['client_port'], unicode) assert response['command'] == 'PING' - @skip_if_server_version_lt('5.0.0') def test_command_with_quoted_key(self, r): with r.monitor() as m: r.get('foo"bar') response = wait_for_command(r, m, 'GET foo"bar') assert response['command'] == 'GET foo"bar' - @skip_if_server_version_lt('5.0.0') def test_command_with_binary_data(self, r): with r.monitor() as m: byte_string = b'foo\x92' @@ -52,7 +34,6 @@ class TestPipeline(object): response = wait_for_command(r, m, 'GET foo\\x92') assert response['command'] == 'GET foo\\x92' - @skip_if_server_version_lt('5.0.0') def test_lua_script(self, r): with r.monitor() as m: script = 'return redis.call("GET", "foo")' diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 088071b..4f22153 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -3,6 +3,7 @@ import pytest import redis from redis._compat import unichr, unicode +from .conftest import wait_for_command class TestPipeline(object): @@ -243,6 +244,40 @@ class TestPipeline(object): pipe.get('a') assert pipe.execute() == [b'1'] + def test_watch_exec_no_unwatch(self, r): + r['a'] = 1 + r['b'] = 2 + + with r.monitor() as m: + with r.pipeline() as pipe: + pipe.watch('a', 'b') + assert pipe.watching + a_value = pipe.get('a') + b_value = pipe.get('b') + assert a_value == b'1' + assert b_value == b'2' + pipe.multi() + pipe.set('c', 3) + assert pipe.execute() == [True] + assert not pipe.watching + + unwatch_command = wait_for_command(r, m, 'UNWATCH') + assert unwatch_command is None, "should not send UNWATCH" + + def test_watch_reset_unwatch(self, r): + r['a'] = 1 + + with r.monitor() as m: + with r.pipeline() as pipe: + pipe.watch('a') + assert pipe.watching + pipe.reset() + assert not pipe.watching + + unwatch_command = wait_for_command(r, m, 'UNWATCH') + assert unwatch_command is not None + assert unwatch_command['command'] == 'UNWATCH' + def test_transaction_callable(self, r): r['a'] = 1 r['b'] = 2 -- cgit v1.2.1