summaryrefslogtreecommitdiff
path: root/alembic/revision.py
diff options
context:
space:
mode:
Diffstat (limited to 'alembic/revision.py')
-rw-r--r--alembic/revision.py140
1 files changed, 93 insertions, 47 deletions
diff --git a/alembic/revision.py b/alembic/revision.py
index 1afb203..7a09e1a 100644
--- a/alembic/revision.py
+++ b/alembic/revision.py
@@ -1,11 +1,12 @@
import re
import collections
+import itertools
from . import util
from sqlalchemy import util as sqlautil
from . import compat
-_relative_destination = re.compile(r'(?:(.+?)@)?((?:\+|-)\d+)')
+_relative_destination = re.compile(r'(?:(.+?)@)?(\w+)?((?:\+|-)\d+)')
class RevisionError(Exception):
@@ -402,6 +403,83 @@ class RevisionMap(object):
else:
return util.to_tuple(id_, default=None), branch_label
+ def _relative_iterate(
+ self, destination, source, is_upwards,
+ implicit_base, inclusive, assert_relative_length):
+ if isinstance(destination, compat.string_types):
+ match = _relative_destination.match(destination)
+ if not match:
+ return None
+ else:
+ return None
+
+ relative = int(match.group(3))
+ symbol = match.group(2)
+ branch_label = match.group(1)
+
+ reldelta = 1 if inclusive and not symbol else 0
+
+ if is_upwards:
+ if branch_label:
+ from_ = "%s@head" % branch_label
+ elif symbol:
+ if symbol.startswith("head"):
+ from_ = symbol
+ else:
+ from_ = "%s@head" % symbol
+ else:
+ from_ = "head"
+ to_ = source
+ else:
+ if branch_label:
+ to_ = "%s@base" % branch_label
+ elif symbol:
+ to_ = "%s@base" % symbol
+ else:
+ to_ = "base"
+ from_ = source
+
+ revs = list(
+ self._iterate_revisions(
+ from_, to_,
+ inclusive=inclusive, implicit_base=implicit_base))
+
+ if symbol:
+ if branch_label:
+ symbol_rev = self.get_revision(
+ "%s@%s" % (branch_label, symbol))
+ else:
+ symbol_rev = self.get_revision(symbol)
+ if symbol.startswith("head"):
+ index = 0
+ elif symbol == "base":
+ index = len(revs) - 1
+ else:
+ range_ = compat.range(len(revs) - 1, 0, -1)
+ for index in range_:
+ if symbol_rev.revision == revs[index].revision:
+ break
+ else:
+ index = 0
+ else:
+ index = 0
+ if is_upwards:
+ revs = revs[index - relative - reldelta:]
+ if not index and assert_relative_length and \
+ len(revs) < abs(relative - reldelta):
+ raise RevisionError(
+ "Relative revision %s didn't "
+ "produce %d migrations" % (destination, abs(relative)))
+ else:
+ revs = revs[0:index - relative + reldelta]
+ if not index and assert_relative_length and \
+ len(revs) != abs(relative) + reldelta:
+ raise RevisionError(
+ "Relative revision %s didn't "
+ "produce %d migrations" % (destination, abs(relative)))
+
+ return iter(revs)
+
def iterate_revisions(
self, upper, lower, implicit_base=False, inclusive=False,
assert_relative_length=True):
@@ -417,54 +495,22 @@ class RevisionMap(object):
"""
- if isinstance(upper, compat.string_types) and \
- _relative_destination.match(upper):
-
- reldelta = 1 if inclusive else 0
- match = _relative_destination.match(upper)
- relative = int(match.group(2))
- branch_label = match.group(1)
- if branch_label:
- from_ = "%s@head" % branch_label
- else:
- from_ = "head"
- revs = list(
- self._iterate_revisions(
- from_, lower,
- inclusive=inclusive, implicit_base=implicit_base))
- revs = revs[-relative - reldelta:]
- if assert_relative_length and \
- len(revs) != abs(relative) + reldelta:
- raise RevisionError(
- "Relative revision %s didn't "
- "produce %d migrations" % (upper, abs(relative)))
- return iter(revs)
- elif isinstance(lower, compat.string_types) and \
- _relative_destination.match(lower):
- reldelta = 1 if inclusive else 0
- match = _relative_destination.match(lower)
- relative = int(match.group(2))
- branch_label = match.group(1)
+ relative_upper = self._relative_iterate(
+ upper, lower, True, implicit_base,
+ inclusive, assert_relative_length
+ )
+ if relative_upper:
+ return relative_upper
- if branch_label:
- to_ = "%s@base" % branch_label
- else:
- to_ = "base"
+ relative_lower = self._relative_iterate(
+ lower, upper, False, implicit_base,
+ inclusive, assert_relative_length
+ )
+ if relative_lower:
+ return relative_lower
- revs = list(
- self._iterate_revisions(
- upper, to_,
- inclusive=inclusive, implicit_base=implicit_base))
- revs = revs[0:-relative + reldelta]
- if assert_relative_length and \
- len(revs) != abs(relative) + reldelta:
- raise RevisionError(
- "Relative revision %s didn't "
- "produce %d migrations" % (lower, abs(relative)))
- return iter(revs)
- else:
- return self._iterate_revisions(
- upper, lower, inclusive=inclusive, implicit_base=implicit_base)
+ return self._iterate_revisions(
+ upper, lower, inclusive=inclusive, implicit_base=implicit_base)
def _get_descendant_nodes(
self, targets, map_=None, check=False, include_dependencies=True):