summaryrefslogtreecommitdiff
path: root/test/engine/test_pool.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/engine/test_pool.py')
-rw-r--r--test/engine/test_pool.py248
1 files changed, 234 insertions, 14 deletions
diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py
index 05c0487f8..2e4c2dc48 100644
--- a/test/engine/test_pool.py
+++ b/test/engine/test_pool.py
@@ -10,6 +10,8 @@ from sqlalchemy.testing import fixtures
from sqlalchemy.testing.mock import Mock, call
+join_timeout = 10
+
def MockDBAPI():
def cursor():
while True:
@@ -306,6 +308,13 @@ class PoolEventsTest(PoolTestBase):
return p, canary
+ def _invalidate_event_fixture(self):
+ p = self._queuepool_fixture()
+ canary = Mock()
+ event.listen(p, 'invalidate', canary)
+
+ return p, canary
+
def test_first_connect_event(self):
p, canary = self._first_connect_event_fixture()
@@ -409,6 +418,31 @@ class PoolEventsTest(PoolTestBase):
c1.close()
eq_(canary, ['reset'])
+ def test_invalidate_event_no_exception(self):
+ p, canary = self._invalidate_event_fixture()
+
+ c1 = p.connect()
+ c1.close()
+ assert not canary.called
+ c1 = p.connect()
+ dbapi_con = c1.connection
+ c1.invalidate()
+ assert canary.call_args_list[0][0][0] is dbapi_con
+ assert canary.call_args_list[0][0][2] is None
+
+ def test_invalidate_event_exception(self):
+ p, canary = self._invalidate_event_fixture()
+
+ c1 = p.connect()
+ c1.close()
+ assert not canary.called
+ c1 = p.connect()
+ dbapi_con = c1.connection
+ exc = Exception("hi")
+ c1.invalidate(exc)
+ assert canary.call_args_list[0][0][0] is dbapi_con
+ assert canary.call_args_list[0][0][2] is exc
+
def test_checkin_event_gc(self):
p, canary = self._checkin_event_fixture()
@@ -827,7 +861,7 @@ class QueuePoolTest(PoolTestBase):
th.start()
threads.append(th)
for th in threads:
- th.join()
+ th.join(join_timeout)
assert len(timeouts) > 0
for t in timeouts:
@@ -864,22 +898,109 @@ class QueuePoolTest(PoolTestBase):
th.start()
threads.append(th)
for th in threads:
- th.join()
+ th.join(join_timeout)
self.assert_(max(peaks) <= max_overflow)
lazy_gc()
assert not pool._refs
+
+ def test_overflow_reset_on_failed_connect(self):
+ dbapi = Mock()
+
+ def failing_dbapi():
+ time.sleep(2)
+ raise Exception("connection failed")
+
+ creator = dbapi.connect
+ def create():
+ return creator()
+
+ p = pool.QueuePool(creator=create, pool_size=2, max_overflow=3)
+ c1 = p.connect()
+ c2 = p.connect()
+ c3 = p.connect()
+ eq_(p._overflow, 1)
+ creator = failing_dbapi
+ assert_raises(Exception, p.connect)
+ eq_(p._overflow, 1)
+
+ @testing.requires.threading_with_mock
+ def test_hanging_connect_within_overflow(self):
+ """test that a single connect() call which is hanging
+ does not block other connections from proceeding."""
+
+ dbapi = Mock()
+ mutex = threading.Lock()
+
+ def hanging_dbapi():
+ time.sleep(2)
+ with mutex:
+ return dbapi.connect()
+
+ def fast_dbapi():
+ with mutex:
+ return dbapi.connect()
+
+ creator = threading.local()
+
+ def create():
+ return creator.mock_connector()
+
+ def run_test(name, pool, should_hang):
+ if should_hang:
+ creator.mock_connector = hanging_dbapi
+ else:
+ creator.mock_connector = fast_dbapi
+
+ conn = pool.connect()
+ conn.operation(name)
+ time.sleep(1)
+ conn.close()
+
+ p = pool.QueuePool(creator=create, pool_size=2, max_overflow=3)
+
+ threads = [
+ threading.Thread(
+ target=run_test, args=("success_one", p, False)),
+ threading.Thread(
+ target=run_test, args=("success_two", p, False)),
+ threading.Thread(
+ target=run_test, args=("overflow_one", p, True)),
+ threading.Thread(
+ target=run_test, args=("overflow_two", p, False)),
+ threading.Thread(
+ target=run_test, args=("overflow_three", p, False))
+ ]
+ for t in threads:
+ t.start()
+ time.sleep(.2)
+
+ for t in threads:
+ t.join(timeout=join_timeout)
+ eq_(
+ dbapi.connect().operation.mock_calls,
+ [call("success_one"), call("success_two"),
+ call("overflow_two"), call("overflow_three"),
+ call("overflow_one")]
+ )
+
+
@testing.requires.threading_with_mock
def test_waiters_handled(self):
"""test that threads waiting for connections are
handled when the pool is replaced.
"""
+ mutex = threading.Lock()
dbapi = MockDBAPI()
def creator():
- return dbapi.connect()
+ mutex.acquire()
+ try:
+ return dbapi.connect()
+ finally:
+ mutex.release()
success = []
for timeout in (None, 30):
@@ -897,21 +1018,27 @@ class QueuePoolTest(PoolTestBase):
c1 = p.connect()
c2 = p.connect()
+ threads = []
for i in range(2):
t = threading.Thread(target=waiter,
args=(p, timeout, max_overflow))
- t.setDaemon(True) # so the tests dont hang if this fails
+ t.daemon = True
t.start()
+ threads.append(t)
- c1.invalidate()
- c2.invalidate()
- p2 = p._replace()
+ # this sleep makes sure that the
+ # two waiter threads hit upon wait()
+ # inside the queue, before we invalidate the other
+ # two conns
time.sleep(.2)
+ p2 = p._replace()
+
+ for t in threads:
+ t.join(join_timeout)
eq_(len(success), 12, "successes: %s" % success)
@testing.requires.threading_with_mock
- @testing.requires.python26
def test_notify_waiters(self):
dbapi = MockDBAPI()
canary = []
@@ -924,9 +1051,7 @@ class QueuePoolTest(PoolTestBase):
p1 = pool.QueuePool(creator=creator1,
pool_size=1, timeout=None,
max_overflow=0)
- p2 = pool.QueuePool(creator=creator2,
- pool_size=1, timeout=None,
- max_overflow=-1)
+ p2 = pool.NullPool(creator=creator2)
def waiter(p):
conn = p.connect()
time.sleep(.5)
@@ -934,14 +1059,18 @@ class QueuePoolTest(PoolTestBase):
c1 = p1.connect()
+ threads = []
for i in range(5):
t = threading.Thread(target=waiter, args=(p1, ))
- t.setDaemon(True)
t.start()
+ threads.append(t)
time.sleep(.5)
eq_(canary, [1])
p1._pool.abort(p2)
- time.sleep(1)
+
+ for t in threads:
+ t.join(join_timeout)
+
eq_(canary, [1, 2, 2, 2, 2, 2])
def test_dispose_closes_pooled(self):
@@ -987,6 +1116,7 @@ class QueuePoolTest(PoolTestBase):
self._test_overflow(40, 5)
def test_mixed_close(self):
+ pool._refs.clear()
p = self._queuepool_fixture(pool_size=3, max_overflow=-1, use_threadlocal=True)
c1 = p.connect()
c2 = p.connect()
@@ -1198,6 +1328,96 @@ class QueuePoolTest(PoolTestBase):
c2 = p.connect()
assert c2.connection is not None
+class ResetOnReturnTest(PoolTestBase):
+ def _fixture(self, **kw):
+ dbapi = Mock()
+ return dbapi, pool.QueuePool(creator=lambda: dbapi.connect('foo.db'), **kw)
+
+ def test_plain_rollback(self):
+ dbapi, p = self._fixture(reset_on_return='rollback')
+
+ c1 = p.connect()
+ c1.close()
+ assert dbapi.connect().rollback.called
+ assert not dbapi.connect().commit.called
+
+ def test_plain_commit(self):
+ dbapi, p = self._fixture(reset_on_return='commit')
+
+ c1 = p.connect()
+ c1.close()
+ assert not dbapi.connect().rollback.called
+ assert dbapi.connect().commit.called
+
+ def test_plain_none(self):
+ dbapi, p = self._fixture(reset_on_return=None)
+
+ c1 = p.connect()
+ c1.close()
+ assert not dbapi.connect().rollback.called
+ assert not dbapi.connect().commit.called
+
+ def test_agent_rollback(self):
+ dbapi, p = self._fixture(reset_on_return='rollback')
+
+ class Agent(object):
+ def __init__(self, conn):
+ self.conn = conn
+
+ def rollback(self):
+ self.conn.special_rollback()
+
+ def commit(self):
+ self.conn.special_commit()
+
+ c1 = p.connect()
+ c1._reset_agent = Agent(c1)
+ c1.close()
+
+ assert dbapi.connect().special_rollback.called
+ assert not dbapi.connect().special_commit.called
+
+ assert not dbapi.connect().rollback.called
+ assert not dbapi.connect().commit.called
+
+ c1 = p.connect()
+ c1.close()
+ eq_(dbapi.connect().special_rollback.call_count, 1)
+ eq_(dbapi.connect().special_commit.call_count, 0)
+
+ assert dbapi.connect().rollback.called
+ assert not dbapi.connect().commit.called
+
+ def test_agent_commit(self):
+ dbapi, p = self._fixture(reset_on_return='commit')
+
+ class Agent(object):
+ def __init__(self, conn):
+ self.conn = conn
+
+ def rollback(self):
+ self.conn.special_rollback()
+
+ def commit(self):
+ self.conn.special_commit()
+
+ c1 = p.connect()
+ c1._reset_agent = Agent(c1)
+ c1.close()
+ assert not dbapi.connect().special_rollback.called
+ assert dbapi.connect().special_commit.called
+
+ assert not dbapi.connect().rollback.called
+ assert not dbapi.connect().commit.called
+
+ c1 = p.connect()
+ c1.close()
+
+ eq_(dbapi.connect().special_rollback.call_count, 0)
+ eq_(dbapi.connect().special_commit.call_count, 1)
+ assert not dbapi.connect().rollback.called
+ assert dbapi.connect().commit.called
+
class SingletonThreadPoolTest(PoolTestBase):
@testing.requires.threading_with_mock
@@ -1245,7 +1465,7 @@ class SingletonThreadPoolTest(PoolTestBase):
th.start()
threads.append(th)
for th in threads:
- th.join()
+ th.join(join_timeout)
assert len(p._all_conns) == 3
if strong_refs: