summaryrefslogtreecommitdiff
path: root/test/engine/test_pool.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/engine/test_pool.py')
-rw-r--r--test/engine/test_pool.py664
1 files changed, 664 insertions, 0 deletions
diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py
new file mode 100644
index 000000000..43a0fc38b
--- /dev/null
+++ b/test/engine/test_pool.py
@@ -0,0 +1,664 @@
+import threading, time, gc
+from sqlalchemy import pool, interfaces
+import sqlalchemy as tsa
+from sqlalchemy.test import TestBase
+
+
+mcid = 1
+class MockDBAPI(object):
+ def __init__(self):
+ self.throw_error = False
+ def connect(self, *args, **kwargs):
+ if self.throw_error:
+ raise Exception("couldnt connect !")
+ delay = kwargs.pop('delay', 0)
+ if delay:
+ time.sleep(delay)
+ return MockConnection()
+class MockConnection(object):
+ def __init__(self):
+ global mcid
+ self.id = mcid
+ self.closed = False
+ mcid += 1
+ def close(self):
+ self.closed = True
+ def rollback(self):
+ pass
+ def cursor(self):
+ return MockCursor()
+class MockCursor(object):
+ def execute(self, *args, **kw):
+ pass
+ def close(self):
+ pass
+mock_dbapi = MockDBAPI()
+
+
+class PoolTestBase(TestBase):
+ def setup(self):
+ pool.clear_managers()
+
+ @classmethod
+ def teardown_class(cls):
+ pool.clear_managers()
+
+class PoolTest(PoolTestBase):
+ def testmanager(self):
+ manager = pool.manage(mock_dbapi, use_threadlocal=True)
+
+ connection = manager.connect('foo.db')
+ connection2 = manager.connect('foo.db')
+ connection3 = manager.connect('bar.db')
+
+ print "connection " + repr(connection)
+ self.assert_(connection.cursor() is not None)
+ self.assert_(connection is connection2)
+ self.assert_(connection2 is not connection3)
+
+ def testbadargs(self):
+ manager = pool.manage(mock_dbapi)
+
+ try:
+ connection = manager.connect(None)
+ except:
+ pass
+
+ def testnonthreadlocalmanager(self):
+ manager = pool.manage(mock_dbapi, use_threadlocal = False)
+
+ connection = manager.connect('foo.db')
+ connection2 = manager.connect('foo.db')
+
+ print "connection " + repr(connection)
+
+ self.assert_(connection.cursor() is not None)
+ self.assert_(connection is not connection2)
+
+
+ def testthreadlocal_del(self):
+ self._do_testthreadlocal(useclose=False)
+
+ def testthreadlocal_close(self):
+ self._do_testthreadlocal(useclose=True)
+
+ def _do_testthreadlocal(self, useclose=False):
+ for p in (
+ pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True),
+ pool.SingletonThreadPool(creator = mock_dbapi.connect, use_threadlocal = True)
+ ):
+ c1 = p.connect()
+ c2 = p.connect()
+ self.assert_(c1 is c2)
+ c3 = p.unique_connection()
+ self.assert_(c3 is not c1)
+ if useclose:
+ c2.close()
+ else:
+ c2 = None
+ c2 = p.connect()
+ self.assert_(c1 is c2)
+ self.assert_(c3 is not c1)
+ if useclose:
+ c2.close()
+ else:
+ c2 = None
+
+ if useclose:
+ c1 = p.connect()
+ c2 = p.connect()
+ c3 = p.connect()
+ c3.close()
+ c2.close()
+ self.assert_(c1.connection is not None)
+ c1.close()
+
+ c1 = c2 = c3 = None
+
+ # extra tests with QueuePool to ensure connections get __del__()ed when dereferenced
+ if isinstance(p, pool.QueuePool):
+ self.assert_(p.checkedout() == 0)
+ c1 = p.connect()
+ c2 = p.connect()
+ if useclose:
+ c2.close()
+ c1.close()
+ else:
+ c2 = None
+ c1 = None
+ self.assert_(p.checkedout() == 0)
+
+ def test_properties(self):
+ dbapi = MockDBAPI()
+ p = pool.QueuePool(creator=lambda: dbapi.connect('foo.db'),
+ pool_size=1, max_overflow=0, use_threadlocal=False)
+
+ c = p.connect()
+ self.assert_(not c.info)
+ self.assert_(c.info is c._connection_record.info)
+
+ c.info['foo'] = 'bar'
+ c.close()
+ del c
+
+ c = p.connect()
+ self.assert_('foo' in c.info)
+
+ c.invalidate()
+ c = p.connect()
+ self.assert_('foo' not in c.info)
+
+ c.info['foo2'] = 'bar2'
+ c.detach()
+ self.assert_('foo2' in c.info)
+
+ c2 = p.connect()
+ self.assert_(c.connection is not c2.connection)
+ self.assert_(not c2.info)
+ self.assert_('foo2' in c.info)
+
+ def test_listeners(self):
+ dbapi = MockDBAPI()
+
+ class InstrumentingListener(object):
+ def __init__(self):
+ if hasattr(self, 'connect'):
+ self.connect = self.inst_connect
+ if hasattr(self, 'checkout'):
+ self.checkout = self.inst_checkout
+ if hasattr(self, 'checkin'):
+ self.checkin = self.inst_checkin
+ self.clear()
+ def clear(self):
+ self.connected = []
+ self.checked_out = []
+ self.checked_in = []
+ def assert_total(innerself, conn, cout, cin):
+ self.assert_(len(innerself.connected) == conn)
+ self.assert_(len(innerself.checked_out) == cout)
+ self.assert_(len(innerself.checked_in) == cin)
+ def assert_in(innerself, item, in_conn, in_cout, in_cin):
+ self.assert_((item in innerself.connected) == in_conn)
+ self.assert_((item in innerself.checked_out) == in_cout)
+ self.assert_((item in innerself.checked_in) == in_cin)
+ def inst_connect(self, con, record):
+ print "connect(%s, %s)" % (con, record)
+ assert con is not None
+ assert record is not None
+ self.connected.append(con)
+ def inst_checkout(self, con, record, proxy):
+ print "checkout(%s, %s, %s)" % (con, record, proxy)
+ assert con is not None
+ assert record is not None
+ assert proxy is not None
+ self.checked_out.append(con)
+ def inst_checkin(self, con, record):
+ print "checkin(%s, %s)" % (con, record)
+ # con can be None if invalidated
+ assert record is not None
+ self.checked_in.append(con)
+
+ class ListenAll(tsa.interfaces.PoolListener, InstrumentingListener):
+ pass
+ class ListenConnect(InstrumentingListener):
+ def connect(self, con, record):
+ pass
+ class ListenCheckOut(InstrumentingListener):
+ def checkout(self, con, record, proxy, num):
+ pass
+ class ListenCheckIn(InstrumentingListener):
+ def checkin(self, con, record):
+ pass
+
+ def _pool(**kw):
+ return pool.QueuePool(creator=lambda: dbapi.connect('foo.db'),
+ use_threadlocal=False, **kw)
+
+ def assert_listeners(p, total, conn, cout, cin):
+ for instance in (p, p.recreate()):
+ self.assert_(len(instance.listeners) == total)
+ self.assert_(len(instance._on_connect) == conn)
+ self.assert_(len(instance._on_checkout) == cout)
+ self.assert_(len(instance._on_checkin) == cin)
+
+ p = _pool()
+ assert_listeners(p, 0, 0, 0, 0)
+
+ p.add_listener(ListenAll())
+ assert_listeners(p, 1, 1, 1, 1)
+
+ p.add_listener(ListenConnect())
+ assert_listeners(p, 2, 2, 1, 1)
+
+ p.add_listener(ListenCheckOut())
+ assert_listeners(p, 3, 2, 2, 1)
+
+ p.add_listener(ListenCheckIn())
+ assert_listeners(p, 4, 2, 2, 2)
+ del p
+
+ print "----"
+ snoop = ListenAll()
+ p = _pool(listeners=[snoop])
+ assert_listeners(p, 1, 1, 1, 1)
+
+ c = p.connect()
+ snoop.assert_total(1, 1, 0)
+ cc = c.connection
+ snoop.assert_in(cc, True, True, False)
+ c.close()
+ snoop.assert_in(cc, True, True, True)
+ del c, cc
+
+ snoop.clear()
+
+ # this one depends on immediate gc
+ c = p.connect()
+ cc = c.connection
+ snoop.assert_in(cc, False, True, False)
+ snoop.assert_total(0, 1, 0)
+ del c, cc
+ snoop.assert_total(0, 1, 1)
+
+ p.dispose()
+ snoop.clear()
+
+ c = p.connect()
+ c.close()
+ c = p.connect()
+ snoop.assert_total(1, 2, 1)
+ c.close()
+ snoop.assert_total(1, 2, 2)
+
+ # invalidation
+ p.dispose()
+ snoop.clear()
+
+ c = p.connect()
+ snoop.assert_total(1, 1, 0)
+ c.invalidate()
+ snoop.assert_total(1, 1, 1)
+ c.close()
+ snoop.assert_total(1, 1, 1)
+ del c
+ snoop.assert_total(1, 1, 1)
+ c = p.connect()
+ snoop.assert_total(2, 2, 1)
+ c.close()
+ del c
+ snoop.assert_total(2, 2, 2)
+
+ # detached
+ p.dispose()
+ snoop.clear()
+
+ c = p.connect()
+ snoop.assert_total(1, 1, 0)
+ c.detach()
+ snoop.assert_total(1, 1, 0)
+ c.close()
+ del c
+ snoop.assert_total(1, 1, 0)
+ c = p.connect()
+ snoop.assert_total(2, 2, 0)
+ c.close()
+ del c
+ snoop.assert_total(2, 2, 1)
+
+ def test_listeners_callables(self):
+ dbapi = MockDBAPI()
+
+ counts = [0, 0, 0]
+ def connect(dbapi_con, con_record):
+ counts[0] += 1
+ def checkout(dbapi_con, con_record, con_proxy):
+ counts[1] += 1
+ def checkin(dbapi_con, con_record):
+ counts[2] += 1
+
+ i_all = dict(connect=connect, checkout=checkout, checkin=checkin)
+ i_connect = dict(connect=connect)
+ i_checkout = dict(checkout=checkout)
+ i_checkin = dict(checkin=checkin)
+
+ def _pool(**kw):
+ return pool.QueuePool(creator=lambda: dbapi.connect('foo.db'),
+ use_threadlocal=False, **kw)
+
+ def assert_listeners(p, total, conn, cout, cin):
+ for instance in (p, p.recreate()):
+ self.assert_(len(instance.listeners) == total)
+ self.assert_(len(instance._on_connect) == conn)
+ self.assert_(len(instance._on_checkout) == cout)
+ self.assert_(len(instance._on_checkin) == cin)
+
+ p = _pool()
+ assert_listeners(p, 0, 0, 0, 0)
+
+ p.add_listener(i_all)
+ assert_listeners(p, 1, 1, 1, 1)
+
+ p.add_listener(i_connect)
+ assert_listeners(p, 2, 2, 1, 1)
+
+ p.add_listener(i_checkout)
+ assert_listeners(p, 3, 2, 2, 1)
+
+ p.add_listener(i_checkin)
+ assert_listeners(p, 4, 2, 2, 2)
+ del p
+
+ p = _pool(listeners=[i_all])
+ assert_listeners(p, 1, 1, 1, 1)
+
+ c = p.connect()
+ assert counts == [1, 1, 0]
+ c.close()
+ assert counts == [1, 1, 1]
+
+ c = p.connect()
+ assert counts == [1, 2, 1]
+ p.add_listener(i_checkin)
+ c.close()
+ assert counts == [1, 2, 3]
+
+class QueuePoolTest(PoolTestBase):
+
+ def testqueuepool_del(self):
+ self._do_testqueuepool(useclose=False)
+
+ def testqueuepool_close(self):
+ self._do_testqueuepool(useclose=True)
+
+ def _do_testqueuepool(self, useclose=False):
+ p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = False)
+
+ def status(pool):
+ tup = (pool.size(), pool.checkedin(), pool.overflow(), pool.checkedout())
+ print "Pool size: %d Connections in pool: %d Current Overflow: %d Current Checked out connections: %d" % tup
+ return tup
+
+ c1 = p.connect()
+ self.assert_(status(p) == (3,0,-2,1))
+ c2 = p.connect()
+ self.assert_(status(p) == (3,0,-1,2))
+ c3 = p.connect()
+ self.assert_(status(p) == (3,0,0,3))
+ c4 = p.connect()
+ self.assert_(status(p) == (3,0,1,4))
+ c5 = p.connect()
+ self.assert_(status(p) == (3,0,2,5))
+ c6 = p.connect()
+ self.assert_(status(p) == (3,0,3,6))
+ if useclose:
+ c4.close()
+ c3.close()
+ c2.close()
+ else:
+ c4 = c3 = c2 = None
+ self.assert_(status(p) == (3,3,3,3))
+ if useclose:
+ c1.close()
+ c5.close()
+ c6.close()
+ else:
+ c1 = c5 = c6 = None
+ self.assert_(status(p) == (3,3,0,0))
+ c1 = p.connect()
+ c2 = p.connect()
+ self.assert_(status(p) == (3, 1, 0, 2), status(p))
+ if useclose:
+ c2.close()
+ else:
+ c2 = None
+ self.assert_(status(p) == (3, 2, 0, 1))
+
+ def test_timeout(self):
+ p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = 0, use_threadlocal = False, timeout=2)
+ c1 = p.connect()
+ c2 = p.connect()
+ c3 = p.connect()
+ now = time.time()
+ try:
+ c4 = p.connect()
+ assert False
+ except tsa.exc.TimeoutError, e:
+ assert int(time.time() - now) == 2
+
+ def test_timeout_race(self):
+ # test a race condition where the initial connecting threads all race
+ # to queue.Empty, then block on the mutex. each thread consumes a
+ # connection as they go in. when the limit is reached, the remaining
+ # threads go in, and get TimeoutError; even though they never got to
+ # wait for the timeout on queue.get(). the fix involves checking the
+ # timeout again within the mutex, and if so, unlocking and throwing
+ # them back to the start of do_get()
+ p = pool.QueuePool(creator = lambda: mock_dbapi.connect(delay=.05), pool_size = 2, max_overflow = 1, use_threadlocal = False, timeout=3)
+ timeouts = []
+ def checkout():
+ for x in xrange(1):
+ now = time.time()
+ try:
+ c1 = p.connect()
+ except tsa.exc.TimeoutError, e:
+ timeouts.append(int(time.time()) - now)
+ continue
+ time.sleep(4)
+ c1.close()
+
+ threads = []
+ for i in xrange(10):
+ th = threading.Thread(target=checkout)
+ th.start()
+ threads.append(th)
+ for th in threads:
+ th.join()
+
+ print timeouts
+ assert len(timeouts) > 0
+ for t in timeouts:
+ assert abs(t - 3) < 1, "Not all timeouts were 3 seconds: " + repr(timeouts)
+
+ def _test_overflow(self, thread_count, max_overflow):
+ def creator():
+ time.sleep(.05)
+ return mock_dbapi.connect()
+
+ p = pool.QueuePool(creator=creator,
+ pool_size=3, timeout=2,
+ max_overflow=max_overflow)
+ peaks = []
+ def whammy():
+ for i in range(10):
+ try:
+ con = p.connect()
+ time.sleep(.005)
+ peaks.append(p.overflow())
+ con.close()
+ del con
+ except tsa.exc.TimeoutError:
+ pass
+ threads = []
+ for i in xrange(thread_count):
+ th = threading.Thread(target=whammy)
+ th.start()
+ threads.append(th)
+ for th in threads:
+ th.join()
+
+ self.assert_(max(peaks) <= max_overflow)
+
+ def test_no_overflow(self):
+ self._test_overflow(40, 0)
+
+ def test_max_overflow(self):
+ self._test_overflow(40, 5)
+
+ def test_mixed_close(self):
+ p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
+ c1 = p.connect()
+ c2 = p.connect()
+ assert c1 is c2
+ c1.close()
+ c2 = None
+ assert p.checkedout() == 1
+ c1 = None
+ assert p.checkedout() == 0
+
+ def test_weakref_kaboom(self):
+ p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
+ c1 = p.connect()
+ c2 = p.connect()
+ c1.close()
+ c2 = None
+ del c1
+ del c2
+ gc.collect()
+ assert p.checkedout() == 0
+ c3 = p.connect()
+ assert c3 is not None
+
+ def test_trick_the_counter(self):
+ """this is a "flaw" in the connection pool; since threadlocal uses a single ConnectionFairy per thread
+ with an open/close counter, you can fool the counter into giving you a ConnectionFairy with an
+ ambiguous counter. i.e. its not true reference counting."""
+ p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
+ c1 = p.connect()
+ c2 = p.connect()
+ assert c1 is c2
+ c1.close()
+ c2 = p.connect()
+ c2.close()
+ self.assert_(p.checkedout() != 0)
+
+ c2.close()
+ self.assert_(p.checkedout() == 0)
+
+ def test_recycle(self):
+ p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 1, max_overflow = 0, use_threadlocal = False, recycle=3)
+
+ c1 = p.connect()
+ c_id = id(c1.connection)
+ c1.close()
+ c2 = p.connect()
+ assert id(c2.connection) == c_id
+ c2.close()
+ time.sleep(4)
+ c3= p.connect()
+ assert id(c3.connection) != c_id
+
+ def test_invalidate(self):
+ dbapi = MockDBAPI()
+ p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
+ c1 = p.connect()
+ c_id = c1.connection.id
+ c1.close(); c1=None
+ c1 = p.connect()
+ assert c1.connection.id == c_id
+ c1.invalidate()
+ c1 = None
+
+ c1 = p.connect()
+ assert c1.connection.id != c_id
+
+ def test_recreate(self):
+ dbapi = MockDBAPI()
+ p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
+ p2 = p.recreate()
+ assert p2.size() == 1
+ assert p2._use_threadlocal is False
+ assert p2._max_overflow == 0
+
+ def test_reconnect(self):
+ """tests reconnect operations at the pool level. SA's engine/dialect includes another
+ layer of reconnect support for 'database was lost' errors."""
+
+ dbapi = MockDBAPI()
+ p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
+ c1 = p.connect()
+ c_id = c1.connection.id
+ c1.close(); c1=None
+
+ c1 = p.connect()
+ assert c1.connection.id == c_id
+ dbapi.raise_error = True
+ c1.invalidate()
+ c1 = None
+
+ c1 = p.connect()
+ assert c1.connection.id != c_id
+
+ def test_detach(self):
+ dbapi = MockDBAPI()
+ p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
+
+ c1 = p.connect()
+ c1.detach()
+ c_id = c1.connection.id
+
+ c2 = p.connect()
+ assert c2.connection.id != c1.connection.id
+ dbapi.raise_error = True
+
+ c2.invalidate()
+ c2 = None
+
+ c2 = p.connect()
+ assert c2.connection.id != c1.connection.id
+
+ con = c1.connection
+
+ assert not con.closed
+ c1.close()
+ assert con.closed
+
+ def test_threadfairy(self):
+ p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
+ c1 = p.connect()
+ c1.close()
+ c2 = p.connect()
+ assert c2.connection is not None
+
+class SingletonThreadPoolTest(PoolTestBase):
+ def test_cleanup(self):
+ """test that the pool's connections are OK after cleanup() has been called."""
+
+ p = pool.SingletonThreadPool(creator = mock_dbapi.connect, pool_size=3)
+
+ def checkout():
+ for x in xrange(10):
+ c = p.connect()
+ assert c
+ c.cursor()
+ c.close()
+
+ time.sleep(.1)
+
+ threads = []
+ for i in xrange(10):
+ th = threading.Thread(target=checkout)
+ th.start()
+ threads.append(th)
+ for th in threads:
+ th.join()
+
+ assert len(p._all_conns) == 3
+
+class NullPoolTest(PoolTestBase):
+ def test_reconnect(self):
+ dbapi = MockDBAPI()
+ p = pool.NullPool(creator = lambda: dbapi.connect('foo.db'))
+ c1 = p.connect()
+ c_id = c1.connection.id
+ c1.close(); c1=None
+
+ c1 = p.connect()
+ dbapi.raise_error = True
+ c1.invalidate()
+ c1 = None
+
+ c1 = p.connect()
+ assert c1.connection.id != c_id
+
+
+