From 85568fc596c301563270efe217715f14aea8aa19 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 7 Jun 2019 11:19:23 -0400 Subject: Don't discard inactive transaction until it is explicitly rolled back The :class:`.Connection` object will now not clear a rolled-back transaction until the outermost transaction is explicitly rolled back. This is essentially the same behavior that the ORM :class:`.Session` has had for a long time, where an explicit call to ``.rollback()`` on all enclosing transactions is required for the transaction to logically clear, even though the DBAPI-level transaction has already been rolled back. The new behavior helps with situations such as the "ORM rollback test suite" pattern where the test suite rolls the transaction back within the ORM scope, but the test harness which seeks to control the scope of the transaction externally does not expect a new transaction to start implicitly. Fixes: #4712 Change-Id: Ibc6c8d981cff31594a5d26dd5203fd9cfcea1c74 --- lib/sqlalchemy/engine/base.py | 67 +++++++++++++++++++++++++++++++++---------- 1 file changed, 52 insertions(+), 15 deletions(-) (limited to 'lib/sqlalchemy/engine') diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 6467e91b9..7aee1a73b 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -708,7 +708,10 @@ class Connection(Connectable): def in_transaction(self): """Return True if a transaction is in progress.""" - return self._root.__transaction is not None + return ( + self._root.__transaction is not None + and self._root.__transaction.is_active + ) def _begin_impl(self, transaction): assert not self.__branch_from @@ -726,7 +729,7 @@ class Connection(Connectable): except BaseException as e: self._handle_dbapi_exception(e, None, None, None, None) - def _rollback_impl(self): + def _rollback_impl(self, deactivate_only=False): assert not self.__branch_from if self._has_events or self.engine._has_events: @@ -745,9 +748,6 @@ class Connection(Connectable): and self.connection._reset_agent is self.__transaction ): self.connection._reset_agent = None - self.__transaction = None - else: - self.__transaction = None def _commit_impl(self, autocommit=False): assert not self.__branch_from @@ -782,7 +782,18 @@ class Connection(Connectable): self.engine.dialect.do_savepoint(self, name) return name - def _rollback_to_savepoint_impl(self, name, context): + def _discard_transaction(self, trans): + if trans is self.__transaction: + if trans._is_root: + assert trans._parent is trans + self.__transaction = None + else: + assert trans._parent is not trans + self.__transaction = trans._parent + + def _rollback_to_savepoint_impl( + self, name, context, deactivate_only=False + ): assert not self.__branch_from if self._has_events or self.engine._has_events: @@ -790,7 +801,6 @@ class Connection(Connectable): if self._still_open_and_connection_is_valid: self.engine.dialect.do_rollback_to_savepoint(self, name) - self.__transaction = context def _release_savepoint_impl(self, name, context): assert not self.__branch_from @@ -1182,6 +1192,17 @@ class Connection(Connectable): e, util.text_type(statement), parameters, None, None ) + if self._root.__transaction and not self._root.__transaction.is_active: + raise exc.InvalidRequestError( + "This connection is on an inactive %stransaction. " + "Please rollback() fully before proceeding." + % ( + "savepoint " + if isinstance(self.__transaction, NestedTransaction) + else "" + ), + code="8s2a", + ) if context.compiled: context.pre_exec() @@ -1671,11 +1692,16 @@ class Transaction(object): single: thread safety; Transaction """ + _is_root = False + def __init__(self, connection, parent): self.connection = connection self._actual_parent = parent self.is_active = True + def _deactivate(self): + self.is_active = False + @property def _parent(self): return self._actual_parent or self @@ -1700,13 +1726,14 @@ class Transaction(object): """Roll back this :class:`.Transaction`. """ - if not self._parent.is_active: - return - self._do_rollback() - self.is_active = False + + if self._parent.is_active: + self._do_rollback() + self.is_active = False + self.connection._discard_transaction(self) def _do_rollback(self): - self._parent.rollback() + self._parent._deactivate() def commit(self): """Commit this :class:`.Transaction`.""" @@ -1734,13 +1761,19 @@ class Transaction(object): class RootTransaction(Transaction): + _is_root = True + def __init__(self, connection): super(RootTransaction, self).__init__(connection, None) self.connection._begin_impl(self) - def _do_rollback(self): + def _deactivate(self): + self._do_rollback(deactivate_only=True) + self.is_active = False + + def _do_rollback(self, deactivate_only=False): if self.is_active: - self.connection._rollback_impl() + self.connection._rollback_impl(deactivate_only=deactivate_only) def _do_commit(self): if self.is_active: @@ -1761,7 +1794,11 @@ class NestedTransaction(Transaction): super(NestedTransaction, self).__init__(connection, parent) self._savepoint = self.connection._savepoint_impl() - def _do_rollback(self): + def _deactivate(self): + self._do_rollback(deactivate_only=True) + self.is_active = False + + def _do_rollback(self, deactivate_only=False): if self.is_active: self.connection._rollback_to_savepoint_impl( self._savepoint, self._parent -- cgit v1.2.1