summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorOlly Cope <olly@ollycope.com>2022-10-29 16:05:43 +0000
committerOlly Cope <olly@ollycope.com>2022-10-29 16:05:43 +0000
commit5e66d76c03a4eadb01eb6e88b4480b303a3d9e9e (patch)
tree30795886930a445c2b8f5a0437c445397e1147f5
parent236b839acc6c7a20cb30cd84985f3ece725a6939 (diff)
downloadyoyo-5e66d76c03a4eadb01eb6e88b4480b303a3d9e9e.tar.gz
transaction handling: replace confusing rollback method with function argument
This code: with backend.transaction() as transaction: transaction.rollback() ... Looks like it rolls back the connection at the start of the block - but that's not what is actually happening. The code becomes much clearer with this change: with backend.transaction(rollback_on_exit=True): ...
-rw-r--r--yoyo/backends/base.py22
-rwxr-xr-xyoyo/migrations.py16
-rw-r--r--yoyo/tests/test_backends.py8
3 files changed, 17 insertions, 29 deletions
diff --git a/yoyo/backends/base.py b/yoyo/backends/base.py
index 4f95a66..1f3c167 100644
--- a/yoyo/backends/base.py
+++ b/yoyo/backends/base.py
@@ -45,9 +45,9 @@ class TransactionManager:
when the context manager block closes
"""
- def __init__(self, backend):
+ def __init__(self, backend, rollback_on_exit=False):
self.backend = backend
- self._rollback = False
+ self.rollback_on_exit = rollback_on_exit
def __enter__(self):
self._do_begin()
@@ -58,18 +58,11 @@ class TransactionManager:
self._do_rollback()
return None
- if self._rollback:
+ if self.rollback_on_exit:
self._do_rollback()
else:
self._do_commit()
- def rollback(self):
- """
- Flag that the transaction will be rolled back when the with statement
- exits
- """
- self._rollback = True
-
def _do_begin(self):
"""
Instruct the backend to begin a transaction
@@ -238,9 +231,8 @@ class DatabaseBackend:
table_name_quoted = self.quote_identifier(table_name)
sql = self.create_test_table_sql.format(table_name_quoted=table_name_quoted)
try:
- with self.transaction() as t:
+ with self.transaction(rollback_on_exit=True):
self.execute(sql)
- t.rollback()
except self.DatabaseError:
return False
@@ -263,12 +255,12 @@ class DatabaseBackend:
)
return [row[0] for row in cursor.fetchall()]
- def transaction(self):
+ def transaction(self, rollback_on_exit=False):
if not self._in_transaction:
- return TransactionManager(self)
+ return TransactionManager(self, rollback_on_exit=rollback_on_exit)
else:
- return SavepointTransactionManager(self)
+ return SavepointTransactionManager(self, rollback_on_exit=rollback_on_exit)
def cursor(self):
return self.connection.cursor()
diff --git a/yoyo/migrations.py b/yoyo/migrations.py
index 5450d0a..5b30049 100755
--- a/yoyo/migrations.py
+++ b/yoyo/migrations.py
@@ -312,16 +312,14 @@ class TransactionWrapper(StepBase):
return "<TransactionWrapper {!r}>".format(self.step)
def apply(self, backend, force=False, direction="apply"):
- with backend.transaction() as transaction:
- try:
+ try:
+ with backend.transaction():
getattr(self.step, direction)(backend, force)
- except backend.DatabaseError:
- if force or self.ignore_errors in (direction, "all"):
- logger.exception("Ignored error in %r", self.step)
- transaction.rollback()
- return
- else:
- raise
+ except backend.DatabaseError:
+ if force or self.ignore_errors in (direction, "all"):
+ logger.exception("Ignored error in %r", self.step)
+ else:
+ raise
def rollback(self, backend, force=False):
self.apply(backend, force, "rollback")
diff --git a/yoyo/tests/test_backends.py b/yoyo/tests/test_backends.py
index 8d99009..e4b9727 100644
--- a/yoyo/tests/test_backends.py
+++ b/yoyo/tests/test_backends.py
@@ -44,11 +44,10 @@ class TestTransactionHandling(object):
with backend.transaction():
backend.execute("INSERT INTO yoyo_t values ('A')")
- with backend.transaction() as trans:
+ with backend.transaction(rollback_on_exit=True):
backend.execute("INSERT INTO yoyo_t values ('B')")
- trans.rollback()
- with backend.transaction() as trans:
+ with backend.transaction():
backend.execute("INSERT INTO yoyo_t values ('C')")
with backend.transaction():
@@ -95,12 +94,11 @@ class TestTransactionHandling(object):
if backend.has_transactional_ddl:
return
- with backend.transaction() as trans:
+ with backend.transaction(rollback_on_exit=True):
backend.execute("CREATE TABLE yoyo_a (id INT)") # implicit commit
backend.execute("INSERT INTO yoyo_a VALUES (1)")
backend.execute("CREATE TABLE yoyo_b (id INT)") # implicit commit
backend.execute("INSERT INTO yoyo_b VALUES (1)")
- trans.rollback()
count_a = backend.execute("SELECT COUNT(1) FROM yoyo_a").fetchall()[0][0]
assert count_a == 1