summaryrefslogtreecommitdiff
path: root/tests/lock.py
blob: 9a3254204c5cb577efa36225983e0deaba016526 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from __future__ import with_statement
import redis
import time
import unittest
from redis.client import Lock, LockError

class LockTestCase(unittest.TestCase):
    def setUp(self):
        self.client = redis.Redis(host='localhost', port=6379, db=9)
        self.client.flushdb()

    def tearDown(self):
        self.client.flushdb()

    def test_lock(self):
        lock = self.client.lock('foo')
        self.assert_(lock.acquire())
        self.assertEquals(self.client['foo'], str(Lock.LOCK_FOREVER))
        lock.release()
        self.assertEquals(self.client.get('foo'), None)

    def test_competing_locks(self):
        lock1 = self.client.lock('foo')
        lock2 = self.client.lock('foo')
        self.assert_(lock1.acquire())
        self.assertFalse(lock2.acquire(blocking=False))
        lock1.release()
        self.assert_(lock2.acquire())
        self.assertFalse(lock1.acquire(blocking=False))
        lock2.release()

    def test_timeouts(self):
        lock1 = self.client.lock('foo', timeout=1)
        lock2 = self.client.lock('foo')
        self.assert_(lock1.acquire())
        self.assertEquals(lock1.acquired_until, float(int(time.time())) + 1)
        self.assertEquals(lock1.acquired_until, float(self.client['foo']))
        self.assertFalse(lock2.acquire(blocking=False))
        time.sleep(2) # need to wait up to 2 seconds for lock to timeout
        self.assert_(lock2.acquire(blocking=False))
        lock2.release()

    def test_non_blocking(self):
        lock1 = self.client.lock('foo')
        self.assert_(lock1.acquire(blocking=False))
        self.assert_(lock1.acquired_until)
        lock1.release()
        self.assert_(lock1.acquired_until is None)

    def test_context_manager(self):
        with self.client.lock('foo'):
            self.assertEquals(self.client['foo'], str(Lock.LOCK_FOREVER))
        self.assertEquals(self.client.get('foo'), None)

    def test_float_timeout(self):
        lock1 = self.client.lock('foo', timeout=1.5)
        lock2 = self.client.lock('foo', timeout=1.5)
        self.assert_(lock1.acquire())
        self.assertFalse(lock2.acquire(blocking=False))
        lock1.release()

    def test_high_sleep_raises_error(self):
        "If sleep is higher than timeout, it should raise an error"
        self.assertRaises(
            LockError,
            self.client.lock, 'foo', timeout=1, sleep=2
            )