diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/engine/base.py | 67 |
1 files changed, 52 insertions, 15 deletions
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 6467e91b9..7aee1a73b 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -708,7 +708,10 @@ class Connection(Connectable): def in_transaction(self): """Return True if a transaction is in progress.""" - return self._root.__transaction is not None + return ( + self._root.__transaction is not None + and self._root.__transaction.is_active + ) def _begin_impl(self, transaction): assert not self.__branch_from @@ -726,7 +729,7 @@ class Connection(Connectable): except BaseException as e: self._handle_dbapi_exception(e, None, None, None, None) - def _rollback_impl(self): + def _rollback_impl(self, deactivate_only=False): assert not self.__branch_from if self._has_events or self.engine._has_events: @@ -745,9 +748,6 @@ class Connection(Connectable): and self.connection._reset_agent is self.__transaction ): self.connection._reset_agent = None - self.__transaction = None - else: - self.__transaction = None def _commit_impl(self, autocommit=False): assert not self.__branch_from @@ -782,7 +782,18 @@ class Connection(Connectable): self.engine.dialect.do_savepoint(self, name) return name - def _rollback_to_savepoint_impl(self, name, context): + def _discard_transaction(self, trans): + if trans is self.__transaction: + if trans._is_root: + assert trans._parent is trans + self.__transaction = None + else: + assert trans._parent is not trans + self.__transaction = trans._parent + + def _rollback_to_savepoint_impl( + self, name, context, deactivate_only=False + ): assert not self.__branch_from if self._has_events or self.engine._has_events: @@ -790,7 +801,6 @@ class Connection(Connectable): if self._still_open_and_connection_is_valid: self.engine.dialect.do_rollback_to_savepoint(self, name) - self.__transaction = context def _release_savepoint_impl(self, name, context): assert not self.__branch_from @@ -1182,6 +1192,17 @@ class Connection(Connectable): e, util.text_type(statement), parameters, None, None ) + if self._root.__transaction and not self._root.__transaction.is_active: + raise exc.InvalidRequestError( + "This connection is on an inactive %stransaction. " + "Please rollback() fully before proceeding." + % ( + "savepoint " + if isinstance(self.__transaction, NestedTransaction) + else "" + ), + code="8s2a", + ) if context.compiled: context.pre_exec() @@ -1671,11 +1692,16 @@ class Transaction(object): single: thread safety; Transaction """ + _is_root = False + def __init__(self, connection, parent): self.connection = connection self._actual_parent = parent self.is_active = True + def _deactivate(self): + self.is_active = False + @property def _parent(self): return self._actual_parent or self @@ -1700,13 +1726,14 @@ class Transaction(object): """Roll back this :class:`.Transaction`. """ - if not self._parent.is_active: - return - self._do_rollback() - self.is_active = False + + if self._parent.is_active: + self._do_rollback() + self.is_active = False + self.connection._discard_transaction(self) def _do_rollback(self): - self._parent.rollback() + self._parent._deactivate() def commit(self): """Commit this :class:`.Transaction`.""" @@ -1734,13 +1761,19 @@ class Transaction(object): class RootTransaction(Transaction): + _is_root = True + def __init__(self, connection): super(RootTransaction, self).__init__(connection, None) self.connection._begin_impl(self) - def _do_rollback(self): + def _deactivate(self): + self._do_rollback(deactivate_only=True) + self.is_active = False + + def _do_rollback(self, deactivate_only=False): if self.is_active: - self.connection._rollback_impl() + self.connection._rollback_impl(deactivate_only=deactivate_only) def _do_commit(self): if self.is_active: @@ -1761,7 +1794,11 @@ class NestedTransaction(Transaction): super(NestedTransaction, self).__init__(connection, parent) self._savepoint = self.connection._savepoint_impl() - def _do_rollback(self): + def _deactivate(self): + self._do_rollback(deactivate_only=True) + self.is_active = False + + def _do_rollback(self, deactivate_only=False): if self.is_active: self.connection._rollback_to_savepoint_impl( self._savepoint, self._parent |
