diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2009-06-10 21:18:24 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2009-06-10 21:18:24 +0000 |
| commit | 45cec095b4904ba71425d2fe18c143982dd08f43 (patch) | |
| tree | af5e540fdcbf1cb2a3337157d69d4b40be010fa8 /test/engine/test_reconnect.py | |
| parent | 698a3c1ac665e7cd2ef8d5ad3ebf51b7fe6661f4 (diff) | |
| download | sqlalchemy-45cec095b4904ba71425d2fe18c143982dd08f43.tar.gz | |
- unit tests have been migrated from unittest to nose.
See README.unittests for information on how to run
the tests. [ticket:970]
Diffstat (limited to 'test/engine/test_reconnect.py')
| -rw-r--r-- | test/engine/test_reconnect.py | 354 |
1 files changed, 354 insertions, 0 deletions
diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py new file mode 100644 index 000000000..3a525c2a7 --- /dev/null +++ b/test/engine/test_reconnect.py @@ -0,0 +1,354 @@ +from sqlalchemy.test.testing import eq_ +import weakref +from sqlalchemy import select, MetaData, Integer, String, pool +from sqlalchemy.test.schema import Table +from sqlalchemy.test.schema import Column +import sqlalchemy as tsa +from sqlalchemy.test import TestBase, testing, engines +import time +import gc + +class MockDisconnect(Exception): + pass + +class MockDBAPI(object): + def __init__(self): + self.paramstyle = 'named' + self.connections = weakref.WeakKeyDictionary() + def connect(self, *args, **kwargs): + return MockConnection(self) + def shutdown(self): + for c in self.connections: + c.explode[0] = True + Error = MockDisconnect + +class MockConnection(object): + def __init__(self, dbapi): + dbapi.connections[self] = True + self.explode = [False] + def rollback(self): + pass + def commit(self): + pass + def cursor(self): + return MockCursor(self) + def close(self): + pass + +class MockCursor(object): + def __init__(self, parent): + self.explode = parent.explode + self.description = () + def execute(self, *args, **kwargs): + if self.explode[0]: + raise MockDisconnect("Lost the DB connection") + else: + return + def close(self): + pass + +db, dbapi = None, None +class MockReconnectTest(TestBase): + def setup(self): + global db, dbapi + dbapi = MockDBAPI() + + # create engine using our current dburi + db = tsa.create_engine('postgres://foo:bar@localhost/test', module=dbapi) + + # monkeypatch disconnect checker + db.dialect.is_disconnect = lambda e: isinstance(e, MockDisconnect) + + def test_reconnect(self): + """test that an 'is_disconnect' condition will invalidate the connection, and additionally + dispose the previous connection pool and recreate.""" + + + pid = id(db.pool) + + # make a connection + conn = db.connect() + + # connection works + conn.execute(select([1])) + + # create a second connection within the pool, which we'll ensure also goes away + conn2 = db.connect() + conn2.close() + + # two connections opened total now + assert len(dbapi.connections) == 2 + + # set it to fail + dbapi.shutdown() + + try: + conn.execute(select([1])) + assert False + except tsa.exc.DBAPIError: + pass + + # assert was invalidated + assert not conn.closed + assert conn.invalidated + + # close shouldnt break + conn.close() + + assert id(db.pool) != pid + + # ensure all connections closed (pool was recycled) + gc.collect() + assert len(dbapi.connections) == 0 + + conn =db.connect() + conn.execute(select([1])) + conn.close() + assert len(dbapi.connections) == 1 + + def test_invalidate_trans(self): + conn = db.connect() + trans = conn.begin() + dbapi.shutdown() + + try: + conn.execute(select([1])) + assert False + except tsa.exc.DBAPIError: + pass + + # assert was invalidated + gc.collect() + assert len(dbapi.connections) == 0 + assert not conn.closed + assert conn.invalidated + assert trans.is_active + + try: + conn.execute(select([1])) + assert False + except tsa.exc.InvalidRequestError, e: + assert str(e) == "Can't reconnect until invalid transaction is rolled back" + + assert trans.is_active + + try: + trans.commit() + assert False + except tsa.exc.InvalidRequestError, e: + assert str(e) == "Can't reconnect until invalid transaction is rolled back" + + assert trans.is_active + + trans.rollback() + assert not trans.is_active + + conn.execute(select([1])) + assert not conn.invalidated + + assert len(dbapi.connections) == 1 + + def test_conn_reusable(self): + conn = db.connect() + + conn.execute(select([1])) + + assert len(dbapi.connections) == 1 + + dbapi.shutdown() + + # raises error + try: + conn.execute(select([1])) + assert False + except tsa.exc.DBAPIError: + pass + + assert not conn.closed + assert conn.invalidated + + # ensure all connections closed (pool was recycled) + gc.collect() + assert len(dbapi.connections) == 0 + + # test reconnects + conn.execute(select([1])) + assert not conn.invalidated + assert len(dbapi.connections) == 1 + +engine = None +class RealReconnectTest(TestBase): + def setup(self): + global engine + engine = engines.reconnecting_engine() + + def teardown(self): + engine.dispose() + + def test_reconnect(self): + conn = engine.connect() + + eq_(conn.execute(select([1])).scalar(), 1) + assert not conn.closed + + engine.test_shutdown() + + try: + conn.execute(select([1])) + assert False + except tsa.exc.DBAPIError, e: + if not e.connection_invalidated: + raise + + assert not conn.closed + assert conn.invalidated + + assert conn.invalidated + eq_(conn.execute(select([1])).scalar(), 1) + assert not conn.invalidated + + # one more time + engine.test_shutdown() + try: + conn.execute(select([1])) + assert False + except tsa.exc.DBAPIError, e: + if not e.connection_invalidated: + raise + assert conn.invalidated + eq_(conn.execute(select([1])).scalar(), 1) + assert not conn.invalidated + + conn.close() + + def test_null_pool(self): + engine = engines.reconnecting_engine(options=dict(poolclass=pool.NullPool)) + conn = engine.connect() + eq_(conn.execute(select([1])).scalar(), 1) + assert not conn.closed + engine.test_shutdown() + try: + conn.execute(select([1])) + assert False + except tsa.exc.DBAPIError, e: + if not e.connection_invalidated: + raise + assert not conn.closed + assert conn.invalidated + eq_(conn.execute(select([1])).scalar(), 1) + assert not conn.invalidated + + def test_close(self): + conn = engine.connect() + eq_(conn.execute(select([1])).scalar(), 1) + assert not conn.closed + + engine.test_shutdown() + + try: + conn.execute(select([1])) + assert False + except tsa.exc.DBAPIError, e: + if not e.connection_invalidated: + raise + + conn.close() + conn = engine.connect() + eq_(conn.execute(select([1])).scalar(), 1) + + def test_with_transaction(self): + conn = engine.connect() + + trans = conn.begin() + + eq_(conn.execute(select([1])).scalar(), 1) + assert not conn.closed + + engine.test_shutdown() + + try: + conn.execute(select([1])) + assert False + except tsa.exc.DBAPIError, e: + if not e.connection_invalidated: + raise + + assert not conn.closed + assert conn.invalidated + assert trans.is_active + + try: + conn.execute(select([1])) + assert False + except tsa.exc.InvalidRequestError, e: + assert str(e) == "Can't reconnect until invalid transaction is rolled back" + + assert trans.is_active + + try: + trans.commit() + assert False + except tsa.exc.InvalidRequestError, e: + assert str(e) == "Can't reconnect until invalid transaction is rolled back" + + assert trans.is_active + + trans.rollback() + assert not trans.is_active + + assert conn.invalidated + eq_(conn.execute(select([1])).scalar(), 1) + assert not conn.invalidated + +class RecycleTest(TestBase): + def test_basic(self): + for threadlocal in (False, True): + engine = engines.reconnecting_engine(options={'pool_recycle':1, 'pool_threadlocal':threadlocal}) + + conn = engine.contextual_connect() + eq_(conn.execute(select([1])).scalar(), 1) + conn.close() + + engine.test_shutdown() + time.sleep(2) + + conn = engine.contextual_connect() + eq_(conn.execute(select([1])).scalar(), 1) + conn.close() + +meta, table, engine = None, None, None +class InvalidateDuringResultTest(TestBase): + def setup(self): + global meta, table, engine + engine = engines.reconnecting_engine() + meta = MetaData(engine) + table = Table('sometable', meta, + Column('id', Integer, primary_key=True), + Column('name', String(50))) + meta.create_all() + table.insert().execute( + [{'id':i, 'name':'row %d' % i} for i in range(1, 100)] + ) + + def teardown(self): + meta.drop_all() + engine.dispose() + + @testing.fails_on('mysql', 'FIXME: unknown') + def test_invalidate_on_results(self): + conn = engine.connect() + + result = conn.execute("select * from sometable") + for x in xrange(20): + result.fetchone() + + engine.test_shutdown() + try: + result.fetchone() + assert False + except tsa.exc.DBAPIError, e: + if not e.connection_invalidated: + raise + + assert conn.invalidated + |
