diff options
Diffstat (limited to 'test/engine/test_pool.py')
-rw-r--r-- | test/engine/test_pool.py | 248 |
1 files changed, 234 insertions, 14 deletions
diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py index 05c0487f8..2e4c2dc48 100644 --- a/test/engine/test_pool.py +++ b/test/engine/test_pool.py @@ -10,6 +10,8 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing.mock import Mock, call +join_timeout = 10 + def MockDBAPI(): def cursor(): while True: @@ -306,6 +308,13 @@ class PoolEventsTest(PoolTestBase): return p, canary + def _invalidate_event_fixture(self): + p = self._queuepool_fixture() + canary = Mock() + event.listen(p, 'invalidate', canary) + + return p, canary + def test_first_connect_event(self): p, canary = self._first_connect_event_fixture() @@ -409,6 +418,31 @@ class PoolEventsTest(PoolTestBase): c1.close() eq_(canary, ['reset']) + def test_invalidate_event_no_exception(self): + p, canary = self._invalidate_event_fixture() + + c1 = p.connect() + c1.close() + assert not canary.called + c1 = p.connect() + dbapi_con = c1.connection + c1.invalidate() + assert canary.call_args_list[0][0][0] is dbapi_con + assert canary.call_args_list[0][0][2] is None + + def test_invalidate_event_exception(self): + p, canary = self._invalidate_event_fixture() + + c1 = p.connect() + c1.close() + assert not canary.called + c1 = p.connect() + dbapi_con = c1.connection + exc = Exception("hi") + c1.invalidate(exc) + assert canary.call_args_list[0][0][0] is dbapi_con + assert canary.call_args_list[0][0][2] is exc + def test_checkin_event_gc(self): p, canary = self._checkin_event_fixture() @@ -827,7 +861,7 @@ class QueuePoolTest(PoolTestBase): th.start() threads.append(th) for th in threads: - th.join() + th.join(join_timeout) assert len(timeouts) > 0 for t in timeouts: @@ -864,22 +898,109 @@ class QueuePoolTest(PoolTestBase): th.start() threads.append(th) for th in threads: - th.join() + th.join(join_timeout) self.assert_(max(peaks) <= max_overflow) lazy_gc() assert not pool._refs + + def test_overflow_reset_on_failed_connect(self): + dbapi = Mock() + + def failing_dbapi(): + time.sleep(2) + raise Exception("connection failed") + + creator = dbapi.connect + def create(): + return creator() + + p = pool.QueuePool(creator=create, pool_size=2, max_overflow=3) + c1 = p.connect() + c2 = p.connect() + c3 = p.connect() + eq_(p._overflow, 1) + creator = failing_dbapi + assert_raises(Exception, p.connect) + eq_(p._overflow, 1) + + @testing.requires.threading_with_mock + def test_hanging_connect_within_overflow(self): + """test that a single connect() call which is hanging + does not block other connections from proceeding.""" + + dbapi = Mock() + mutex = threading.Lock() + + def hanging_dbapi(): + time.sleep(2) + with mutex: + return dbapi.connect() + + def fast_dbapi(): + with mutex: + return dbapi.connect() + + creator = threading.local() + + def create(): + return creator.mock_connector() + + def run_test(name, pool, should_hang): + if should_hang: + creator.mock_connector = hanging_dbapi + else: + creator.mock_connector = fast_dbapi + + conn = pool.connect() + conn.operation(name) + time.sleep(1) + conn.close() + + p = pool.QueuePool(creator=create, pool_size=2, max_overflow=3) + + threads = [ + threading.Thread( + target=run_test, args=("success_one", p, False)), + threading.Thread( + target=run_test, args=("success_two", p, False)), + threading.Thread( + target=run_test, args=("overflow_one", p, True)), + threading.Thread( + target=run_test, args=("overflow_two", p, False)), + threading.Thread( + target=run_test, args=("overflow_three", p, False)) + ] + for t in threads: + t.start() + time.sleep(.2) + + for t in threads: + t.join(timeout=join_timeout) + eq_( + dbapi.connect().operation.mock_calls, + [call("success_one"), call("success_two"), + call("overflow_two"), call("overflow_three"), + call("overflow_one")] + ) + + @testing.requires.threading_with_mock def test_waiters_handled(self): """test that threads waiting for connections are handled when the pool is replaced. """ + mutex = threading.Lock() dbapi = MockDBAPI() def creator(): - return dbapi.connect() + mutex.acquire() + try: + return dbapi.connect() + finally: + mutex.release() success = [] for timeout in (None, 30): @@ -897,21 +1018,27 @@ class QueuePoolTest(PoolTestBase): c1 = p.connect() c2 = p.connect() + threads = [] for i in range(2): t = threading.Thread(target=waiter, args=(p, timeout, max_overflow)) - t.setDaemon(True) # so the tests dont hang if this fails + t.daemon = True t.start() + threads.append(t) - c1.invalidate() - c2.invalidate() - p2 = p._replace() + # this sleep makes sure that the + # two waiter threads hit upon wait() + # inside the queue, before we invalidate the other + # two conns time.sleep(.2) + p2 = p._replace() + + for t in threads: + t.join(join_timeout) eq_(len(success), 12, "successes: %s" % success) @testing.requires.threading_with_mock - @testing.requires.python26 def test_notify_waiters(self): dbapi = MockDBAPI() canary = [] @@ -924,9 +1051,7 @@ class QueuePoolTest(PoolTestBase): p1 = pool.QueuePool(creator=creator1, pool_size=1, timeout=None, max_overflow=0) - p2 = pool.QueuePool(creator=creator2, - pool_size=1, timeout=None, - max_overflow=-1) + p2 = pool.NullPool(creator=creator2) def waiter(p): conn = p.connect() time.sleep(.5) @@ -934,14 +1059,18 @@ class QueuePoolTest(PoolTestBase): c1 = p1.connect() + threads = [] for i in range(5): t = threading.Thread(target=waiter, args=(p1, )) - t.setDaemon(True) t.start() + threads.append(t) time.sleep(.5) eq_(canary, [1]) p1._pool.abort(p2) - time.sleep(1) + + for t in threads: + t.join(join_timeout) + eq_(canary, [1, 2, 2, 2, 2, 2]) def test_dispose_closes_pooled(self): @@ -987,6 +1116,7 @@ class QueuePoolTest(PoolTestBase): self._test_overflow(40, 5) def test_mixed_close(self): + pool._refs.clear() p = self._queuepool_fixture(pool_size=3, max_overflow=-1, use_threadlocal=True) c1 = p.connect() c2 = p.connect() @@ -1198,6 +1328,96 @@ class QueuePoolTest(PoolTestBase): c2 = p.connect() assert c2.connection is not None +class ResetOnReturnTest(PoolTestBase): + def _fixture(self, **kw): + dbapi = Mock() + return dbapi, pool.QueuePool(creator=lambda: dbapi.connect('foo.db'), **kw) + + def test_plain_rollback(self): + dbapi, p = self._fixture(reset_on_return='rollback') + + c1 = p.connect() + c1.close() + assert dbapi.connect().rollback.called + assert not dbapi.connect().commit.called + + def test_plain_commit(self): + dbapi, p = self._fixture(reset_on_return='commit') + + c1 = p.connect() + c1.close() + assert not dbapi.connect().rollback.called + assert dbapi.connect().commit.called + + def test_plain_none(self): + dbapi, p = self._fixture(reset_on_return=None) + + c1 = p.connect() + c1.close() + assert not dbapi.connect().rollback.called + assert not dbapi.connect().commit.called + + def test_agent_rollback(self): + dbapi, p = self._fixture(reset_on_return='rollback') + + class Agent(object): + def __init__(self, conn): + self.conn = conn + + def rollback(self): + self.conn.special_rollback() + + def commit(self): + self.conn.special_commit() + + c1 = p.connect() + c1._reset_agent = Agent(c1) + c1.close() + + assert dbapi.connect().special_rollback.called + assert not dbapi.connect().special_commit.called + + assert not dbapi.connect().rollback.called + assert not dbapi.connect().commit.called + + c1 = p.connect() + c1.close() + eq_(dbapi.connect().special_rollback.call_count, 1) + eq_(dbapi.connect().special_commit.call_count, 0) + + assert dbapi.connect().rollback.called + assert not dbapi.connect().commit.called + + def test_agent_commit(self): + dbapi, p = self._fixture(reset_on_return='commit') + + class Agent(object): + def __init__(self, conn): + self.conn = conn + + def rollback(self): + self.conn.special_rollback() + + def commit(self): + self.conn.special_commit() + + c1 = p.connect() + c1._reset_agent = Agent(c1) + c1.close() + assert not dbapi.connect().special_rollback.called + assert dbapi.connect().special_commit.called + + assert not dbapi.connect().rollback.called + assert not dbapi.connect().commit.called + + c1 = p.connect() + c1.close() + + eq_(dbapi.connect().special_rollback.call_count, 0) + eq_(dbapi.connect().special_commit.call_count, 1) + assert not dbapi.connect().rollback.called + assert dbapi.connect().commit.called + class SingletonThreadPoolTest(PoolTestBase): @testing.requires.threading_with_mock @@ -1245,7 +1465,7 @@ class SingletonThreadPoolTest(PoolTestBase): th.start() threads.append(th) for th in threads: - th.join() + th.join(join_timeout) assert len(p._all_conns) == 3 if strong_refs: |