summaryrefslogtreecommitdiff
path: root/alembic/migration.py
blob: 7d509af7bd73c710511c670e215a347a9b2e902f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
import logging
import sys
from contextlib import contextmanager

from sqlalchemy import MetaData, Table, Column, String, literal_column
from sqlalchemy import create_engine
from sqlalchemy.engine import url as sqla_url

from .compat import callable, EncodedIO
from . import ddl, util

log = logging.getLogger(__name__)


class MigrationContext(object):

    """Represent the database state made available to a migration
    script.

    :class:`.MigrationContext` is the front end to an actual
    database connection, or alternatively a string output
    stream given a particular database dialect,
    from an Alembic perspective.

    When inside the ``env.py`` script, the :class:`.MigrationContext`
    is available via the
    :meth:`.EnvironmentContext.get_context` method,
    which is available at ``alembic.context``::

        # from within env.py script
        from alembic import context
        migration_context = context.get_context()

    For usage outside of an ``env.py`` script, such as for
    utility routines that want to check the current version
    in the database, the :meth:`.MigrationContext.configure`
    method to create new :class:`.MigrationContext` objects.
    For example, to get at the current revision in the
    database using :meth:`.MigrationContext.get_current_revision`::

        # in any application, outside of an env.py script
        from alembic.migration import MigrationContext
        from sqlalchemy import create_engine

        engine = create_engine("postgresql://mydatabase")
        conn = engine.connect()

        context = MigrationContext.configure(conn)
        current_rev = context.get_current_revision()

    The above context can also be used to produce
    Alembic migration operations with an :class:`.Operations`
    instance::

        # in any application, outside of the normal Alembic environment
        from alembic.operations import Operations
        op = Operations(context)
        op.alter_column("mytable", "somecolumn", nullable=True)

    """

    def __init__(self, dialect, connection, opts, environment_context=None):
        self.environment_context = environment_context
        self.opts = opts
        self.dialect = dialect
        self.script = opts.get('script')

        as_sql = opts.get('as_sql', False)
        transactional_ddl = opts.get("transactional_ddl")

        self._transaction_per_migration = opts.get(
            "transaction_per_migration", False)

        if as_sql:
            self.connection = self._stdout_connection(connection)
            assert self.connection is not None
        else:
            self.connection = connection
        self._migrations_fn = opts.get('fn')
        self.as_sql = as_sql

        if "output_encoding" in opts:
            self.output_buffer = EncodedIO(
                opts.get("output_buffer") or sys.stdout,
                opts['output_encoding']
            )
        else:
            self.output_buffer = opts.get("output_buffer", sys.stdout)

        self._user_compare_type = opts.get('compare_type', False)
        self._user_compare_server_default = opts.get(
            'compare_server_default',
            False)
        self.version_table = version_table = opts.get(
            'version_table', 'alembic_version')
        self.version_table_schema = version_table_schema = \
            opts.get('version_table_schema', None)
        self._version = Table(
            version_table, MetaData(),
            Column('version_num', String(32), nullable=False),
            schema=version_table_schema)

        self._start_from_rev = opts.get("starting_rev")
        self.impl = ddl.DefaultImpl.get_by_dialect(dialect)(
            dialect, self.connection, self.as_sql,
            transactional_ddl,
            self.output_buffer,
            opts
        )
        log.info("Context impl %s.", self.impl.__class__.__name__)
        if self.as_sql:
            log.info("Generating static SQL")
        log.info("Will assume %s DDL.",
                 "transactional" if self.impl.transactional_ddl
                 else "non-transactional")

    @classmethod
    def configure(cls,
                  connection=None,
                  url=None,
                  dialect_name=None,
                  environment_context=None,
                  opts=None,
                  ):
        """Create a new :class:`.MigrationContext`.

        This is a factory method usually called
        by :meth:`.EnvironmentContext.configure`.

        :param connection: a :class:`~sqlalchemy.engine.Connection`
         to use for SQL execution in "online" mode.  When present,
         is also used to determine the type of dialect in use.
        :param url: a string database url, or a
         :class:`sqlalchemy.engine.url.URL` object.
         The type of dialect to be used will be derived from this if
         ``connection`` is not passed.
        :param dialect_name: string name of a dialect, such as
         "postgresql", "mssql", etc.  The type of dialect to be used will be
         derived from this if ``connection`` and ``url`` are not passed.
        :param opts: dictionary of options.  Most other options
         accepted by :meth:`.EnvironmentContext.configure` are passed via
         this dictionary.

        """
        if opts is None:
            opts = {}

        if connection:
            dialect = connection.dialect
        elif url:
            url = sqla_url.make_url(url)
            dialect = url.get_dialect()()
        elif dialect_name:
            url = sqla_url.make_url("%s://" % dialect_name)
            dialect = url.get_dialect()()
        else:
            raise Exception("Connection, url, or dialect_name is required.")

        return MigrationContext(dialect, connection, opts, environment_context)

    def begin_transaction(self, _per_migration=False):
        transaction_now = _per_migration == self._transaction_per_migration

        if not transaction_now:
            @contextmanager
            def do_nothing():
                yield
            return do_nothing()

        elif not self.impl.transactional_ddl:
            @contextmanager
            def do_nothing():
                yield
            return do_nothing()
        elif self.as_sql:
            @contextmanager
            def begin_commit():
                self.impl.emit_begin()
                yield
                self.impl.emit_commit()
            return begin_commit()
        else:
            return self.bind.begin()

    def get_current_revision(self):
        """Return the current revision, usually that which is present
        in the ``alembic_version`` table in the database.

        This method intends to be used only for a migration stream that
        does not contain unmerged branches in the target database;
        if there are multiple branches present, an exception is raised.
        The :meth:`.MigrationContext.get_current_heads` should be preferred
        over this method going forward in order to be compatible with
        branch migration support.

        If this :class:`.MigrationContext` was configured in "offline"
        mode, that is with ``as_sql=True``, the ``starting_rev``
        parameter is returned instead, if any.

        """
        heads = self.get_current_heads()
        if len(heads) == 0:
            return None
        elif len(heads) > 1:
            raise util.CommandError(
                "Version table '%s' has more than one head present; "
                "please use get_current_heads()" % self.version_table)
        else:
            return heads[0]

    def get_current_heads(self):
        """Return a tuple of the current 'head versions' that are represented
        in the target database.

        For a migration stream without branches, this will be a single
        value, synonymous with that of
        :meth:`.MigrationContext.get_current_revision`.   However when multiple
        unmerged branches exist within the target database, the returned tuple
        will contain a value for each head.

        If this :class:`.MigrationContext` was configured in "offline"
        mode, that is with ``as_sql=True``, the ``starting_rev``
        parameter is returned in a one-length tuple.

        If no version table is present, or if there are no revisions
        present, an empty tuple is returned.

        .. versionadded:: 0.7.0

        """
        if self.as_sql:
            return util.to_tuple(self._start_from_rev, default=())
        else:
            if self._start_from_rev:
                raise util.CommandError(
                    "Can't specify current_rev to context "
                    "when using a database connection")
            if not self._has_version_table():
                return ()
        return tuple(
            row[0] for row in self.connection.execute(self._version.select())
        )

    def _ensure_version_table(self):
        self._version.create(self.connection, checkfirst=True)

    def _has_version_table(self):
        return self.connection.dialect.has_table(
            self.connection, self.version_table, self.version_table_schema)

    def stamp(self, script_directory, revision):
        """Stamp the version table with a specific revision.

        This method calculates those branches to which the given revision
        can apply, and updates those branches as though they were migrated
        towards that revision (either up or down).  If no current branches
        include the revision, it is added as a new branch head.

        .. versionadded:: 0.7.0

        """
        heads = self.get_current_heads()
        head_maintainer = HeadMaintainer(self, heads)
        for step in script_directory._steps_revs(revision, heads):
            head_maintainer.update_to_step(step)

    def run_migrations(self, **kw):
        """Run the migration scripts established for this
        :class:`.MigrationContext`, if any.

        The commands in :mod:`alembic.command` will set up a function
        that is ultimately passed to the :class:`.MigrationContext`
        as the ``fn`` argument.  This function represents the "work"
        that will be done when :meth:`.MigrationContext.run_migrations`
        is called, typically from within the ``env.py`` script of the
        migration environment.  The "work function" then provides an iterable
        of version callables and other version information which
        in the case of the ``upgrade`` or ``downgrade`` commands are the
        list of version scripts to invoke.  Other commands yield nothing,
        in the case that a command wants to run some other operation
        against the database such as the ``current`` or ``stamp`` commands.

        :param \**kw: keyword arguments here will be passed to each
         migration callable, that is the ``upgrade()`` or ``downgrade()``
         method within revision scripts.

        """
        self.impl.start_migrations()

        heads = self.get_current_heads()
        if not self.as_sql and not heads:
            self._ensure_version_table()

        head_maintainer = HeadMaintainer(self, heads)

        for step in self._migrations_fn(heads, self):
            with self.begin_transaction(_per_migration=True):
                if self.as_sql and not head_maintainer.heads:
                    # for offline mode, include a CREATE TABLE from
                    # the base
                    self._version.create(self.connection)
                log.info("Running %s", step)
                if self.as_sql:
                    self.impl.static_output("-- Running %s" % (step.short_log,))
                step.migration_fn(**kw)

                # previously, we wouldn't stamp per migration
                # if we were in a transaction, however given the more
                # complex model that involves any number of inserts
                # and row-targeted updates and deletes, it's simpler for now
                # just to run the operations on every version
                head_maintainer.update_to_step(step)

        if self.as_sql and not head_maintainer.heads:
            self._version.drop(self.connection)

    def execute(self, sql, execution_options=None):
        """Execute a SQL construct or string statement.

        The underlying execution mechanics are used, that is
        if this is "offline mode" the SQL is written to the
        output buffer, otherwise the SQL is emitted on
        the current SQLAlchemy connection.

        """
        self.impl._exec(sql, execution_options)

    def _stdout_connection(self, connection):
        def dump(construct, *multiparams, **params):
            self.impl._exec(construct)

        return create_engine("%s://" % self.dialect.name,
                             strategy="mock", executor=dump)

    @property
    def bind(self):
        """Return the current "bind".

        In online mode, this is an instance of
        :class:`sqlalchemy.engine.Connection`, and is suitable
        for ad-hoc execution of any kind of usage described
        in :ref:`sqlexpression_toplevel` as well as
        for usage with the :meth:`sqlalchemy.schema.Table.create`
        and :meth:`sqlalchemy.schema.MetaData.create_all` methods
        of :class:`~sqlalchemy.schema.Table`,
        :class:`~sqlalchemy.schema.MetaData`.

        Note that when "standard output" mode is enabled,
        this bind will be a "mock" connection handler that cannot
        return results and is only appropriate for a very limited
        subset of commands.

        """
        return self.connection

    @property
    def config(self):
        """Return the :class:`.Config` used by the current environment, if any.

        .. versionadded:: 0.6.6

        """
        if self.environment_context:
            return self.environment_context.config
        else:
            return None

    def _compare_type(self, inspector_column, metadata_column):
        if self._user_compare_type is False:
            return False

        if callable(self._user_compare_type):
            user_value = self._user_compare_type(
                self,
                inspector_column,
                metadata_column,
                inspector_column.type,
                metadata_column.type
            )
            if user_value is not None:
                return user_value

        return self.impl.compare_type(
            inspector_column,
            metadata_column)

    def _compare_server_default(self, inspector_column,
                                metadata_column,
                                rendered_metadata_default,
                                rendered_column_default):

        if self._user_compare_server_default is False:
            return False

        if callable(self._user_compare_server_default):
            user_value = self._user_compare_server_default(
                self,
                inspector_column,
                metadata_column,
                rendered_column_default,
                metadata_column.server_default,
                rendered_metadata_default
            )
            if user_value is not None:
                return user_value

        return self.impl.compare_server_default(
            inspector_column,
            metadata_column,
            rendered_metadata_default,
            rendered_column_default)


class HeadMaintainer(object):
    def __init__(self, context, heads):
        self.context = context
        self.heads = set(heads)

    def _insert_version(self, version):
        assert version not in self.heads
        self.heads.add(version)

        self.context.impl._exec(
            self.context._version.insert().
            values(
                version_num=literal_column("'%s'" % version)
            )
        )

    def _delete_version(self, version):
        self.heads.remove(version)

        ret = self.context.impl._exec(
            self.context._version.delete().where(
                self.context._version.c.version_num ==
                literal_column("'%s'" % version)))
        if not self.context.as_sql and ret.rowcount != 1:
            raise util.CommandError(
                "Online migration expected to match one "
                "row when deleting '%s' in '%s'; "
                "%d found"
                % (version,
                   self.context.version_table, ret.rowcount))

    def _update_version(self, from_, to_):
        assert to_ not in self.heads
        self.heads.remove(from_)
        self.heads.add(to_)

        ret = self.context.impl._exec(
            self.context._version.update().
            values(version_num=literal_column("'%s'" % to_)).where(
                self.context._version.c.version_num
                == literal_column("'%s'" % from_))
        )
        if not self.context.as_sql and ret.rowcount != 1:
            raise util.CommandError(
                "Online migration expected to match one "
                "row when updating '%s' to '%s' in '%s'; "
                "%d found"
                % (from_, to_, self.context.version_table, ret.rowcount))

    def update_to_step(self, step):
        if step.should_delete_branch(self.heads):
            vers = step.delete_version_num
            log.debug("branch delete %s", vers)
            self._delete_version(vers)
        elif step.should_create_branch(self.heads):
            vers = step.insert_version_num
            log.debug("new branch insert %s", vers)
            self._insert_version(vers)
        elif step.should_merge_branches(self.heads):
            # delete revs, update from rev, update to rev
            (delete_revs, update_from_rev,
             update_to_rev) = step.merge_branch_idents
            log.debug(
                "merge, delete %s, update %s to %s",
                delete_revs, update_from_rev, update_to_rev)
            for delrev in delete_revs:
                self._delete_version(delrev)
            self._update_version(update_from_rev, update_to_rev)
        elif step.should_unmerge_branches(self.heads):
            (update_from_rev, update_to_rev,
             insert_revs) = step.unmerge_branch_idents
            log.debug(
                "unmerge, insert %s, update %s to %s",
                insert_revs, update_from_rev, update_to_rev)
            for insrev in insert_revs:
                self._insert_version(insrev)
            self._update_version(update_from_rev, update_to_rev)
        else:
            from_, to_ = step.update_version_num
            log.debug("update %s to %s", from_, to_)
            self._update_version(from_, to_)


class MigrationStep(object):
    @property
    def name(self):
        return self.migration_fn.__name__

    @classmethod
    def upgrade_from_script(cls, revision_map, script):
        return RevisionStep(revision_map, script, True)

    @classmethod
    def downgrade_from_script(cls, revision_map, script):
        return RevisionStep(revision_map, script, False)

    @property
    def is_downgrade(self):
        return not self.is_upgrade

    @property
    def merge_branch_idents(self):
        return (
            # delete revs, update from rev, update to rev
            list(self.from_revisions[0:-1]), self.from_revisions[-1],
            self.to_revisions[0]
        )

    @property
    def unmerge_branch_idents(self):
        return (
            # update from rev, update to rev, insert revs
            self.from_revisions[0], self.to_revisions[-1],
            list(self.to_revisions[0:-1])
        )

    @property
    def short_log(self):
        return "%s %s -> %s" % (
            self.name,
            util.format_as_comma(self.from_revisions),
            util.format_as_comma(self.to_revisions)
        )

    def __str__(self):
        if self.doc:
            return "%s %s -> %s, %s" % (
                self.name,
                util.format_as_comma(self.from_revisions),
                util.format_as_comma(self.to_revisions),
                self.doc
            )
        else:
            return self.short_log


class RevisionStep(MigrationStep):
    def __init__(self, revision_map, revision, is_upgrade):
        self.revision_map = revision_map
        self.revision = revision
        self.is_upgrade = is_upgrade
        if is_upgrade:
            self.migration_fn = revision.module.upgrade
        else:
            self.migration_fn = revision.module.downgrade

    def __eq__(self, other):
        return isinstance(other, RevisionStep) and \
            other.revision == self.revision and \
            self.is_upgrade == other.is_upgrade

    @property
    def doc(self):
        return self.revision.doc

    @property
    def from_revisions(self):
        if self.is_upgrade:
            return self.revision._down_revision_tuple
        else:
            return (self.revision.revision, )

    @property
    def to_revisions(self):
        if self.is_upgrade:
            return (self.revision.revision, )
        else:
            return self.revision._down_revision_tuple

    @property
    def _has_scalar_down_revision(self):
        return len(self.revision._down_revision_tuple) == 1

    def should_delete_branch(self, heads):
        if not self.is_downgrade:
            return False

        if self.revision.revision not in heads:
            return False

        downrevs = self.revision._down_revision_tuple
        if not downrevs:
            # is a base
            return True
        elif len(downrevs) == 1:
            downrev = self.revision_map.get_revision(downrevs[0])

            if not downrev.is_branch_point:
                return False

            descendants = set(
                r.revision for r in self.revision_map._get_descendant_nodes(
                    self.revision_map.get_revisions(downrev.nextrev),
                    check=False
                )
            )

            # the downrev is a branchpoint, and other members or descendants
            # of the branch are still in heads; so delete this branch.
            # the reason this occurs is because traversal tries to stay
            # fully on one branch down to the branchpoint before starting
            # the other; so if we have a->b->(c1->d1->e1, c2->d2->e2),
            # on a downgrade from the top we may go e1, d1, c1, now heads
            # are at c1 and e2, with the current method, we don't know that
            # "e2" is important unless we get all descendants of c1/c2

            if len(descendants.intersection(heads).difference(
                    [self.revision.revision])):

            # TODO: this doesn't work; make sure tests are here to ensure
            # this fails
            #if len(downrev.nextrev.intersection(heads).difference(
            #        [self.revision.revision])):

                return True
            else:
                return False
        else:
            # is a merge point
            return False

    def should_create_branch(self, heads):
        if not self.is_upgrade:
            return False

        downrevs = self.revision._down_revision_tuple

        if not downrevs:
            # is a base
            return True
        elif len(downrevs) == 1:
            if downrevs[0] in heads:
                return False
            else:
                return True
        else:
            # is a merge point
            return False

    def should_merge_branches(self, heads):
        if not self.is_upgrade:
            return False

        downrevs = self.revision._down_revision_tuple

        if len(downrevs) > 1 and \
                len(heads.intersection(downrevs)) > 1:
            return True

        return False

    def should_unmerge_branches(self, heads):
        if not self.is_downgrade:
            return False

        downrevs = self.revision._down_revision_tuple

        if self.revision.revision in heads and len(downrevs) > 1:
            return True

        return False

    @property
    def update_version_num(self):
        assert self._has_scalar_down_revision
        if self.is_upgrade:
            return self.revision.down_revision, self.revision.revision
        else:
            return self.revision.revision, self.revision.down_revision

    @property
    def delete_version_num(self):
        return self.revision.revision

    @property
    def insert_version_num(self):
        return self.revision.revision


class StampStep(MigrationStep):
    def __init__(self, from_, to_, is_upgrade, branch_move):
        self.from_ = util.to_tuple(from_, default=())
        self.to_ = util.to_tuple(to_, default=())
        self.is_upgrade = is_upgrade
        self.branch_move = branch_move
        self.migration_fn = self.stamp_revision

    doc = None

    def stamp_revision(self, **kw):
        return None

    def __eq__(self, other):
        return isinstance(other, StampStep) and \
            other.from_revisions == self.revisions and \
            other.to_revisions == self.to_revisions and \
            other.branch_move == self.branch_move and \
            self.is_upgrade == other.is_upgrade

    @property
    def from_revisions(self):
        return self.from_

    @property
    def to_revisions(self):
        return self.to_

    @property
    def delete_version_num(self):
        assert len(self.from_) == 1
        return self.from_[0]

    @property
    def insert_version_num(self):
        assert len(self.to_) == 1
        return self.to_[0]

    @property
    def update_version_num(self):
        assert len(self.from_) == 1
        assert len(self.to_) == 1
        return self.from_[0], self.to_[0]

    def should_delete_branch(self, heads):
        return self.is_downgrade and self.branch_move

    def should_create_branch(self, heads):
        return self.is_upgrade and self.branch_move

    def should_merge_branches(self, heads):
        return len(self.from_) > 1

    def should_unmerge_branches(self, heads):
        return len(self.to_) > 1