diff options
author | Olly Cope <olly@ollycope.com> | 2022-10-29 16:05:43 +0000 |
---|---|---|
committer | Olly Cope <olly@ollycope.com> | 2022-10-29 16:05:43 +0000 |
commit | 5e66d76c03a4eadb01eb6e88b4480b303a3d9e9e (patch) | |
tree | 30795886930a445c2b8f5a0437c445397e1147f5 | |
parent | 236b839acc6c7a20cb30cd84985f3ece725a6939 (diff) | |
download | yoyo-5e66d76c03a4eadb01eb6e88b4480b303a3d9e9e.tar.gz |
transaction handling: replace confusing rollback method with function argument
This code:
with backend.transaction() as transaction:
transaction.rollback()
...
Looks like it rolls back the connection at the start of the block - but
that's not what is actually happening. The code becomes much clearer with this
change:
with backend.transaction(rollback_on_exit=True):
...
-rw-r--r-- | yoyo/backends/base.py | 22 | ||||
-rwxr-xr-x | yoyo/migrations.py | 16 | ||||
-rw-r--r-- | yoyo/tests/test_backends.py | 8 |
3 files changed, 17 insertions, 29 deletions
diff --git a/yoyo/backends/base.py b/yoyo/backends/base.py index 4f95a66..1f3c167 100644 --- a/yoyo/backends/base.py +++ b/yoyo/backends/base.py @@ -45,9 +45,9 @@ class TransactionManager: when the context manager block closes """ - def __init__(self, backend): + def __init__(self, backend, rollback_on_exit=False): self.backend = backend - self._rollback = False + self.rollback_on_exit = rollback_on_exit def __enter__(self): self._do_begin() @@ -58,18 +58,11 @@ class TransactionManager: self._do_rollback() return None - if self._rollback: + if self.rollback_on_exit: self._do_rollback() else: self._do_commit() - def rollback(self): - """ - Flag that the transaction will be rolled back when the with statement - exits - """ - self._rollback = True - def _do_begin(self): """ Instruct the backend to begin a transaction @@ -238,9 +231,8 @@ class DatabaseBackend: table_name_quoted = self.quote_identifier(table_name) sql = self.create_test_table_sql.format(table_name_quoted=table_name_quoted) try: - with self.transaction() as t: + with self.transaction(rollback_on_exit=True): self.execute(sql) - t.rollback() except self.DatabaseError: return False @@ -263,12 +255,12 @@ class DatabaseBackend: ) return [row[0] for row in cursor.fetchall()] - def transaction(self): + def transaction(self, rollback_on_exit=False): if not self._in_transaction: - return TransactionManager(self) + return TransactionManager(self, rollback_on_exit=rollback_on_exit) else: - return SavepointTransactionManager(self) + return SavepointTransactionManager(self, rollback_on_exit=rollback_on_exit) def cursor(self): return self.connection.cursor() diff --git a/yoyo/migrations.py b/yoyo/migrations.py index 5450d0a..5b30049 100755 --- a/yoyo/migrations.py +++ b/yoyo/migrations.py @@ -312,16 +312,14 @@ class TransactionWrapper(StepBase): return "<TransactionWrapper {!r}>".format(self.step) def apply(self, backend, force=False, direction="apply"): - with backend.transaction() as transaction: - try: + try: + with backend.transaction(): getattr(self.step, direction)(backend, force) - except backend.DatabaseError: - if force or self.ignore_errors in (direction, "all"): - logger.exception("Ignored error in %r", self.step) - transaction.rollback() - return - else: - raise + except backend.DatabaseError: + if force or self.ignore_errors in (direction, "all"): + logger.exception("Ignored error in %r", self.step) + else: + raise def rollback(self, backend, force=False): self.apply(backend, force, "rollback") diff --git a/yoyo/tests/test_backends.py b/yoyo/tests/test_backends.py index 8d99009..e4b9727 100644 --- a/yoyo/tests/test_backends.py +++ b/yoyo/tests/test_backends.py @@ -44,11 +44,10 @@ class TestTransactionHandling(object): with backend.transaction(): backend.execute("INSERT INTO yoyo_t values ('A')") - with backend.transaction() as trans: + with backend.transaction(rollback_on_exit=True): backend.execute("INSERT INTO yoyo_t values ('B')") - trans.rollback() - with backend.transaction() as trans: + with backend.transaction(): backend.execute("INSERT INTO yoyo_t values ('C')") with backend.transaction(): @@ -95,12 +94,11 @@ class TestTransactionHandling(object): if backend.has_transactional_ddl: return - with backend.transaction() as trans: + with backend.transaction(rollback_on_exit=True): backend.execute("CREATE TABLE yoyo_a (id INT)") # implicit commit backend.execute("INSERT INTO yoyo_a VALUES (1)") backend.execute("CREATE TABLE yoyo_b (id INT)") # implicit commit backend.execute("INSERT INTO yoyo_b VALUES (1)") - trans.rollback() count_a = backend.execute("SELECT COUNT(1) FROM yoyo_a").fetchall()[0][0] assert count_a == 1 |