diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-03-07 11:17:47 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-03-07 11:18:25 -0500 |
commit | a8102ba496c4c11eae6b904a962cf352902f0de7 (patch) | |
tree | f181217c8f5c6a9ef7e354c8cb86561ae880d949 | |
parent | de95dc4ce5ed44cc63d9fd8b2e00a78858a73d2a (diff) | |
download | sqlalchemy-a8102ba496c4c11eae6b904a962cf352902f0de7.tar.gz |
test sqlite w/ savepoint workaround in session fixture test
Fixes: #7795
Change-Id: Ib790581555656c088f86c00080c70d19ca295a03
(cherry picked from commit fbacb1991585202a5bf22acb0d36b5c979bcfad8)
-rw-r--r-- | lib/sqlalchemy/testing/engines.py | 14 | ||||
-rw-r--r-- | test/orm/test_transaction.py | 12 | ||||
-rw-r--r-- | test/requirements.py | 10 |
3 files changed, 30 insertions, 6 deletions
diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index a92d476ac..b8be6b9bd 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -276,10 +276,12 @@ def testing_engine( future=None, asyncio=False, transfer_staticpool=False, + _sqlite_savepoint=False, ): """Produce an engine configured by --options with optional overrides.""" if asyncio: + assert not _sqlite_savepoint from sqlalchemy.ext.asyncio import ( create_async_engine as create_engine, ) @@ -294,9 +296,11 @@ def testing_engine( if not options: use_reaper = True scope = "function" + sqlite_savepoint = False else: use_reaper = options.pop("use_reaper", True) scope = options.pop("scope", "function") + sqlite_savepoint = options.pop("sqlite_savepoint", False) url = url or config.db.url @@ -312,6 +316,16 @@ def testing_engine( engine = create_engine(url, **options) + if sqlite_savepoint and engine.name == "sqlite": + # apply SQLite savepoint workaround + @event.listens_for(engine, "connect") + def do_connect(dbapi_connection, connection_record): + dbapi_connection.isolation_level = None + + @event.listens_for(engine, "begin") + def do_begin(conn): + conn.exec_driver_sql("BEGIN") + if transfer_staticpool: from sqlalchemy.pool import StaticPool diff --git a/test/orm/test_transaction.py b/test/orm/test_transaction.py index 603ec079a..e077220e1 100644 --- a/test/orm/test_transaction.py +++ b/test/orm/test_transaction.py @@ -2526,10 +2526,10 @@ class NaturalPKRollbackTest(fixtures.MappedTest): class JoinIntoAnExternalTransactionFixture(object): """Test the "join into an external transaction" examples""" - __leave_connections_for_teardown__ = True - def setup_test(self): - self.engine = testing.db + self.engine = engines.testing_engine( + options={"use_reaper": False, "sqlite_savepoint": True} + ) self.connection = self.engine.connect() self.metadata = MetaData() @@ -2590,7 +2590,7 @@ class NewStyleJoinIntoAnExternalTransactionTest( # bind an individual Session to the connection self.session = Session(bind=self.connection, future=True) - if testing.requires.savepoints.enabled: + if testing.requires.compat_savepoints.enabled: self.nested = self.connection.begin_nested() @event.listens_for(self.session, "after_transaction_end") @@ -2607,7 +2607,7 @@ class NewStyleJoinIntoAnExternalTransactionTest( if self.trans.is_active: self.trans.rollback() - @testing.requires.savepoints + @testing.requires.compat_savepoints def test_something_with_context_managers(self): A = self.A @@ -2673,7 +2673,7 @@ class LegacyJoinIntoAnExternalTransactionTest( # bind an individual Session to the connection self.session = Session(bind=self.connection) - if testing.requires.savepoints.enabled: + if testing.requires.compat_savepoints.enabled: # start the session in a SAVEPOINT... self.session.begin_nested() diff --git a/test/requirements.py b/test/requirements.py index 1780e3b21..4c9ac40c5 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -559,6 +559,16 @@ class DefaultRequirements(SuiteRequirements): ) @property + def compat_savepoints(self): + """Target database must support savepoints, or a compat + recipe e.g. for sqlite will be used""" + + return skip_if( + ["sybase", ("mysql", "<", (5, 0, 3))], + "savepoints not supported", + ) + + @property def savepoints_w_release(self): return self.savepoints + skip_if( ["oracle", "mssql"], |