diff options
Diffstat (limited to 'oslo_db/tests/sqlalchemy/test_exc_filters.py')
-rw-r--r-- | oslo_db/tests/sqlalchemy/test_exc_filters.py | 335 |
1 files changed, 232 insertions, 103 deletions
diff --git a/oslo_db/tests/sqlalchemy/test_exc_filters.py b/oslo_db/tests/sqlalchemy/test_exc_filters.py index af3cd91..796ba6c 100644 --- a/oslo_db/tests/sqlalchemy/test_exc_filters.py +++ b/oslo_db/tests/sqlalchemy/test_exc_filters.py @@ -1190,41 +1190,16 @@ class IntegrationTest(db_test_base._DbTestCase): self.assertIn("no such function", str(matched)) -class TestDBDisconnected(TestsExceptionFilter): - - @contextlib.contextmanager - def _fixture( - self, - dialect_name, exception, num_disconnects, is_disconnect=True): - engine = self.engine - - event.listen(engine, "engine_connect", engines._connect_ping_listener) - - real_do_execute = engine.dialect.do_execute - counter = itertools.count(1) - - def fake_do_execute(self, *arg, **kw): - if next(counter) > num_disconnects: - return real_do_execute(self, *arg, **kw) - else: - raise exception - - with self._dbapi_fixture(dialect_name): - with test_utils.nested( - mock.patch.object(engine.dialect, - "do_execute", - fake_do_execute), - mock.patch.object(engine.dialect, - "is_disconnect", - mock.Mock(return_value=is_disconnect)) - ): - yield +class TestDBDisconnectedFixture(TestsExceptionFilter): + native_pre_ping = False def _test_ping_listener_disconnected( self, dialect_name, exc_obj, is_disconnect=True, ): - with self._fixture(dialect_name, exc_obj, 1, is_disconnect): - conn = self.engine.connect() + with self._fixture( + dialect_name, exc_obj, False, is_disconnect, + ) as engine: + conn = engine.connect() with conn.begin(): self.assertEqual( 1, conn.execute(sqla.select(1)).scalars().first(), @@ -1233,19 +1208,145 @@ class TestDBDisconnected(TestsExceptionFilter): self.assertFalse(conn.invalidated) self.assertTrue(conn.in_transaction()) - with self._fixture(dialect_name, exc_obj, 2, is_disconnect): + with self._fixture( + dialect_name, exc_obj, True, is_disconnect, + ) as engine: self.assertRaises( exception.DBConnectionError, - self.engine.connect + engine.connect ) # test implicit execution - with self._fixture(dialect_name, exc_obj, 1): - with self.engine.connect() as conn: + with self._fixture(dialect_name, exc_obj, False) as engine: + with engine.connect() as conn: self.assertEqual( 1, conn.execute(sqla.select(1)).scalars().first(), ) + @contextlib.contextmanager + def _fixture( + self, + dialect_name, + exception, + db_stays_down, + is_disconnect=True, + ): + """Fixture for testing the ping listener. + + For SQLAlchemy 2.0, the mocking is placed more deeply in the + stack within the DBAPI connection / cursor so that we can also + effectively mock out the "pre ping" condition. + + :param dialect_name: dialect to use. "postgresql" or "mysql" + :param exception: an exception class to raise + :param db_stays_down: if True, the database will stay down after the + first ping fails + :param is_disconnect: whether or not the SQLAlchemy dialect should + consider the exception object as a "disconnect error". Openstack's + own exception handlers upgrade various DB exceptions to be + "disconnect" scenarios that SQLAlchemy itself does not, such as + some specific Galera error messages. + + The importance of an exception being a "disconnect error" means that + SQLAlchemy knows it can discard the connection and then reconnect. + If the error is not a "disconnection error", then it raises. + """ + connect_args = {} + patchers = [] + db_disconnected = False + + class DisconnectCursorMixin: + def execute(self, *arg, **kw): + if db_disconnected: + raise exception + else: + return super().execute(*arg, **kw) + + if dialect_name == "postgresql": + import psycopg2.extensions + + class Curs(DisconnectCursorMixin, psycopg2.extensions.cursor): + pass + + connect_args = {"cursor_factory": Curs} + + elif dialect_name == "mysql": + import pymysql + + def fake_ping(self, *arg, **kw): + if db_disconnected: + raise exception + else: + return True + + class Curs(DisconnectCursorMixin, pymysql.cursors.Cursor): + pass + + connect_args = {"cursorclass": Curs} + + patchers.append( + mock.patch.object( + pymysql.Connection, "ping", fake_ping + ) + ) + else: + raise NotImplementedError() + + with mock.patch.object( + compat, + "native_pre_ping_event_support", + self.native_pre_ping, + ): + engine = engines.create_engine( + self.engine.url, max_retries=0) + + # 1. override how we connect. if we want the DB to be down + # for the moment, but recover, reset db_disconnected after + # connect is called. If we want the DB to stay down, then + # make sure connect raises the error also. + @event.listens_for(engine, "do_connect") + def _connect(dialect, connrec, cargs, cparams): + nonlocal db_disconnected + + # while we're here, add our cursor classes to the DBAPI + # connect args + cparams.update(connect_args) + + if db_disconnected: + if db_stays_down: + raise exception + else: + db_disconnected = False + + # 2. initialize the dialect with a first connect + conn = engine.connect() + conn.close() + + # 3. add additional patchers + patchers.extend([ + mock.patch.object( + engine.dialect.dbapi, + "Error", + self.Error, + ), + mock.patch.object( + engine.dialect, + "is_disconnect", + mock.Mock(return_value=is_disconnect), + ), + ]) + + with test_utils.nested(*patchers): + # "disconnect" the DB + db_disconnected = True + yield engine + + +class MySQLPrePingHandlerTests( + db_test_base._MySQLOpportunisticTestCase, + TestDBDisconnectedFixture, +): + def test_mariadb_error_1927(self): for code in [1927]: self._test_ping_listener_disconnected( @@ -1298,6 +1399,26 @@ class TestDBDisconnected(TestsExceptionFilter): is_disconnect=False ) + def test_mysql_w_disconnect_flag(self): + for code in [2002, 2003, 2002]: + self._test_ping_listener_disconnected( + "mysql", + self.OperationalError('%d MySQL server has gone away' % code) + ) + + def test_mysql_wo_disconnect_flag(self): + for code in [2002, 2003]: + self._test_ping_listener_disconnected( + "mysql", + self.OperationalError('%d MySQL server has gone away' % code), + is_disconnect=False + ) + + +class PostgreSQLPrePingHandlerTests( + db_test_base._PostgreSQLOpportunisticTestCase, + TestDBDisconnectedFixture): + def test_postgresql_ping_listener_disconnected(self): self._test_ping_listener_disconnected( "postgresql", @@ -1314,79 +1435,18 @@ class TestDBDisconnected(TestsExceptionFilter): ) -class TestDBConnectRetry(TestsExceptionFilter): - - def _run_test(self, dialect_name, exception, count, retries): - counter = itertools.count() - - engine = self.engine - - # empty out the connection pool - engine.dispose() - - connect_fn = engine.dialect.connect - - def cant_connect(*arg, **kw): - if next(counter) < count: - raise exception - else: - return connect_fn(*arg, **kw) - - with self._dbapi_fixture(dialect_name): - with mock.patch.object(engine.dialect, "connect", cant_connect): - return engines._test_connection(engine, retries, .01) - - def test_connect_no_retries(self): - conn = self._run_test( - "mysql", - self.OperationalError("Error: (2003) something wrong"), - 2, 0 - ) - # didnt connect because nothing was tried - self.assertIsNone(conn) - - def test_connect_inifinite_retries(self): - conn = self._run_test( - "mysql", - self.OperationalError("Error: (2003) something wrong"), - 2, -1 - ) - # conn is good - self.assertEqual(1, conn.scalar(sqla.select(1))) - - def test_connect_retry_past_failure(self): - conn = self._run_test( - "mysql", - self.OperationalError("Error: (2003) something wrong"), - 2, 3 - ) - # conn is good - self.assertEqual(1, conn.scalar(sqla.select(1))) - - def test_connect_retry_not_candidate_exception(self): - self.assertRaises( - sqla.exc.OperationalError, # remember, we pass OperationalErrors - # through at the moment :) - self._run_test, - "mysql", - self.OperationalError("Error: (2015) I can't connect period"), - 2, 3 - ) +if compat.sqla_2: + class MySQLNativePrePingTests(MySQLPrePingHandlerTests): + native_pre_ping = True - def test_connect_retry_stops_infailure(self): - self.assertRaises( - exception.DBConnectionError, - self._run_test, - "mysql", - self.OperationalError("Error: (2003) something wrong"), - 3, 2 - ) + class PostgreSQLNativePrePingTests(PostgreSQLPrePingHandlerTests): + native_pre_ping = True -class TestDBConnectPingWrapping(TestsExceptionFilter): +class TestDBConnectPingListener(TestsExceptionFilter): def setUp(self): - super(TestDBConnectPingWrapping, self).setUp() + super().setUp() event.listen( self.engine, "engine_connect", engines._connect_ping_listener) @@ -1475,6 +1535,75 @@ class TestDBConnectPingWrapping(TestsExceptionFilter): ) +class TestDBConnectRetry(TestsExceptionFilter): + + def _run_test(self, dialect_name, exception, count, retries): + counter = itertools.count() + + engine = self.engine + + # empty out the connection pool + engine.dispose() + + connect_fn = engine.dialect.connect + + def cant_connect(*arg, **kw): + if next(counter) < count: + raise exception + else: + return connect_fn(*arg, **kw) + + with self._dbapi_fixture(dialect_name): + with mock.patch.object(engine.dialect, "connect", cant_connect): + return engines._test_connection(engine, retries, .01) + + def test_connect_no_retries(self): + conn = self._run_test( + "mysql", + self.OperationalError("Error: (2003) something wrong"), + 2, 0 + ) + # didnt connect because nothing was tried + self.assertIsNone(conn) + + def test_connect_inifinite_retries(self): + conn = self._run_test( + "mysql", + self.OperationalError("Error: (2003) something wrong"), + 2, -1 + ) + # conn is good + self.assertEqual(1, conn.scalar(sqla.select(1))) + + def test_connect_retry_past_failure(self): + conn = self._run_test( + "mysql", + self.OperationalError("Error: (2003) something wrong"), + 2, 3 + ) + # conn is good + self.assertEqual(1, conn.scalar(sqla.select(1))) + + def test_connect_retry_not_candidate_exception(self): + self.assertRaises( + sqla.exc.OperationalError, # remember, we pass OperationalErrors + # through at the moment :) + self._run_test, + "mysql", + self.OperationalError("Error: (2015) I can't connect period"), + 2, 3 + ) + + def test_connect_retry_stops_infailure(self): + self.assertRaises( + exception.DBConnectionError, + self._run_test, + "mysql", + self.OperationalError("Error: (2003) something wrong"), + 3, 2 + ) + + class TestsErrorHandler(TestsExceptionFilter): def test_multiple_error_handlers(self): handler = mock.MagicMock(return_value=None) |