from alembic.testing.fixtures import TestBase from alembic.testing import config, eq_, assert_raises, assert_raises_message from sqlalchemy import Table, MetaData, Column, String from sqlalchemy.engine.reflection import Inspector from alembic import migration from alembic.util import CommandError version_table = Table('version_table', MetaData(), Column('version_num', String(32), nullable=False)) def _up(from_, to_, branch_presence_changed=False): return migration.StampStep( from_, to_, True, branch_presence_changed ) def _down(from_, to_, branch_presence_changed=False): return migration.StampStep( from_, to_, False, branch_presence_changed ) class TestMigrationContext(TestBase): @classmethod def setup_class(cls): cls.bind = config.db def setUp(self): self.connection = self.bind.connect() self.transaction = self.connection.begin() def tearDown(self): version_table.drop(self.connection, checkfirst=True) self.transaction.rollback() self.connection.close() def make_one(self, **kwargs): return migration.MigrationContext.configure(**kwargs) def get_revision(self): result = self.connection.execute(version_table.select()) rows = result.fetchall() if len(rows) == 0: return None eq_(len(rows), 1) return rows[0]['version_num'] def test_config_default_version_table_name(self): context = self.make_one(dialect_name='sqlite') eq_(context._version.name, 'alembic_version') def test_config_explicit_version_table_name(self): context = self.make_one(dialect_name='sqlite', opts={'version_table': 'explicit'}) eq_(context._version.name, 'explicit') def test_config_explicit_version_table_schema(self): context = self.make_one(dialect_name='sqlite', opts={'version_table_schema': 'explicit'}) eq_(context._version.schema, 'explicit') def test_get_current_revision_doesnt_create_version_table(self): context = self.make_one(connection=self.connection, opts={'version_table': 'version_table'}) eq_(context.get_current_revision(), None) insp = Inspector(self.connection) assert ('version_table' not in insp.get_table_names()) def test_get_current_revision(self): context = self.make_one(connection=self.connection, opts={'version_table': 'version_table'}) version_table.create(self.connection) eq_(context.get_current_revision(), None) self.connection.execute( version_table.insert().values(version_num='revid')) eq_(context.get_current_revision(), 'revid') def test_get_current_revision_error_if_starting_rev_given_online(self): context = self.make_one(connection=self.connection, opts={'starting_rev': 'boo'}) assert_raises( CommandError, context.get_current_revision ) def test_get_current_revision_offline(self): context = self.make_one(dialect_name='sqlite', opts={'starting_rev': 'startrev', 'as_sql': True}) eq_(context.get_current_revision(), 'startrev') def test_get_current_revision_multiple_heads(self): version_table.create(self.connection) context = self.make_one(connection=self.connection, opts={'version_table': 'version_table'}) updater = migration.HeadMaintainer(context, ()) updater.update_to_step(_up(None, 'a', True)) updater.update_to_step(_up(None, 'b', True)) assert_raises_message( CommandError, "Version table 'version_table' has more than one head present; " "please use get_current_heads()", context.get_current_revision ) def test_get_heads(self): version_table.create(self.connection) context = self.make_one(connection=self.connection, opts={'version_table': 'version_table'}) updater = migration.HeadMaintainer(context, ()) updater.update_to_step(_up(None, 'a', True)) updater.update_to_step(_up(None, 'b', True)) eq_(context.get_current_heads(), ('a', 'b')) def test_get_heads_offline(self): version_table.create(self.connection) context = self.make_one(connection=self.connection, opts={ 'starting_rev': 'q', 'version_table': 'version_table', 'as_sql': True}) eq_(context.get_current_heads(), ('q', )) class UpdateRevTest(TestBase): @classmethod def setup_class(cls): cls.bind = config.db def setUp(self): self.connection = self.bind.connect() self.context = migration.MigrationContext.configure( connection=self.connection, opts={"version_table": "version_table"}) version_table.create(self.connection) self.updater = migration.HeadMaintainer(self.context, ()) def tearDown(self): version_table.drop(self.connection, checkfirst=True) self.connection.close() def _assert_heads(self, heads): eq_(self.context.get_current_heads(), heads) eq_(self.updater.heads, set(heads)) def test_update_none_to_single(self): self.updater.update_to_step(_up(None, 'a', True)) self._assert_heads(('a',)) def test_update_single_to_single(self): self.updater.update_to_step(_up(None, 'a', True)) self.updater.update_to_step(_up('a', 'b')) self._assert_heads(('b',)) def test_update_single_to_none(self): self.updater.update_to_step(_up(None, 'a', True)) self.updater.update_to_step(_down('a', None, True)) self._assert_heads(()) def test_add_branches(self): self.updater.update_to_step(_up(None, 'a', True)) self.updater.update_to_step(_up('a', 'b')) self.updater.update_to_step(_up(None, 'c', True)) self._assert_heads(('b', 'c')) self.updater.update_to_step(_up('c', 'd')) self.updater.update_to_step(_up('d', 'e1')) self.updater.update_to_step(_up('d', 'e2', True)) self._assert_heads(('b', 'e1', 'e2')) def test_teardown_branches(self): self.updater.update_to_step(_up(None, 'd1', True)) self.updater.update_to_step(_up(None, 'd2', True)) self._assert_heads(('d1', 'd2')) self.updater.update_to_step(_down('d1', 'c')) self._assert_heads(('c', 'd2')) self.updater.update_to_step(_down('d2', 'c', True)) self._assert_heads(('c',)) self.updater.update_to_step(_down('c', 'b')) self._assert_heads(('b',)) def test_resolve_merges(self): self.updater.update_to_step(_up(None, 'a', True)) self.updater.update_to_step(_up('a', 'b')) self.updater.update_to_step(_up('b', 'c1')) self.updater.update_to_step(_up('b', 'c2', True)) self.updater.update_to_step(_up('c1', 'd1')) self.updater.update_to_step(_up('c2', 'd2')) self._assert_heads(('d1', 'd2')) self.updater.update_to_step(_up(('d1', 'd2'), 'e')) self._assert_heads(('e',)) def test_unresolve_merges(self): self.updater.update_to_step(_up(None, 'e', True)) self.updater.update_to_step(_down('e', ('d1', 'd2'))) self._assert_heads(('d2', 'd1')) self.updater.update_to_step(_down('d2', 'c2')) self._assert_heads(('c2', 'd1')) def test_update_no_match(self): self.updater.update_to_step(_up(None, 'a', True)) self.updater.heads.add('x') assert_raises_message( CommandError, "Online migration expected to match one row when updating " "'x' to 'b' in 'version_table'; 0 found", self.updater.update_to_step, _up('x', 'b') ) def test_update_multi_match(self): self.connection.execute(version_table.insert(), version_num='a') self.connection.execute(version_table.insert(), version_num='a') self.updater.heads.add('a') assert_raises_message( CommandError, "Online migration expected to match one row when updating " "'a' to 'b' in 'version_table'; 2 found", self.updater.update_to_step, _up('a', 'b') ) def test_delete_no_match(self): self.updater.update_to_step(_up(None, 'a', True)) self.updater.heads.add('x') assert_raises_message( CommandError, "Online migration expected to match one row when " "deleting 'x' in 'version_table'; 0 found", self.updater.update_to_step, _down('x', None, True) ) def test_delete_multi_match(self): self.connection.execute(version_table.insert(), version_num='a') self.connection.execute(version_table.insert(), version_num='a') self.updater.heads.add('a') assert_raises_message( CommandError, "Online migration expected to match one row when " "deleting 'a' in 'version_table'; 2 found", self.updater.update_to_step, _down('a', None, True) )