summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--alembic/migration.py22
-rw-r--r--alembic/revision.py153
-rw-r--r--alembic/script.py22
-rw-r--r--tests/test_revision.py157
4 files changed, 305 insertions, 49 deletions
diff --git a/alembic/migration.py b/alembic/migration.py
index 7d854ba..eefb1d1 100644
--- a/alembic/migration.py
+++ b/alembic/migration.py
@@ -553,7 +553,7 @@ class RevisionStep(MigrationStep):
@property
def from_revisions(self):
if self.is_upgrade:
- return self.revision._down_revision_tuple
+ return self.revision._all_down_revisions
else:
return (self.revision.revision, )
@@ -562,11 +562,11 @@ class RevisionStep(MigrationStep):
if self.is_upgrade:
return (self.revision.revision, )
else:
- return self.revision._down_revision_tuple
+ return self.revision._all_down_revisions
@property
def _has_scalar_down_revision(self):
- return len(self.revision._down_revision_tuple) == 1
+ return len(self.revision._all_down_revisions) == 1
def should_delete_branch(self, heads):
if not self.is_downgrade:
@@ -575,7 +575,7 @@ class RevisionStep(MigrationStep):
if self.revision.revision not in heads:
return False
- downrevs = self.revision._down_revision_tuple
+ downrevs = self.revision._all_down_revisions
if not downrevs:
# is a base
return True
@@ -587,7 +587,7 @@ class RevisionStep(MigrationStep):
descendants = set(
r.revision for r in self.revision_map._get_descendant_nodes(
- self.revision_map.get_revisions(downrev.nextrev),
+ self.revision_map.get_revisions(downrev._all_nextrev),
check=False
)
)
@@ -606,7 +606,7 @@ class RevisionStep(MigrationStep):
# TODO: this doesn't work; make sure tests are here to ensure
# this fails
- #if len(downrev.nextrev.intersection(heads).difference(
+ #if len(downrev._all_nextrev.intersection(heads).difference(
# [self.revision.revision])):
return True
@@ -662,7 +662,7 @@ class RevisionStep(MigrationStep):
if not self.is_upgrade:
return False
- downrevs = self.revision._down_revision_tuple
+ downrevs = self.revision._all_down_revisions
if not downrevs:
# is a base
@@ -680,7 +680,7 @@ class RevisionStep(MigrationStep):
if not self.is_upgrade:
return False
- downrevs = self.revision._down_revision_tuple
+ downrevs = self.revision._all_down_revisions
if len(downrevs) > 1 and \
len(heads.intersection(downrevs)) > 1:
@@ -692,7 +692,7 @@ class RevisionStep(MigrationStep):
if not self.is_downgrade:
return False
- downrevs = self.revision._down_revision_tuple
+ downrevs = self.revision._all_down_revisions
if self.revision.revision in heads and len(downrevs) > 1:
return True
@@ -701,12 +701,12 @@ class RevisionStep(MigrationStep):
def update_version_num(self, heads):
if not self._has_scalar_down_revision:
- downrev = heads.intersection(self.revision._down_revision_tuple)
+ downrev = heads.intersection(self.revision._all_down_revisions)
assert len(downrev) == 1, \
"Can't do an UPDATE because downrevision is ambiguous"
down_revision = list(downrev)[0]
else:
- down_revision = self.revision.down_revision
+ down_revision = self.revision._all_down_revisions[0]
if self.is_upgrade:
return down_revision, self.revision.revision
diff --git a/alembic/revision.py b/alembic/revision.py
index 071452f..0ab52ab 100644
--- a/alembic/revision.py
+++ b/alembic/revision.py
@@ -82,6 +82,16 @@ class RevisionMap(object):
return self.bases
@util.memoized_property
+ def _real_bases(self):
+ """All "real" base revisions as strings.
+
+ :return: a tuple of string revision numbers.
+
+ """
+ self._revision_map
+ return self._real_bases
+
+ @util.memoized_property
def _revision_map(self):
"""memoized attribute, initializes the revision map from the
initial collection.
@@ -91,6 +101,7 @@ class RevisionMap(object):
heads = sqlautil.OrderedSet()
self.bases = ()
+ self._real_bases = ()
has_branch_labels = set()
for revision in self._generator():
@@ -104,14 +115,16 @@ class RevisionMap(object):
heads.add(revision.revision)
if revision.is_base:
self.bases += (revision.revision, )
+ if revision._is_real_base:
+ self._real_bases += (revision.revision, )
for rev in map_.values():
- for downrev in rev._down_revision_tuple:
+ for downrev in rev._all_down_revisions:
if downrev not in map_:
util.warn("Revision %s referenced from %s is not present"
- % (rev.down_revision, rev))
+ % (downrev, rev))
down_revision = map_[downrev]
- down_revision.add_nextrev(rev.revision)
+ down_revision.add_nextrev(rev)
heads.discard(downrev)
map_[None] = map_[()] = None
@@ -133,12 +146,13 @@ class RevisionMap(object):
)
map_[branch_label] = revision
revision.branch_labels.update(revision.branch_labels)
- for node in self._get_descendant_nodes([revision], map_):
+ for node in self._get_descendant_nodes(
+ [revision], map_, include_dependencies=False):
node.branch_labels.update(revision.branch_labels)
parent = node
while parent and \
- not parent.is_branch_point and not parent.is_merge_point:
+ not parent._is_real_branch_point and not parent.is_merge_point:
parent.branch_labels.update(revision.branch_labels)
if parent.down_revision:
@@ -164,18 +178,20 @@ class RevisionMap(object):
self._add_branches(revision, map_)
if revision.is_base:
self.bases += (revision.revision, )
- for downrev in revision._down_revision_tuple:
+ if revision._is_real_base:
+ self._real_bases += (revision.revision, )
+ for downrev in revision._all_down_revisions:
if downrev not in map_:
util.warn(
"Revision %s referenced from %s is not present"
- % (revision.down_revision, revision)
+ % (downrev, revision)
)
- map_[downrev].add_nextrev(revision.revision)
- if revision.is_head:
+ map_[downrev].add_nextrev(revision)
+ if revision._is_real_head:
self.heads = tuple(
head for head in self.heads
if head not in
- set(revision._down_revision_tuple).union([revision.revision])
+ set(revision._all_down_revisions).union([revision.revision])
) + (revision.revision,)
def get_current_head(self, branch_label=None):
@@ -334,8 +350,12 @@ class RevisionMap(object):
]
return bool(
- set(self._get_descendant_nodes([target]))
- .union(self._get_ancestor_nodes([target]))
+ set(self._get_descendant_nodes([target],
+ include_dependencies=False
+ ))
+ .union(self._get_ancestor_nodes([target],
+ include_dependencies=False
+ ))
.intersection(test_against_revs)
)
@@ -424,16 +444,28 @@ class RevisionMap(object):
return self._iterate_revisions(
upper, lower, inclusive=inclusive, implicit_base=implicit_base)
- def _get_descendant_nodes(self, targets, map_=None, check=False):
+ def _get_descendant_nodes(
+ self, targets, map_=None, check=False, include_dependencies=True):
+
+ if include_dependencies:
+ fn = lambda rev: rev._all_nextrev
+ else:
+ fn = lambda rev: rev.nextrev
+
return self._iterate_related_revisions(
- lambda rev: rev.nextrev,
- targets, map_=map_, check=check
+ fn, targets, map_=map_, check=check
)
- def _get_ancestor_nodes(self, targets, map_=None, check=False):
+ def _get_ancestor_nodes(
+ self, targets, map_=None, check=False, include_dependencies=True):
+
+ if include_dependencies:
+ fn = lambda rev: rev._all_down_revisions
+ else:
+ fn = lambda rev: rev._versioned_down_revisions
+
return self._iterate_related_revisions(
- lambda rev: rev._down_revision_tuple,
- targets, map_=map_, check=check
+ fn, targets, map_=map_, check=check
)
def _iterate_related_revisions(self, fn, targets, map_, check=False):
@@ -494,17 +526,17 @@ class RevisionMap(object):
difference(lower_ancestors).\
difference(lower_descendants)
for rev in candidate_lowers:
- for downrev in rev._down_revision_tuple:
+ for downrev in rev._all_down_revisions:
if self._revision_map[downrev] in candidate_lowers:
break
else:
base_lowers.add(rev)
lowers = base_lowers.union(requested_lowers)
elif implicit_base:
- base_lowers = set(self.get_revisions(self.bases))
+ base_lowers = set(self.get_revisions(self._real_bases))
lowers = base_lowers.union(requested_lowers)
elif not requested_lowers:
- lowers = set(self.get_revisions(self.bases))
+ lowers = set(self.get_revisions(self._real_bases))
else:
lowers = requested_lowers
@@ -523,12 +555,12 @@ class RevisionMap(object):
branch_todo = set(
rev for rev in
(self._revision_map[rev] for rev in total_space)
- if rev.is_branch_point and
- len(total_space.intersection(rev.nextrev)) > 1
+ if rev._is_real_branch_point and
+ len(total_space.intersection(rev._all_nextrev)) > 1
)
# it's not possible for any "uppers" to be in branch_todo,
- # because the .nextrev of those nodes is not in total_space
+ # because the ._all_nextrev of those nodes is not in total_space
#assert not branch_todo.intersection(uppers)
todo = collections.deque(
@@ -541,11 +573,17 @@ class RevisionMap(object):
# descendants left in the queue
if not todo:
todo.extendleft(
- rev for rev in branch_todo
- if not rev.nextrev.intersection(total_space)
+ sorted(
+ (
+ rev for rev in branch_todo
+ if not rev._all_nextrev.intersection(total_space)
+ ),
+ # favor "revisioned" branch points before
+ # dependent ones
+ key=lambda rev: 0 if rev.is_branch_point else 1
+ )
)
branch_todo.difference_update(todo)
-
# iterate nodes that are in the immediate todo
while todo:
rev = todo.popleft()
@@ -555,7 +593,7 @@ class RevisionMap(object):
# don't consume any actual branch nodes
todo.extendleft([
self._revision_map[downrev]
- for downrev in reversed(rev._down_revision_tuple)
+ for downrev in reversed(rev._all_down_revisions)
if self._revision_map[downrev] not in branch_todo
and downrev in total_space])
@@ -577,28 +615,56 @@ class Revision(object):
"""
nextrev = frozenset()
+ """following revisions, based on down_revision only."""
+
+ _all_nextrev = frozenset()
revision = None
"""The string revision number."""
down_revision = None
- """The ``down_revision`` identifier(s) within the migration script."""
+ """The ``down_revision`` identifier(s) within the migration script.
+
+ Note that the total set of "down" revisions is
+ down_revision + dependencies.
+
+ """
+
+ dependencies = None
+ """Additional revisions which this revision is dependent on.
+
+ From a migration standpoint, these dependencies are added to the
+ down_revision to form the full iteration. However, the separation
+ of down_revision from "dependencies" is to assist in navigating
+ a history that contains many branches, typically a multi-root scenario.
+
+ """
branch_labels = None
"""Optional string/tuple of symbolic names to apply to this
revision's branch"""
- def __init__(self, revision, down_revision, branch_labels=None):
+ def __init__(
+ self, revision, down_revision,
+ dependencies=None, branch_labels=None):
self.revision = revision
self.down_revision = tuple_rev_as_scalar(down_revision)
+ self.dependencies = tuple_rev_as_scalar(dependencies)
self._orig_branch_labels = util.to_tuple(branch_labels, default=())
self.branch_labels = set(self._orig_branch_labels)
- def add_nextrev(self, rev):
- self.nextrev = self.nextrev.union([rev])
+ def add_nextrev(self, revision):
+ self._all_nextrev = self._all_nextrev.union([revision.revision])
+ if self.revision in revision._versioned_down_revisions:
+ self.nextrev = self.nextrev.union([revision.revision])
+
+ @property
+ def _all_down_revisions(self):
+ return util.to_tuple(self.down_revision, default=()) + \
+ util.to_tuple(self.dependencies, default=())
@property
- def _down_revision_tuple(self):
+ def _versioned_down_revisions(self):
return util.to_tuple(self.down_revision, default=())
@property
@@ -613,12 +679,23 @@ class Revision(object):
return not bool(self.nextrev)
@property
+ def _is_real_head(self):
+ return not bool(self._all_nextrev)
+
+ @property
def is_base(self):
"""Return True if this :class:`.Revision` is a 'base' revision."""
return self.down_revision is None
@property
+ def _is_real_base(self):
+ """Return True if this :class:`.Revision` is a "real" base revision,
+ e.g. that it has no dependencies either."""
+
+ return self.down_revision is None and self.dependencies is None
+
+ @property
def is_branch_point(self):
"""Return True if this :class:`.Script` is a branch point.
@@ -631,10 +708,18 @@ class Revision(object):
return len(self.nextrev) > 1
@property
+ def _is_real_branch_point(self):
+ """Return True if this :class:`.Script` is a 'real' branch point,
+ taking into account dependencies as well.
+
+ """
+ return len(self._all_nextrev) > 1
+
+ @property
def is_merge_point(self):
"""Return True if this :class:`.Script` is a merge point."""
- return len(self._down_revision_tuple) > 1
+ return len(self._versioned_down_revisions) > 1
def tuple_rev_as_scalar(rev):
diff --git a/alembic/script.py b/alembic/script.py
index 1835605..2147652 100644
--- a/alembic/script.py
+++ b/alembic/script.py
@@ -483,7 +483,10 @@ class Script(revision.Revision):
rev_id,
module.down_revision,
branch_labels=util.to_tuple(
- getattr(module, 'branch_labels', None), default=()))
+ getattr(module, 'branch_labels', None), default=()),
+ dependencies=util.to_tuple(
+ getattr(module, 'depends_on', None), default=())
+ )
module = None
"""The Python module representing the actual script itself."""
@@ -522,6 +525,10 @@ class Script(revision.Revision):
else:
entry += "Parent: %s\n" % (self._format_down_revision(), )
+ if self.dependencies:
+ entry += "Depends on: %s\n" % (
+ util.format_as_comma(self.dependencies))
+
if self.is_branch_point:
entry += "Branches into: %s\n" % (
util.format_as_comma(self.nextrev))
@@ -554,8 +561,15 @@ class Script(revision.Revision):
include_parents=False, tree_indicators=True):
text = self.revision
if include_parents:
- text = "%s -> %s" % (
- self._format_down_revision(), text)
+ if self.dependencies:
+ text = "%s (%s) -> %s" % (
+ self._format_down_revision(),
+ util.format_as_comma(self.dependencies),
+ text
+ )
+ else:
+ text = "%s -> %s" % (
+ self._format_down_revision(), text)
if include_branches and self.branch_labels:
text += " (%s)" % util.format_as_comma(self.branch_labels)
if tree_indicators:
@@ -584,7 +598,7 @@ class Script(revision.Revision):
if not self.down_revision:
return "<base>"
else:
- return util.format_as_comma(self._down_revision_tuple)
+ return util.format_as_comma(self._versioned_down_revisions)
@classmethod
def _from_path(cls, scriptdir, path):
diff --git a/tests/test_revision.py b/tests/test_revision.py
index 10fae91..e9e8935 100644
--- a/tests/test_revision.py
+++ b/tests/test_revision.py
@@ -715,3 +715,160 @@ class MultipleBaseTest(DownIterateTest):
['b3', 'a3', 'b2', 'a2', 'base2'],
inclusive=False, implicit_base=True
)
+
+
+class MultipleBaseCrossDependencyTestOne(DownIterateTest):
+ def setUp(self):
+ self.map = RevisionMap(
+ lambda: [
+ Revision('base1', (), branch_labels='b_1'),
+ Revision('a1a', ('base1',)),
+ Revision('a1b', ('base1',)),
+ Revision('b1a', ('a1a',)),
+ Revision('b1b', ('a1b', ), dependencies='a3'),
+
+ Revision('base2', (), branch_labels='b_2'),
+ Revision('a2', ('base2',)),
+ Revision('b2', ('a2',)),
+ Revision('c2', ('b2', ), dependencies='a3'),
+ Revision('d2', ('c2',)),
+
+ Revision('base3', (), branch_labels='b_3'),
+ Revision('a3', ('base3',)),
+ Revision('b3', ('a3',)),
+ ]
+ )
+
+ def test_what_are_the_heads(self):
+ eq_(self.map.heads, ("b1a", "b1b", "d2", "b3"))
+
+ def test_heads_to_base(self):
+ self._assert_iteration(
+ "heads", "base",
+ [
+
+ 'b1a', 'a1a', 'b1b', 'a1b', 'd2', 'c2', 'b2', 'a2', 'base2',
+ 'b3', 'a3', 'base3',
+ 'base1'
+ ]
+ )
+
+ def test_we_need_head2(self):
+ # the 2 branch relies on the 3 branch
+ self._assert_iteration(
+ "b_2@head", "base",
+ ['d2', 'c2', 'b2', 'a2', 'base2', 'a3', 'base3']
+ )
+
+ def test_we_need_head3(self):
+ # the 3 branch can be upgraded alone.
+ self._assert_iteration(
+ "b_3@head", "base",
+ ['b3', 'a3', 'base3']
+ )
+
+ def test_we_need_head1(self):
+ # the 1 branch relies on the 3 branch
+ self._assert_iteration(
+ "b1b@head", "base",
+ ['b1b', 'a1b', 'base1', 'a3', 'base3']
+ )
+
+ def test_we_need_base2(self):
+ # consider a downgrade to b_2@base - we
+ # want to run through all the "2"s alone, and we're done.
+ self._assert_iteration(
+ "heads", "b_2@base",
+ ['d2', 'c2', 'b2', 'a2', 'base2']
+ )
+
+ def test_we_need_base3(self):
+ # consider a downgrade to b_3@base - due to the a3 dependency, we
+ # need to downgrade everything dependent on a3
+ # as well, which means b1b and c2. Then we can downgrade
+ # the 3s.
+ self._assert_iteration(
+ "heads", "b_3@base",
+ ['b1b', 'd2', 'c2', 'b3', 'a3', 'base3']
+ )
+
+
+class MultipleBaseCrossDependencyTestTwo(DownIterateTest):
+ def setUp(self):
+ self.map = RevisionMap(
+ lambda: [
+ Revision('base1', (), branch_labels='b_1'),
+ Revision('a1', 'base1'),
+ Revision('b1', 'a1'),
+ Revision('c1', 'b1'),
+
+ Revision('base2', (), dependencies='base1', branch_labels='b_2'),
+ Revision('a2', 'base2'),
+ Revision('b2', 'a2'),
+ Revision('c2', 'b2'),
+ Revision('d2', 'c2'),
+
+ Revision('base3', (), branch_labels='b_3'),
+ Revision('a3', 'base3'),
+ Revision('b3', 'a3'),
+ Revision('c3', 'b3', dependencies='b2'),
+ Revision('d3', 'c3'),
+ ]
+ )
+
+ def test_what_are_the_heads(self):
+ eq_(self.map.heads, ("c1", "d2", "d3"))
+
+ def test_heads_to_base(self):
+ self._assert_iteration(
+ "heads", "base",
+ [
+ 'c1', 'b1', 'a1',
+ 'd2', 'c2',
+ 'd3', 'c3', 'b3', 'a3', 'base3',
+ 'b2', 'a2', 'base2',
+ 'base1'
+ ]
+ )
+
+ def test_we_need_head2(self):
+ self._assert_iteration(
+ "b_2@head", "base",
+ ['d2', 'c2', 'b2', 'a2', 'base2', 'base1']
+ )
+
+ def test_we_need_head3(self):
+ self._assert_iteration(
+ "b_3@head", "base",
+ ['d3', 'c3', 'b3', 'a3', 'base3', 'b2', 'a2', 'base2', 'base1']
+ )
+
+ def test_we_need_head1(self):
+ self._assert_iteration(
+ "b_1@head", "base",
+ ['c1', 'b1', 'a1', 'base1']
+ )
+
+ def test_we_need_base1(self):
+ self._assert_iteration(
+ "heads", "b_1@base",
+ [
+ 'c1', 'b1', 'a1',
+ 'd2', 'c2',
+ 'd3', 'c3', 'b2', 'a2', 'base2',
+ 'base1'
+ ]
+ )
+
+ def test_we_need_base2(self):
+ self._assert_iteration(
+ "heads", "b_2@base",
+ ['d2', 'c2', 'd3', 'c3', 'b2', 'a2', 'base2']
+ )
+
+ def test_we_need_base3(self):
+ self._assert_iteration(
+ "heads", "b_3@base",
+ ['d3', 'c3', 'b3', 'a3', 'base3']
+ )
+