summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--oslo_db/sqlalchemy/compat/__init__.py2
-rw-r--r--oslo_db/sqlalchemy/engines.py14
-rw-r--r--oslo_db/sqlalchemy/exc_filters.py27
-rw-r--r--oslo_db/tests/sqlalchemy/test_exc_filters.py335
4 files changed, 267 insertions, 111 deletions
diff --git a/oslo_db/sqlalchemy/compat/__init__.py b/oslo_db/sqlalchemy/compat/__init__.py
index 6713696..d209207 100644
--- a/oslo_db/sqlalchemy/compat/__init__.py
+++ b/oslo_db/sqlalchemy/compat/__init__.py
@@ -18,6 +18,8 @@ from sqlalchemy import __version__
_vers = versionutils.convert_version_to_tuple(__version__)
sqla_2 = _vers >= (2, )
+native_pre_ping_event_support = _vers >= (2, 0, 5)
+
def dialect_from_exception_context(ctx):
if sqla_2:
diff --git a/oslo_db/sqlalchemy/engines.py b/oslo_db/sqlalchemy/engines.py
index 146d189..7c36c8a 100644
--- a/oslo_db/sqlalchemy/engines.py
+++ b/oslo_db/sqlalchemy/engines.py
@@ -60,6 +60,12 @@ def _connect_ping_listener(connection, branch):
Ping the server at transaction begin and transparently reconnect
if a disconnect exception occurs.
+ This listener is used up until SQLAlchemy 2.0.5. At 2.0.5, we use the
+ ``pool_pre_ping`` parameter instead of this event handler.
+
+ Note the current test suite in test_exc_filters still **tests** this
+ handler using all SQLAlchemy versions including 2.0.5 and greater.
+
"""
if branch:
return
@@ -199,8 +205,11 @@ def create_engine(sql_connection, sqlite_fk=False, mysql_sql_mode=None,
_vet_url(url)
+ _native_pre_ping = compat.native_pre_ping_event_support
+
engine_args = {
'pool_recycle': connection_recycle_time,
+ 'pool_pre_ping': _native_pre_ping,
'connect_args': {},
'logging_name': logging_name
}
@@ -236,9 +245,10 @@ def create_engine(sql_connection, sqlite_fk=False, mysql_sql_mode=None,
# register alternate exception handler
exc_filters.register_engine(engine)
- # register engine connect handler
+ if not _native_pre_ping:
+ # register engine connect handler.
- event.listen(engine, "engine_connect", _connect_ping_listener)
+ event.listen(engine, "engine_connect", _connect_ping_listener)
# initial connect + test
# NOTE(viktors): the current implementation of _test_connection()
diff --git a/oslo_db/sqlalchemy/exc_filters.py b/oslo_db/sqlalchemy/exc_filters.py
index e578987..420b5c7 100644
--- a/oslo_db/sqlalchemy/exc_filters.py
+++ b/oslo_db/sqlalchemy/exc_filters.py
@@ -20,7 +20,7 @@ from sqlalchemy import event
from sqlalchemy import exc as sqla_exc
from oslo_db import exception
-
+from oslo_db.sqlalchemy import compat
LOG = logging.getLogger(__name__)
@@ -377,6 +377,7 @@ def _raise_operational_errors_directly_filter(operational_error,
def _is_db_connection_error(operational_error, match, engine_name,
is_disconnect):
"""Detect the exception as indicating a recoverable error on connect."""
+
raise exception.DBConnectionError(operational_error)
@@ -423,13 +424,14 @@ def handler(context):
more specific exception class are attempted first.
"""
- def _dialect_registries(engine):
- if engine.dialect.name in _registry:
- yield _registry[engine.dialect.name]
+ def _dialect_registries(dialect):
+ if dialect.name in _registry:
+ yield _registry[dialect.name]
if '*' in _registry:
yield _registry['*']
- for per_dialect in _dialect_registries(context.engine):
+ dialect = compat.dialect_from_exception_context(context)
+ for per_dialect in _dialect_registries(dialect):
for exc in (
context.sqlalchemy_exception,
context.original_exception):
@@ -443,7 +445,7 @@ def handler(context):
fn(
exc,
match,
- context.engine.dialect.name,
+ dialect.name,
context.is_disconnect)
except exception.DBError as dbe:
if (
@@ -460,6 +462,19 @@ def handler(context):
if isinstance(
dbe, exception.DBConnectionError):
context.is_disconnect = True
+
+ # new in 2.0.5
+ if (
+ hasattr(context, "is_pre_ping") and
+ context.is_pre_ping
+ ):
+ # if this is a pre-ping, need to
+ # integrate with the built
+ # in pre-ping handler that doesnt know
+ # about DBConnectionError, just needs
+ # the updated status
+ return None
+
return dbe
diff --git a/oslo_db/tests/sqlalchemy/test_exc_filters.py b/oslo_db/tests/sqlalchemy/test_exc_filters.py
index af3cd91..796ba6c 100644
--- a/oslo_db/tests/sqlalchemy/test_exc_filters.py
+++ b/oslo_db/tests/sqlalchemy/test_exc_filters.py
@@ -1190,41 +1190,16 @@ class IntegrationTest(db_test_base._DbTestCase):
self.assertIn("no such function", str(matched))
-class TestDBDisconnected(TestsExceptionFilter):
-
- @contextlib.contextmanager
- def _fixture(
- self,
- dialect_name, exception, num_disconnects, is_disconnect=True):
- engine = self.engine
-
- event.listen(engine, "engine_connect", engines._connect_ping_listener)
-
- real_do_execute = engine.dialect.do_execute
- counter = itertools.count(1)
-
- def fake_do_execute(self, *arg, **kw):
- if next(counter) > num_disconnects:
- return real_do_execute(self, *arg, **kw)
- else:
- raise exception
-
- with self._dbapi_fixture(dialect_name):
- with test_utils.nested(
- mock.patch.object(engine.dialect,
- "do_execute",
- fake_do_execute),
- mock.patch.object(engine.dialect,
- "is_disconnect",
- mock.Mock(return_value=is_disconnect))
- ):
- yield
+class TestDBDisconnectedFixture(TestsExceptionFilter):
+ native_pre_ping = False
def _test_ping_listener_disconnected(
self, dialect_name, exc_obj, is_disconnect=True,
):
- with self._fixture(dialect_name, exc_obj, 1, is_disconnect):
- conn = self.engine.connect()
+ with self._fixture(
+ dialect_name, exc_obj, False, is_disconnect,
+ ) as engine:
+ conn = engine.connect()
with conn.begin():
self.assertEqual(
1, conn.execute(sqla.select(1)).scalars().first(),
@@ -1233,19 +1208,145 @@ class TestDBDisconnected(TestsExceptionFilter):
self.assertFalse(conn.invalidated)
self.assertTrue(conn.in_transaction())
- with self._fixture(dialect_name, exc_obj, 2, is_disconnect):
+ with self._fixture(
+ dialect_name, exc_obj, True, is_disconnect,
+ ) as engine:
self.assertRaises(
exception.DBConnectionError,
- self.engine.connect
+ engine.connect
)
# test implicit execution
- with self._fixture(dialect_name, exc_obj, 1):
- with self.engine.connect() as conn:
+ with self._fixture(dialect_name, exc_obj, False) as engine:
+ with engine.connect() as conn:
self.assertEqual(
1, conn.execute(sqla.select(1)).scalars().first(),
)
+ @contextlib.contextmanager
+ def _fixture(
+ self,
+ dialect_name,
+ exception,
+ db_stays_down,
+ is_disconnect=True,
+ ):
+ """Fixture for testing the ping listener.
+
+ For SQLAlchemy 2.0, the mocking is placed more deeply in the
+ stack within the DBAPI connection / cursor so that we can also
+ effectively mock out the "pre ping" condition.
+
+ :param dialect_name: dialect to use. "postgresql" or "mysql"
+ :param exception: an exception class to raise
+ :param db_stays_down: if True, the database will stay down after the
+ first ping fails
+ :param is_disconnect: whether or not the SQLAlchemy dialect should
+ consider the exception object as a "disconnect error". Openstack's
+ own exception handlers upgrade various DB exceptions to be
+ "disconnect" scenarios that SQLAlchemy itself does not, such as
+ some specific Galera error messages.
+
+ The importance of an exception being a "disconnect error" means that
+ SQLAlchemy knows it can discard the connection and then reconnect.
+ If the error is not a "disconnection error", then it raises.
+ """
+ connect_args = {}
+ patchers = []
+ db_disconnected = False
+
+ class DisconnectCursorMixin:
+ def execute(self, *arg, **kw):
+ if db_disconnected:
+ raise exception
+ else:
+ return super().execute(*arg, **kw)
+
+ if dialect_name == "postgresql":
+ import psycopg2.extensions
+
+ class Curs(DisconnectCursorMixin, psycopg2.extensions.cursor):
+ pass
+
+ connect_args = {"cursor_factory": Curs}
+
+ elif dialect_name == "mysql":
+ import pymysql
+
+ def fake_ping(self, *arg, **kw):
+ if db_disconnected:
+ raise exception
+ else:
+ return True
+
+ class Curs(DisconnectCursorMixin, pymysql.cursors.Cursor):
+ pass
+
+ connect_args = {"cursorclass": Curs}
+
+ patchers.append(
+ mock.patch.object(
+ pymysql.Connection, "ping", fake_ping
+ )
+ )
+ else:
+ raise NotImplementedError()
+
+ with mock.patch.object(
+ compat,
+ "native_pre_ping_event_support",
+ self.native_pre_ping,
+ ):
+ engine = engines.create_engine(
+ self.engine.url, max_retries=0)
+
+ # 1. override how we connect. if we want the DB to be down
+ # for the moment, but recover, reset db_disconnected after
+ # connect is called. If we want the DB to stay down, then
+ # make sure connect raises the error also.
+ @event.listens_for(engine, "do_connect")
+ def _connect(dialect, connrec, cargs, cparams):
+ nonlocal db_disconnected
+
+ # while we're here, add our cursor classes to the DBAPI
+ # connect args
+ cparams.update(connect_args)
+
+ if db_disconnected:
+ if db_stays_down:
+ raise exception
+ else:
+ db_disconnected = False
+
+ # 2. initialize the dialect with a first connect
+ conn = engine.connect()
+ conn.close()
+
+ # 3. add additional patchers
+ patchers.extend([
+ mock.patch.object(
+ engine.dialect.dbapi,
+ "Error",
+ self.Error,
+ ),
+ mock.patch.object(
+ engine.dialect,
+ "is_disconnect",
+ mock.Mock(return_value=is_disconnect),
+ ),
+ ])
+
+ with test_utils.nested(*patchers):
+ # "disconnect" the DB
+ db_disconnected = True
+ yield engine
+
+
+class MySQLPrePingHandlerTests(
+ db_test_base._MySQLOpportunisticTestCase,
+ TestDBDisconnectedFixture,
+):
+
def test_mariadb_error_1927(self):
for code in [1927]:
self._test_ping_listener_disconnected(
@@ -1298,6 +1399,26 @@ class TestDBDisconnected(TestsExceptionFilter):
is_disconnect=False
)
+ def test_mysql_w_disconnect_flag(self):
+ for code in [2002, 2003, 2002]:
+ self._test_ping_listener_disconnected(
+ "mysql",
+ self.OperationalError('%d MySQL server has gone away' % code)
+ )
+
+ def test_mysql_wo_disconnect_flag(self):
+ for code in [2002, 2003]:
+ self._test_ping_listener_disconnected(
+ "mysql",
+ self.OperationalError('%d MySQL server has gone away' % code),
+ is_disconnect=False
+ )
+
+
+class PostgreSQLPrePingHandlerTests(
+ db_test_base._PostgreSQLOpportunisticTestCase,
+ TestDBDisconnectedFixture):
+
def test_postgresql_ping_listener_disconnected(self):
self._test_ping_listener_disconnected(
"postgresql",
@@ -1314,79 +1435,18 @@ class TestDBDisconnected(TestsExceptionFilter):
)
-class TestDBConnectRetry(TestsExceptionFilter):
-
- def _run_test(self, dialect_name, exception, count, retries):
- counter = itertools.count()
-
- engine = self.engine
-
- # empty out the connection pool
- engine.dispose()
-
- connect_fn = engine.dialect.connect
-
- def cant_connect(*arg, **kw):
- if next(counter) < count:
- raise exception
- else:
- return connect_fn(*arg, **kw)
-
- with self._dbapi_fixture(dialect_name):
- with mock.patch.object(engine.dialect, "connect", cant_connect):
- return engines._test_connection(engine, retries, .01)
-
- def test_connect_no_retries(self):
- conn = self._run_test(
- "mysql",
- self.OperationalError("Error: (2003) something wrong"),
- 2, 0
- )
- # didnt connect because nothing was tried
- self.assertIsNone(conn)
-
- def test_connect_inifinite_retries(self):
- conn = self._run_test(
- "mysql",
- self.OperationalError("Error: (2003) something wrong"),
- 2, -1
- )
- # conn is good
- self.assertEqual(1, conn.scalar(sqla.select(1)))
-
- def test_connect_retry_past_failure(self):
- conn = self._run_test(
- "mysql",
- self.OperationalError("Error: (2003) something wrong"),
- 2, 3
- )
- # conn is good
- self.assertEqual(1, conn.scalar(sqla.select(1)))
-
- def test_connect_retry_not_candidate_exception(self):
- self.assertRaises(
- sqla.exc.OperationalError, # remember, we pass OperationalErrors
- # through at the moment :)
- self._run_test,
- "mysql",
- self.OperationalError("Error: (2015) I can't connect period"),
- 2, 3
- )
+if compat.sqla_2:
+ class MySQLNativePrePingTests(MySQLPrePingHandlerTests):
+ native_pre_ping = True
- def test_connect_retry_stops_infailure(self):
- self.assertRaises(
- exception.DBConnectionError,
- self._run_test,
- "mysql",
- self.OperationalError("Error: (2003) something wrong"),
- 3, 2
- )
+ class PostgreSQLNativePrePingTests(PostgreSQLPrePingHandlerTests):
+ native_pre_ping = True
-class TestDBConnectPingWrapping(TestsExceptionFilter):
+class TestDBConnectPingListener(TestsExceptionFilter):
def setUp(self):
- super(TestDBConnectPingWrapping, self).setUp()
+ super().setUp()
event.listen(
self.engine, "engine_connect", engines._connect_ping_listener)
@@ -1475,6 +1535,75 @@ class TestDBConnectPingWrapping(TestsExceptionFilter):
)
+class TestDBConnectRetry(TestsExceptionFilter):
+
+ def _run_test(self, dialect_name, exception, count, retries):
+ counter = itertools.count()
+
+ engine = self.engine
+
+ # empty out the connection pool
+ engine.dispose()
+
+ connect_fn = engine.dialect.connect
+
+ def cant_connect(*arg, **kw):
+ if next(counter) < count:
+ raise exception
+ else:
+ return connect_fn(*arg, **kw)
+
+ with self._dbapi_fixture(dialect_name):
+ with mock.patch.object(engine.dialect, "connect", cant_connect):
+ return engines._test_connection(engine, retries, .01)
+
+ def test_connect_no_retries(self):
+ conn = self._run_test(
+ "mysql",
+ self.OperationalError("Error: (2003) something wrong"),
+ 2, 0
+ )
+ # didnt connect because nothing was tried
+ self.assertIsNone(conn)
+
+ def test_connect_inifinite_retries(self):
+ conn = self._run_test(
+ "mysql",
+ self.OperationalError("Error: (2003) something wrong"),
+ 2, -1
+ )
+ # conn is good
+ self.assertEqual(1, conn.scalar(sqla.select(1)))
+
+ def test_connect_retry_past_failure(self):
+ conn = self._run_test(
+ "mysql",
+ self.OperationalError("Error: (2003) something wrong"),
+ 2, 3
+ )
+ # conn is good
+ self.assertEqual(1, conn.scalar(sqla.select(1)))
+
+ def test_connect_retry_not_candidate_exception(self):
+ self.assertRaises(
+ sqla.exc.OperationalError, # remember, we pass OperationalErrors
+ # through at the moment :)
+ self._run_test,
+ "mysql",
+ self.OperationalError("Error: (2015) I can't connect period"),
+ 2, 3
+ )
+
+ def test_connect_retry_stops_infailure(self):
+ self.assertRaises(
+ exception.DBConnectionError,
+ self._run_test,
+ "mysql",
+ self.OperationalError("Error: (2003) something wrong"),
+ 3, 2
+ )
+
+
class TestsErrorHandler(TestsExceptionFilter):
def test_multiple_error_handlers(self):
handler = mock.MagicMock(return_value=None)