summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2014-11-22 11:49:20 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2014-11-22 11:49:20 -0500
commit296c5160285500fd52973d327608410c6b4ebbdd (patch)
tree9ca8f3ddb95b8e13b948666315c46a60c13d7cdf
parent43bb4ec14aacc8e014d565f88876a32bc8a90428 (diff)
downloadalembic-296c5160285500fd52973d327608410c6b4ebbdd.tar.gz
- break out the concept of "down revision" into two pieces:
down_revision and "dependencies". For migration traversal, the downrevs we care about are the union of these two sets. however for location of nodes and branch labeling, we look only at down_revsion. this works really well and allows us to have mutually-dependent trees that can easily be itererated independently of each other. docs are needed
-rw-r--r--alembic/migration.py22
-rw-r--r--alembic/revision.py153
-rw-r--r--alembic/script.py22
-rw-r--r--tests/test_revision.py157
4 files changed, 305 insertions, 49 deletions
diff --git a/alembic/migration.py b/alembic/migration.py
index 7d854ba..eefb1d1 100644
--- a/alembic/migration.py
+++ b/alembic/migration.py
@@ -553,7 +553,7 @@ class RevisionStep(MigrationStep):
@property
def from_revisions(self):
if self.is_upgrade:
- return self.revision._down_revision_tuple
+ return self.revision._all_down_revisions
else:
return (self.revision.revision, )
@@ -562,11 +562,11 @@ class RevisionStep(MigrationStep):
if self.is_upgrade:
return (self.revision.revision, )
else:
- return self.revision._down_revision_tuple
+ return self.revision._all_down_revisions
@property
def _has_scalar_down_revision(self):
- return len(self.revision._down_revision_tuple) == 1
+ return len(self.revision._all_down_revisions) == 1
def should_delete_branch(self, heads):
if not self.is_downgrade:
@@ -575,7 +575,7 @@ class RevisionStep(MigrationStep):
if self.revision.revision not in heads:
return False
- downrevs = self.revision._down_revision_tuple
+ downrevs = self.revision._all_down_revisions
if not downrevs:
# is a base
return True
@@ -587,7 +587,7 @@ class RevisionStep(MigrationStep):
descendants = set(
r.revision for r in self.revision_map._get_descendant_nodes(
- self.revision_map.get_revisions(downrev.nextrev),
+ self.revision_map.get_revisions(downrev._all_nextrev),
check=False
)
)
@@ -606,7 +606,7 @@ class RevisionStep(MigrationStep):
# TODO: this doesn't work; make sure tests are here to ensure
# this fails
- #if len(downrev.nextrev.intersection(heads).difference(
+ #if len(downrev._all_nextrev.intersection(heads).difference(
# [self.revision.revision])):
return True
@@ -662,7 +662,7 @@ class RevisionStep(MigrationStep):
if not self.is_upgrade:
return False
- downrevs = self.revision._down_revision_tuple
+ downrevs = self.revision._all_down_revisions
if not downrevs:
# is a base
@@ -680,7 +680,7 @@ class RevisionStep(MigrationStep):
if not self.is_upgrade:
return False
- downrevs = self.revision._down_revision_tuple
+ downrevs = self.revision._all_down_revisions
if len(downrevs) > 1 and \
len(heads.intersection(downrevs)) > 1:
@@ -692,7 +692,7 @@ class RevisionStep(MigrationStep):
if not self.is_downgrade:
return False
- downrevs = self.revision._down_revision_tuple
+ downrevs = self.revision._all_down_revisions
if self.revision.revision in heads and len(downrevs) > 1:
return True
@@ -701,12 +701,12 @@ class RevisionStep(MigrationStep):
def update_version_num(self, heads):
if not self._has_scalar_down_revision:
- downrev = heads.intersection(self.revision._down_revision_tuple)
+ downrev = heads.intersection(self.revision._all_down_revisions)
assert len(downrev) == 1, \
"Can't do an UPDATE because downrevision is ambiguous"
down_revision = list(downrev)[0]
else:
- down_revision = self.revision.down_revision
+ down_revision = self.revision._all_down_revisions[0]
if self.is_upgrade:
return down_revision, self.revision.revision
diff --git a/alembic/revision.py b/alembic/revision.py
index 071452f..0ab52ab 100644
--- a/alembic/revision.py
+++ b/alembic/revision.py
@@ -82,6 +82,16 @@ class RevisionMap(object):
return self.bases
@util.memoized_property
+ def _real_bases(self):
+ """All "real" base revisions as strings.
+
+ :return: a tuple of string revision numbers.
+
+ """
+ self._revision_map
+ return self._real_bases
+
+ @util.memoized_property
def _revision_map(self):
"""memoized attribute, initializes the revision map from the
initial collection.
@@ -91,6 +101,7 @@ class RevisionMap(object):
heads = sqlautil.OrderedSet()
self.bases = ()
+ self._real_bases = ()
has_branch_labels = set()
for revision in self._generator():
@@ -104,14 +115,16 @@ class RevisionMap(object):
heads.add(revision.revision)
if revision.is_base:
self.bases += (revision.revision, )
+ if revision._is_real_base:
+ self._real_bases += (revision.revision, )
for rev in map_.values():
- for downrev in rev._down_revision_tuple:
+ for downrev in rev._all_down_revisions:
if downrev not in map_:
util.warn("Revision %s referenced from %s is not present"
- % (rev.down_revision, rev))
+ % (downrev, rev))
down_revision = map_[downrev]
- down_revision.add_nextrev(rev.revision)
+ down_revision.add_nextrev(rev)
heads.discard(downrev)
map_[None] = map_[()] = None
@@ -133,12 +146,13 @@ class RevisionMap(object):
)
map_[branch_label] = revision
revision.branch_labels.update(revision.branch_labels)
- for node in self._get_descendant_nodes([revision], map_):
+ for node in self._get_descendant_nodes(
+ [revision], map_, include_dependencies=False):
node.branch_labels.update(revision.branch_labels)
parent = node
while parent and \
- not parent.is_branch_point and not parent.is_merge_point:
+ not parent._is_real_branch_point and not parent.is_merge_point:
parent.branch_labels.update(revision.branch_labels)
if parent.down_revision:
@@ -164,18 +178,20 @@ class RevisionMap(object):
self._add_branches(revision, map_)
if revision.is_base:
self.bases += (revision.revision, )
- for downrev in revision._down_revision_tuple:
+ if revision._is_real_base:
+ self._real_bases += (revision.revision, )
+ for downrev in revision._all_down_revisions:
if downrev not in map_:
util.warn(
"Revision %s referenced from %s is not present"
- % (revision.down_revision, revision)
+ % (downrev, revision)
)
- map_[downrev].add_nextrev(revision.revision)
- if revision.is_head:
+ map_[downrev].add_nextrev(revision)
+ if revision._is_real_head:
self.heads = tuple(
head for head in self.heads
if head not in
- set(revision._down_revision_tuple).union([revision.revision])
+ set(revision._all_down_revisions).union([revision.revision])
) + (revision.revision,)
def get_current_head(self, branch_label=None):
@@ -334,8 +350,12 @@ class RevisionMap(object):
]
return bool(
- set(self._get_descendant_nodes([target]))
- .union(self._get_ancestor_nodes([target]))
+ set(self._get_descendant_nodes([target],
+ include_dependencies=False
+ ))
+ .union(self._get_ancestor_nodes([target],
+ include_dependencies=False
+ ))
.intersection(test_against_revs)
)
@@ -424,16 +444,28 @@ class RevisionMap(object):
return self._iterate_revisions(
upper, lower, inclusive=inclusive, implicit_base=implicit_base)
- def _get_descendant_nodes(self, targets, map_=None, check=False):
+ def _get_descendant_nodes(
+ self, targets, map_=None, check=False, include_dependencies=True):
+
+ if include_dependencies:
+ fn = lambda rev: rev._all_nextrev
+ else:
+ fn = lambda rev: rev.nextrev
+
return self._iterate_related_revisions(
- lambda rev: rev.nextrev,
- targets, map_=map_, check=check
+ fn, targets, map_=map_, check=check
)
- def _get_ancestor_nodes(self, targets, map_=None, check=False):
+ def _get_ancestor_nodes(
+ self, targets, map_=None, check=False, include_dependencies=True):
+
+ if include_dependencies:
+ fn = lambda rev: rev._all_down_revisions
+ else:
+ fn = lambda rev: rev._versioned_down_revisions
+
return self._iterate_related_revisions(
- lambda rev: rev._down_revision_tuple,
- targets, map_=map_, check=check
+ fn, targets, map_=map_, check=check
)
def _iterate_related_revisions(self, fn, targets, map_, check=False):
@@ -494,17 +526,17 @@ class RevisionMap(object):
difference(lower_ancestors).\
difference(lower_descendants)
for rev in candidate_lowers:
- for downrev in rev._down_revision_tuple:
+ for downrev in rev._all_down_revisions:
if self._revision_map[downrev] in candidate_lowers:
break
else:
base_lowers.add(rev)
lowers = base_lowers.union(requested_lowers)
elif implicit_base:
- base_lowers = set(self.get_revisions(self.bases))
+ base_lowers = set(self.get_revisions(self._real_bases))
lowers = base_lowers.union(requested_lowers)
elif not requested_lowers:
- lowers = set(self.get_revisions(self.bases))
+ lowers = set(self.get_revisions(self._real_bases))
else:
lowers = requested_lowers
@@ -523,12 +555,12 @@ class RevisionMap(object):
branch_todo = set(
rev for rev in
(self._revision_map[rev] for rev in total_space)
- if rev.is_branch_point and
- len(total_space.intersection(rev.nextrev)) > 1
+ if rev._is_real_branch_point and
+ len(total_space.intersection(rev._all_nextrev)) > 1
)
# it's not possible for any "uppers" to be in branch_todo,
- # because the .nextrev of those nodes is not in total_space
+ # because the ._all_nextrev of those nodes is not in total_space
#assert not branch_todo.intersection(uppers)
todo = collections.deque(
@@ -541,11 +573,17 @@ class RevisionMap(object):
# descendants left in the queue
if not todo:
todo.extendleft(
- rev for rev in branch_todo
- if not rev.nextrev.intersection(total_space)
+ sorted(
+ (
+ rev for rev in branch_todo
+ if not rev._all_nextrev.intersection(total_space)
+ ),
+ # favor "revisioned" branch points before
+ # dependent ones
+ key=lambda rev: 0 if rev.is_branch_point else 1
+ )
)
branch_todo.difference_update(todo)
-
# iterate nodes that are in the immediate todo
while todo:
rev = todo.popleft()
@@ -555,7 +593,7 @@ class RevisionMap(object):
# don't consume any actual branch nodes
todo.extendleft([
self._revision_map[downrev]
- for downrev in reversed(rev._down_revision_tuple)
+ for downrev in reversed(rev._all_down_revisions)
if self._revision_map[downrev] not in branch_todo
and downrev in total_space])
@@ -577,28 +615,56 @@ class Revision(object):
"""
nextrev = frozenset()
+ """following revisions, based on down_revision only."""
+
+ _all_nextrev = frozenset()
revision = None
"""The string revision number."""
down_revision = None
- """The ``down_revision`` identifier(s) within the migration script."""
+ """The ``down_revision`` identifier(s) within the migration script.
+
+ Note that the total set of "down" revisions is
+ down_revision + dependencies.
+
+ """
+
+ dependencies = None
+ """Additional revisions which this revision is dependent on.
+
+ From a migration standpoint, these dependencies are added to the
+ down_revision to form the full iteration. However, the separation
+ of down_revision from "dependencies" is to assist in navigating
+ a history that contains many branches, typically a multi-root scenario.
+
+ """
branch_labels = None
"""Optional string/tuple of symbolic names to apply to this
revision's branch"""
- def __init__(self, revision, down_revision, branch_labels=None):
+ def __init__(
+ self, revision, down_revision,
+ dependencies=None, branch_labels=None):
self.revision = revision
self.down_revision = tuple_rev_as_scalar(down_revision)
+ self.dependencies = tuple_rev_as_scalar(dependencies)
self._orig_branch_labels = util.to_tuple(branch_labels, default=())
self.branch_labels = set(self._orig_branch_labels)
- def add_nextrev(self, rev):
- self.nextrev = self.nextrev.union([rev])
+ def add_nextrev(self, revision):
+ self._all_nextrev = self._all_nextrev.union([revision.revision])
+ if self.revision in revision._versioned_down_revisions:
+ self.nextrev = self.nextrev.union([revision.revision])
+
+ @property
+ def _all_down_revisions(self):
+ return util.to_tuple(self.down_revision, default=()) + \
+ util.to_tuple(self.dependencies, default=())
@property
- def _down_revision_tuple(self):
+ def _versioned_down_revisions(self):
return util.to_tuple(self.down_revision, default=())
@property
@@ -613,12 +679,23 @@ class Revision(object):
return not bool(self.nextrev)
@property
+ def _is_real_head(self):
+ return not bool(self._all_nextrev)
+
+ @property
def is_base(self):
"""Return True if this :class:`.Revision` is a 'base' revision."""
return self.down_revision is None
@property
+ def _is_real_base(self):
+ """Return True if this :class:`.Revision` is a "real" base revision,
+ e.g. that it has no dependencies either."""
+
+ return self.down_revision is None and self.dependencies is None
+
+ @property
def is_branch_point(self):
"""Return True if this :class:`.Script` is a branch point.
@@ -631,10 +708,18 @@ class Revision(object):
return len(self.nextrev) > 1
@property
+ def _is_real_branch_point(self):
+ """Return True if this :class:`.Script` is a 'real' branch point,
+ taking into account dependencies as well.
+
+ """
+ return len(self._all_nextrev) > 1
+
+ @property
def is_merge_point(self):
"""Return True if this :class:`.Script` is a merge point."""
- return len(self._down_revision_tuple) > 1
+ return len(self._versioned_down_revisions) > 1
def tuple_rev_as_scalar(rev):
diff --git a/alembic/script.py b/alembic/script.py
index 1835605..2147652 100644
--- a/alembic/script.py
+++ b/alembic/script.py
@@ -483,7 +483,10 @@ class Script(revision.Revision):
rev_id,
module.down_revision,
branch_labels=util.to_tuple(
- getattr(module, 'branch_labels', None), default=()))
+ getattr(module, 'branch_labels', None), default=()),
+ dependencies=util.to_tuple(
+ getattr(module, 'depends_on', None), default=())
+ )
module = None
"""The Python module representing the actual script itself."""
@@ -522,6 +525,10 @@ class Script(revision.Revision):
else:
entry += "Parent: %s\n" % (self._format_down_revision(), )
+ if self.dependencies:
+ entry += "Depends on: %s\n" % (
+ util.format_as_comma(self.dependencies))
+
if self.is_branch_point:
entry += "Branches into: %s\n" % (
util.format_as_comma(self.nextrev))
@@ -554,8 +561,15 @@ class Script(revision.Revision):
include_parents=False, tree_indicators=True):
text = self.revision
if include_parents:
- text = "%s -> %s" % (
- self._format_down_revision(), text)
+ if self.dependencies:
+ text = "%s (%s) -> %s" % (
+ self._format_down_revision(),
+ util.format_as_comma(self.dependencies),
+ text
+ )
+ else:
+ text = "%s -> %s" % (
+ self._format_down_revision(), text)
if include_branches and self.branch_labels:
text += " (%s)" % util.format_as_comma(self.branch_labels)
if tree_indicators:
@@ -584,7 +598,7 @@ class Script(revision.Revision):
if not self.down_revision:
return "<base>"
else:
- return util.format_as_comma(self._down_revision_tuple)
+ return util.format_as_comma(self._versioned_down_revisions)
@classmethod
def _from_path(cls, scriptdir, path):
diff --git a/tests/test_revision.py b/tests/test_revision.py
index 10fae91..e9e8935 100644
--- a/tests/test_revision.py
+++ b/tests/test_revision.py
@@ -715,3 +715,160 @@ class MultipleBaseTest(DownIterateTest):
['b3', 'a3', 'b2', 'a2', 'base2'],
inclusive=False, implicit_base=True
)
+
+
+class MultipleBaseCrossDependencyTestOne(DownIterateTest):
+ def setUp(self):
+ self.map = RevisionMap(
+ lambda: [
+ Revision('base1', (), branch_labels='b_1'),
+ Revision('a1a', ('base1',)),
+ Revision('a1b', ('base1',)),
+ Revision('b1a', ('a1a',)),
+ Revision('b1b', ('a1b', ), dependencies='a3'),
+
+ Revision('base2', (), branch_labels='b_2'),
+ Revision('a2', ('base2',)),
+ Revision('b2', ('a2',)),
+ Revision('c2', ('b2', ), dependencies='a3'),
+ Revision('d2', ('c2',)),
+
+ Revision('base3', (), branch_labels='b_3'),
+ Revision('a3', ('base3',)),
+ Revision('b3', ('a3',)),
+ ]
+ )
+
+ def test_what_are_the_heads(self):
+ eq_(self.map.heads, ("b1a", "b1b", "d2", "b3"))
+
+ def test_heads_to_base(self):
+ self._assert_iteration(
+ "heads", "base",
+ [
+
+ 'b1a', 'a1a', 'b1b', 'a1b', 'd2', 'c2', 'b2', 'a2', 'base2',
+ 'b3', 'a3', 'base3',
+ 'base1'
+ ]
+ )
+
+ def test_we_need_head2(self):
+ # the 2 branch relies on the 3 branch
+ self._assert_iteration(
+ "b_2@head", "base",
+ ['d2', 'c2', 'b2', 'a2', 'base2', 'a3', 'base3']
+ )
+
+ def test_we_need_head3(self):
+ # the 3 branch can be upgraded alone.
+ self._assert_iteration(
+ "b_3@head", "base",
+ ['b3', 'a3', 'base3']
+ )
+
+ def test_we_need_head1(self):
+ # the 1 branch relies on the 3 branch
+ self._assert_iteration(
+ "b1b@head", "base",
+ ['b1b', 'a1b', 'base1', 'a3', 'base3']
+ )
+
+ def test_we_need_base2(self):
+ # consider a downgrade to b_2@base - we
+ # want to run through all the "2"s alone, and we're done.
+ self._assert_iteration(
+ "heads", "b_2@base",
+ ['d2', 'c2', 'b2', 'a2', 'base2']
+ )
+
+ def test_we_need_base3(self):
+ # consider a downgrade to b_3@base - due to the a3 dependency, we
+ # need to downgrade everything dependent on a3
+ # as well, which means b1b and c2. Then we can downgrade
+ # the 3s.
+ self._assert_iteration(
+ "heads", "b_3@base",
+ ['b1b', 'd2', 'c2', 'b3', 'a3', 'base3']
+ )
+
+
+class MultipleBaseCrossDependencyTestTwo(DownIterateTest):
+ def setUp(self):
+ self.map = RevisionMap(
+ lambda: [
+ Revision('base1', (), branch_labels='b_1'),
+ Revision('a1', 'base1'),
+ Revision('b1', 'a1'),
+ Revision('c1', 'b1'),
+
+ Revision('base2', (), dependencies='base1', branch_labels='b_2'),
+ Revision('a2', 'base2'),
+ Revision('b2', 'a2'),
+ Revision('c2', 'b2'),
+ Revision('d2', 'c2'),
+
+ Revision('base3', (), branch_labels='b_3'),
+ Revision('a3', 'base3'),
+ Revision('b3', 'a3'),
+ Revision('c3', 'b3', dependencies='b2'),
+ Revision('d3', 'c3'),
+ ]
+ )
+
+ def test_what_are_the_heads(self):
+ eq_(self.map.heads, ("c1", "d2", "d3"))
+
+ def test_heads_to_base(self):
+ self._assert_iteration(
+ "heads", "base",
+ [
+ 'c1', 'b1', 'a1',
+ 'd2', 'c2',
+ 'd3', 'c3', 'b3', 'a3', 'base3',
+ 'b2', 'a2', 'base2',
+ 'base1'
+ ]
+ )
+
+ def test_we_need_head2(self):
+ self._assert_iteration(
+ "b_2@head", "base",
+ ['d2', 'c2', 'b2', 'a2', 'base2', 'base1']
+ )
+
+ def test_we_need_head3(self):
+ self._assert_iteration(
+ "b_3@head", "base",
+ ['d3', 'c3', 'b3', 'a3', 'base3', 'b2', 'a2', 'base2', 'base1']
+ )
+
+ def test_we_need_head1(self):
+ self._assert_iteration(
+ "b_1@head", "base",
+ ['c1', 'b1', 'a1', 'base1']
+ )
+
+ def test_we_need_base1(self):
+ self._assert_iteration(
+ "heads", "b_1@base",
+ [
+ 'c1', 'b1', 'a1',
+ 'd2', 'c2',
+ 'd3', 'c3', 'b2', 'a2', 'base2',
+ 'base1'
+ ]
+ )
+
+ def test_we_need_base2(self):
+ self._assert_iteration(
+ "heads", "b_2@base",
+ ['d2', 'c2', 'd3', 'c3', 'b2', 'a2', 'base2']
+ )
+
+ def test_we_need_base3(self):
+ self._assert_iteration(
+ "heads", "b_3@base",
+ ['d3', 'c3', 'b3', 'a3', 'base3']
+ )
+