diff options
-rw-r--r-- | alembic/migration.py | 22 | ||||
-rw-r--r-- | alembic/revision.py | 153 | ||||
-rw-r--r-- | alembic/script.py | 22 | ||||
-rw-r--r-- | tests/test_revision.py | 157 |
4 files changed, 305 insertions, 49 deletions
diff --git a/alembic/migration.py b/alembic/migration.py index 7d854ba..eefb1d1 100644 --- a/alembic/migration.py +++ b/alembic/migration.py @@ -553,7 +553,7 @@ class RevisionStep(MigrationStep): @property def from_revisions(self): if self.is_upgrade: - return self.revision._down_revision_tuple + return self.revision._all_down_revisions else: return (self.revision.revision, ) @@ -562,11 +562,11 @@ class RevisionStep(MigrationStep): if self.is_upgrade: return (self.revision.revision, ) else: - return self.revision._down_revision_tuple + return self.revision._all_down_revisions @property def _has_scalar_down_revision(self): - return len(self.revision._down_revision_tuple) == 1 + return len(self.revision._all_down_revisions) == 1 def should_delete_branch(self, heads): if not self.is_downgrade: @@ -575,7 +575,7 @@ class RevisionStep(MigrationStep): if self.revision.revision not in heads: return False - downrevs = self.revision._down_revision_tuple + downrevs = self.revision._all_down_revisions if not downrevs: # is a base return True @@ -587,7 +587,7 @@ class RevisionStep(MigrationStep): descendants = set( r.revision for r in self.revision_map._get_descendant_nodes( - self.revision_map.get_revisions(downrev.nextrev), + self.revision_map.get_revisions(downrev._all_nextrev), check=False ) ) @@ -606,7 +606,7 @@ class RevisionStep(MigrationStep): # TODO: this doesn't work; make sure tests are here to ensure # this fails - #if len(downrev.nextrev.intersection(heads).difference( + #if len(downrev._all_nextrev.intersection(heads).difference( # [self.revision.revision])): return True @@ -662,7 +662,7 @@ class RevisionStep(MigrationStep): if not self.is_upgrade: return False - downrevs = self.revision._down_revision_tuple + downrevs = self.revision._all_down_revisions if not downrevs: # is a base @@ -680,7 +680,7 @@ class RevisionStep(MigrationStep): if not self.is_upgrade: return False - downrevs = self.revision._down_revision_tuple + downrevs = self.revision._all_down_revisions if len(downrevs) > 1 and \ len(heads.intersection(downrevs)) > 1: @@ -692,7 +692,7 @@ class RevisionStep(MigrationStep): if not self.is_downgrade: return False - downrevs = self.revision._down_revision_tuple + downrevs = self.revision._all_down_revisions if self.revision.revision in heads and len(downrevs) > 1: return True @@ -701,12 +701,12 @@ class RevisionStep(MigrationStep): def update_version_num(self, heads): if not self._has_scalar_down_revision: - downrev = heads.intersection(self.revision._down_revision_tuple) + downrev = heads.intersection(self.revision._all_down_revisions) assert len(downrev) == 1, \ "Can't do an UPDATE because downrevision is ambiguous" down_revision = list(downrev)[0] else: - down_revision = self.revision.down_revision + down_revision = self.revision._all_down_revisions[0] if self.is_upgrade: return down_revision, self.revision.revision diff --git a/alembic/revision.py b/alembic/revision.py index 071452f..0ab52ab 100644 --- a/alembic/revision.py +++ b/alembic/revision.py @@ -82,6 +82,16 @@ class RevisionMap(object): return self.bases @util.memoized_property + def _real_bases(self): + """All "real" base revisions as strings. + + :return: a tuple of string revision numbers. + + """ + self._revision_map + return self._real_bases + + @util.memoized_property def _revision_map(self): """memoized attribute, initializes the revision map from the initial collection. @@ -91,6 +101,7 @@ class RevisionMap(object): heads = sqlautil.OrderedSet() self.bases = () + self._real_bases = () has_branch_labels = set() for revision in self._generator(): @@ -104,14 +115,16 @@ class RevisionMap(object): heads.add(revision.revision) if revision.is_base: self.bases += (revision.revision, ) + if revision._is_real_base: + self._real_bases += (revision.revision, ) for rev in map_.values(): - for downrev in rev._down_revision_tuple: + for downrev in rev._all_down_revisions: if downrev not in map_: util.warn("Revision %s referenced from %s is not present" - % (rev.down_revision, rev)) + % (downrev, rev)) down_revision = map_[downrev] - down_revision.add_nextrev(rev.revision) + down_revision.add_nextrev(rev) heads.discard(downrev) map_[None] = map_[()] = None @@ -133,12 +146,13 @@ class RevisionMap(object): ) map_[branch_label] = revision revision.branch_labels.update(revision.branch_labels) - for node in self._get_descendant_nodes([revision], map_): + for node in self._get_descendant_nodes( + [revision], map_, include_dependencies=False): node.branch_labels.update(revision.branch_labels) parent = node while parent and \ - not parent.is_branch_point and not parent.is_merge_point: + not parent._is_real_branch_point and not parent.is_merge_point: parent.branch_labels.update(revision.branch_labels) if parent.down_revision: @@ -164,18 +178,20 @@ class RevisionMap(object): self._add_branches(revision, map_) if revision.is_base: self.bases += (revision.revision, ) - for downrev in revision._down_revision_tuple: + if revision._is_real_base: + self._real_bases += (revision.revision, ) + for downrev in revision._all_down_revisions: if downrev not in map_: util.warn( "Revision %s referenced from %s is not present" - % (revision.down_revision, revision) + % (downrev, revision) ) - map_[downrev].add_nextrev(revision.revision) - if revision.is_head: + map_[downrev].add_nextrev(revision) + if revision._is_real_head: self.heads = tuple( head for head in self.heads if head not in - set(revision._down_revision_tuple).union([revision.revision]) + set(revision._all_down_revisions).union([revision.revision]) ) + (revision.revision,) def get_current_head(self, branch_label=None): @@ -334,8 +350,12 @@ class RevisionMap(object): ] return bool( - set(self._get_descendant_nodes([target])) - .union(self._get_ancestor_nodes([target])) + set(self._get_descendant_nodes([target], + include_dependencies=False + )) + .union(self._get_ancestor_nodes([target], + include_dependencies=False + )) .intersection(test_against_revs) ) @@ -424,16 +444,28 @@ class RevisionMap(object): return self._iterate_revisions( upper, lower, inclusive=inclusive, implicit_base=implicit_base) - def _get_descendant_nodes(self, targets, map_=None, check=False): + def _get_descendant_nodes( + self, targets, map_=None, check=False, include_dependencies=True): + + if include_dependencies: + fn = lambda rev: rev._all_nextrev + else: + fn = lambda rev: rev.nextrev + return self._iterate_related_revisions( - lambda rev: rev.nextrev, - targets, map_=map_, check=check + fn, targets, map_=map_, check=check ) - def _get_ancestor_nodes(self, targets, map_=None, check=False): + def _get_ancestor_nodes( + self, targets, map_=None, check=False, include_dependencies=True): + + if include_dependencies: + fn = lambda rev: rev._all_down_revisions + else: + fn = lambda rev: rev._versioned_down_revisions + return self._iterate_related_revisions( - lambda rev: rev._down_revision_tuple, - targets, map_=map_, check=check + fn, targets, map_=map_, check=check ) def _iterate_related_revisions(self, fn, targets, map_, check=False): @@ -494,17 +526,17 @@ class RevisionMap(object): difference(lower_ancestors).\ difference(lower_descendants) for rev in candidate_lowers: - for downrev in rev._down_revision_tuple: + for downrev in rev._all_down_revisions: if self._revision_map[downrev] in candidate_lowers: break else: base_lowers.add(rev) lowers = base_lowers.union(requested_lowers) elif implicit_base: - base_lowers = set(self.get_revisions(self.bases)) + base_lowers = set(self.get_revisions(self._real_bases)) lowers = base_lowers.union(requested_lowers) elif not requested_lowers: - lowers = set(self.get_revisions(self.bases)) + lowers = set(self.get_revisions(self._real_bases)) else: lowers = requested_lowers @@ -523,12 +555,12 @@ class RevisionMap(object): branch_todo = set( rev for rev in (self._revision_map[rev] for rev in total_space) - if rev.is_branch_point and - len(total_space.intersection(rev.nextrev)) > 1 + if rev._is_real_branch_point and + len(total_space.intersection(rev._all_nextrev)) > 1 ) # it's not possible for any "uppers" to be in branch_todo, - # because the .nextrev of those nodes is not in total_space + # because the ._all_nextrev of those nodes is not in total_space #assert not branch_todo.intersection(uppers) todo = collections.deque( @@ -541,11 +573,17 @@ class RevisionMap(object): # descendants left in the queue if not todo: todo.extendleft( - rev for rev in branch_todo - if not rev.nextrev.intersection(total_space) + sorted( + ( + rev for rev in branch_todo + if not rev._all_nextrev.intersection(total_space) + ), + # favor "revisioned" branch points before + # dependent ones + key=lambda rev: 0 if rev.is_branch_point else 1 + ) ) branch_todo.difference_update(todo) - # iterate nodes that are in the immediate todo while todo: rev = todo.popleft() @@ -555,7 +593,7 @@ class RevisionMap(object): # don't consume any actual branch nodes todo.extendleft([ self._revision_map[downrev] - for downrev in reversed(rev._down_revision_tuple) + for downrev in reversed(rev._all_down_revisions) if self._revision_map[downrev] not in branch_todo and downrev in total_space]) @@ -577,28 +615,56 @@ class Revision(object): """ nextrev = frozenset() + """following revisions, based on down_revision only.""" + + _all_nextrev = frozenset() revision = None """The string revision number.""" down_revision = None - """The ``down_revision`` identifier(s) within the migration script.""" + """The ``down_revision`` identifier(s) within the migration script. + + Note that the total set of "down" revisions is + down_revision + dependencies. + + """ + + dependencies = None + """Additional revisions which this revision is dependent on. + + From a migration standpoint, these dependencies are added to the + down_revision to form the full iteration. However, the separation + of down_revision from "dependencies" is to assist in navigating + a history that contains many branches, typically a multi-root scenario. + + """ branch_labels = None """Optional string/tuple of symbolic names to apply to this revision's branch""" - def __init__(self, revision, down_revision, branch_labels=None): + def __init__( + self, revision, down_revision, + dependencies=None, branch_labels=None): self.revision = revision self.down_revision = tuple_rev_as_scalar(down_revision) + self.dependencies = tuple_rev_as_scalar(dependencies) self._orig_branch_labels = util.to_tuple(branch_labels, default=()) self.branch_labels = set(self._orig_branch_labels) - def add_nextrev(self, rev): - self.nextrev = self.nextrev.union([rev]) + def add_nextrev(self, revision): + self._all_nextrev = self._all_nextrev.union([revision.revision]) + if self.revision in revision._versioned_down_revisions: + self.nextrev = self.nextrev.union([revision.revision]) + + @property + def _all_down_revisions(self): + return util.to_tuple(self.down_revision, default=()) + \ + util.to_tuple(self.dependencies, default=()) @property - def _down_revision_tuple(self): + def _versioned_down_revisions(self): return util.to_tuple(self.down_revision, default=()) @property @@ -613,12 +679,23 @@ class Revision(object): return not bool(self.nextrev) @property + def _is_real_head(self): + return not bool(self._all_nextrev) + + @property def is_base(self): """Return True if this :class:`.Revision` is a 'base' revision.""" return self.down_revision is None @property + def _is_real_base(self): + """Return True if this :class:`.Revision` is a "real" base revision, + e.g. that it has no dependencies either.""" + + return self.down_revision is None and self.dependencies is None + + @property def is_branch_point(self): """Return True if this :class:`.Script` is a branch point. @@ -631,10 +708,18 @@ class Revision(object): return len(self.nextrev) > 1 @property + def _is_real_branch_point(self): + """Return True if this :class:`.Script` is a 'real' branch point, + taking into account dependencies as well. + + """ + return len(self._all_nextrev) > 1 + + @property def is_merge_point(self): """Return True if this :class:`.Script` is a merge point.""" - return len(self._down_revision_tuple) > 1 + return len(self._versioned_down_revisions) > 1 def tuple_rev_as_scalar(rev): diff --git a/alembic/script.py b/alembic/script.py index 1835605..2147652 100644 --- a/alembic/script.py +++ b/alembic/script.py @@ -483,7 +483,10 @@ class Script(revision.Revision): rev_id, module.down_revision, branch_labels=util.to_tuple( - getattr(module, 'branch_labels', None), default=())) + getattr(module, 'branch_labels', None), default=()), + dependencies=util.to_tuple( + getattr(module, 'depends_on', None), default=()) + ) module = None """The Python module representing the actual script itself.""" @@ -522,6 +525,10 @@ class Script(revision.Revision): else: entry += "Parent: %s\n" % (self._format_down_revision(), ) + if self.dependencies: + entry += "Depends on: %s\n" % ( + util.format_as_comma(self.dependencies)) + if self.is_branch_point: entry += "Branches into: %s\n" % ( util.format_as_comma(self.nextrev)) @@ -554,8 +561,15 @@ class Script(revision.Revision): include_parents=False, tree_indicators=True): text = self.revision if include_parents: - text = "%s -> %s" % ( - self._format_down_revision(), text) + if self.dependencies: + text = "%s (%s) -> %s" % ( + self._format_down_revision(), + util.format_as_comma(self.dependencies), + text + ) + else: + text = "%s -> %s" % ( + self._format_down_revision(), text) if include_branches and self.branch_labels: text += " (%s)" % util.format_as_comma(self.branch_labels) if tree_indicators: @@ -584,7 +598,7 @@ class Script(revision.Revision): if not self.down_revision: return "<base>" else: - return util.format_as_comma(self._down_revision_tuple) + return util.format_as_comma(self._versioned_down_revisions) @classmethod def _from_path(cls, scriptdir, path): diff --git a/tests/test_revision.py b/tests/test_revision.py index 10fae91..e9e8935 100644 --- a/tests/test_revision.py +++ b/tests/test_revision.py @@ -715,3 +715,160 @@ class MultipleBaseTest(DownIterateTest): ['b3', 'a3', 'b2', 'a2', 'base2'], inclusive=False, implicit_base=True ) + + +class MultipleBaseCrossDependencyTestOne(DownIterateTest): + def setUp(self): + self.map = RevisionMap( + lambda: [ + Revision('base1', (), branch_labels='b_1'), + Revision('a1a', ('base1',)), + Revision('a1b', ('base1',)), + Revision('b1a', ('a1a',)), + Revision('b1b', ('a1b', ), dependencies='a3'), + + Revision('base2', (), branch_labels='b_2'), + Revision('a2', ('base2',)), + Revision('b2', ('a2',)), + Revision('c2', ('b2', ), dependencies='a3'), + Revision('d2', ('c2',)), + + Revision('base3', (), branch_labels='b_3'), + Revision('a3', ('base3',)), + Revision('b3', ('a3',)), + ] + ) + + def test_what_are_the_heads(self): + eq_(self.map.heads, ("b1a", "b1b", "d2", "b3")) + + def test_heads_to_base(self): + self._assert_iteration( + "heads", "base", + [ + + 'b1a', 'a1a', 'b1b', 'a1b', 'd2', 'c2', 'b2', 'a2', 'base2', + 'b3', 'a3', 'base3', + 'base1' + ] + ) + + def test_we_need_head2(self): + # the 2 branch relies on the 3 branch + self._assert_iteration( + "b_2@head", "base", + ['d2', 'c2', 'b2', 'a2', 'base2', 'a3', 'base3'] + ) + + def test_we_need_head3(self): + # the 3 branch can be upgraded alone. + self._assert_iteration( + "b_3@head", "base", + ['b3', 'a3', 'base3'] + ) + + def test_we_need_head1(self): + # the 1 branch relies on the 3 branch + self._assert_iteration( + "b1b@head", "base", + ['b1b', 'a1b', 'base1', 'a3', 'base3'] + ) + + def test_we_need_base2(self): + # consider a downgrade to b_2@base - we + # want to run through all the "2"s alone, and we're done. + self._assert_iteration( + "heads", "b_2@base", + ['d2', 'c2', 'b2', 'a2', 'base2'] + ) + + def test_we_need_base3(self): + # consider a downgrade to b_3@base - due to the a3 dependency, we + # need to downgrade everything dependent on a3 + # as well, which means b1b and c2. Then we can downgrade + # the 3s. + self._assert_iteration( + "heads", "b_3@base", + ['b1b', 'd2', 'c2', 'b3', 'a3', 'base3'] + ) + + +class MultipleBaseCrossDependencyTestTwo(DownIterateTest): + def setUp(self): + self.map = RevisionMap( + lambda: [ + Revision('base1', (), branch_labels='b_1'), + Revision('a1', 'base1'), + Revision('b1', 'a1'), + Revision('c1', 'b1'), + + Revision('base2', (), dependencies='base1', branch_labels='b_2'), + Revision('a2', 'base2'), + Revision('b2', 'a2'), + Revision('c2', 'b2'), + Revision('d2', 'c2'), + + Revision('base3', (), branch_labels='b_3'), + Revision('a3', 'base3'), + Revision('b3', 'a3'), + Revision('c3', 'b3', dependencies='b2'), + Revision('d3', 'c3'), + ] + ) + + def test_what_are_the_heads(self): + eq_(self.map.heads, ("c1", "d2", "d3")) + + def test_heads_to_base(self): + self._assert_iteration( + "heads", "base", + [ + 'c1', 'b1', 'a1', + 'd2', 'c2', + 'd3', 'c3', 'b3', 'a3', 'base3', + 'b2', 'a2', 'base2', + 'base1' + ] + ) + + def test_we_need_head2(self): + self._assert_iteration( + "b_2@head", "base", + ['d2', 'c2', 'b2', 'a2', 'base2', 'base1'] + ) + + def test_we_need_head3(self): + self._assert_iteration( + "b_3@head", "base", + ['d3', 'c3', 'b3', 'a3', 'base3', 'b2', 'a2', 'base2', 'base1'] + ) + + def test_we_need_head1(self): + self._assert_iteration( + "b_1@head", "base", + ['c1', 'b1', 'a1', 'base1'] + ) + + def test_we_need_base1(self): + self._assert_iteration( + "heads", "b_1@base", + [ + 'c1', 'b1', 'a1', + 'd2', 'c2', + 'd3', 'c3', 'b2', 'a2', 'base2', + 'base1' + ] + ) + + def test_we_need_base2(self): + self._assert_iteration( + "heads", "b_2@base", + ['d2', 'c2', 'd3', 'c3', 'b2', 'a2', 'base2'] + ) + + def test_we_need_base3(self): + self._assert_iteration( + "heads", "b_3@base", + ['d3', 'c3', 'b3', 'a3', 'base3'] + ) + |