diff options
Diffstat (limited to 'alembic/revision.py')
-rw-r--r-- | alembic/revision.py | 140 |
1 files changed, 93 insertions, 47 deletions
diff --git a/alembic/revision.py b/alembic/revision.py index 1afb203..7a09e1a 100644 --- a/alembic/revision.py +++ b/alembic/revision.py @@ -1,11 +1,12 @@ import re import collections +import itertools from . import util from sqlalchemy import util as sqlautil from . import compat -_relative_destination = re.compile(r'(?:(.+?)@)?((?:\+|-)\d+)') +_relative_destination = re.compile(r'(?:(.+?)@)?(\w+)?((?:\+|-)\d+)') class RevisionError(Exception): @@ -402,6 +403,83 @@ class RevisionMap(object): else: return util.to_tuple(id_, default=None), branch_label + def _relative_iterate( + self, destination, source, is_upwards, + implicit_base, inclusive, assert_relative_length): + if isinstance(destination, compat.string_types): + match = _relative_destination.match(destination) + if not match: + return None + else: + return None + + relative = int(match.group(3)) + symbol = match.group(2) + branch_label = match.group(1) + + reldelta = 1 if inclusive and not symbol else 0 + + if is_upwards: + if branch_label: + from_ = "%s@head" % branch_label + elif symbol: + if symbol.startswith("head"): + from_ = symbol + else: + from_ = "%s@head" % symbol + else: + from_ = "head" + to_ = source + else: + if branch_label: + to_ = "%s@base" % branch_label + elif symbol: + to_ = "%s@base" % symbol + else: + to_ = "base" + from_ = source + + revs = list( + self._iterate_revisions( + from_, to_, + inclusive=inclusive, implicit_base=implicit_base)) + + if symbol: + if branch_label: + symbol_rev = self.get_revision( + "%s@%s" % (branch_label, symbol)) + else: + symbol_rev = self.get_revision(symbol) + if symbol.startswith("head"): + index = 0 + elif symbol == "base": + index = len(revs) - 1 + else: + range_ = compat.range(len(revs) - 1, 0, -1) + for index in range_: + if symbol_rev.revision == revs[index].revision: + break + else: + index = 0 + else: + index = 0 + if is_upwards: + revs = revs[index - relative - reldelta:] + if not index and assert_relative_length and \ + len(revs) < abs(relative - reldelta): + raise RevisionError( + "Relative revision %s didn't " + "produce %d migrations" % (destination, abs(relative))) + else: + revs = revs[0:index - relative + reldelta] + if not index and assert_relative_length and \ + len(revs) != abs(relative) + reldelta: + raise RevisionError( + "Relative revision %s didn't " + "produce %d migrations" % (destination, abs(relative))) + + return iter(revs) + def iterate_revisions( self, upper, lower, implicit_base=False, inclusive=False, assert_relative_length=True): @@ -417,54 +495,22 @@ class RevisionMap(object): """ - if isinstance(upper, compat.string_types) and \ - _relative_destination.match(upper): - - reldelta = 1 if inclusive else 0 - match = _relative_destination.match(upper) - relative = int(match.group(2)) - branch_label = match.group(1) - if branch_label: - from_ = "%s@head" % branch_label - else: - from_ = "head" - revs = list( - self._iterate_revisions( - from_, lower, - inclusive=inclusive, implicit_base=implicit_base)) - revs = revs[-relative - reldelta:] - if assert_relative_length and \ - len(revs) != abs(relative) + reldelta: - raise RevisionError( - "Relative revision %s didn't " - "produce %d migrations" % (upper, abs(relative))) - return iter(revs) - elif isinstance(lower, compat.string_types) and \ - _relative_destination.match(lower): - reldelta = 1 if inclusive else 0 - match = _relative_destination.match(lower) - relative = int(match.group(2)) - branch_label = match.group(1) + relative_upper = self._relative_iterate( + upper, lower, True, implicit_base, + inclusive, assert_relative_length + ) + if relative_upper: + return relative_upper - if branch_label: - to_ = "%s@base" % branch_label - else: - to_ = "base" + relative_lower = self._relative_iterate( + lower, upper, False, implicit_base, + inclusive, assert_relative_length + ) + if relative_lower: + return relative_lower - revs = list( - self._iterate_revisions( - upper, to_, - inclusive=inclusive, implicit_base=implicit_base)) - revs = revs[0:-relative + reldelta] - if assert_relative_length and \ - len(revs) != abs(relative) + reldelta: - raise RevisionError( - "Relative revision %s didn't " - "produce %d migrations" % (lower, abs(relative))) - return iter(revs) - else: - return self._iterate_revisions( - upper, lower, inclusive=inclusive, implicit_base=implicit_base) + return self._iterate_revisions( + upper, lower, inclusive=inclusive, implicit_base=implicit_base) def _get_descendant_nodes( self, targets, map_=None, check=False, include_dependencies=True): |