summaryrefslogtreecommitdiff
path: root/alembic/testing/fixtures.py
blob: ae25fd276827839eb0393eeed6bedd205860f07d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
# coding: utf-8
import io
import re

from sqlalchemy import create_engine, text, MetaData

import alembic
from alembic.compat import configparser
from alembic import util
from alembic.compat import string_types, text_type
from alembic.migration import MigrationContext
from alembic.environment import EnvironmentContext
from alembic.operations import Operations
from alembic.ddl.impl import _impls
from contextlib import contextmanager
from .plugin.plugin_base import SkipTest
from .assertions import _get_dialect, eq_
from . import mock

testing_config = configparser.ConfigParser()
testing_config.read(['test.cfg'])


if not util.sqla_094:
    class TestBase(object):
        # A sequence of database names to always run, regardless of the
        # constraints below.
        __whitelist__ = ()

        # A sequence of requirement names matching testing.requires decorators
        __requires__ = ()

        # A sequence of dialect names to exclude from the test class.
        __unsupported_on__ = ()

        # If present, test class is only runnable for the *single* specified
        # dialect.  If you need multiple, use __unsupported_on__ and invert.
        __only_on__ = None

        # A sequence of no-arg callables. If any are True, the entire testcase is
        # skipped.
        __skip_if__ = None

        def assert_(self, val, msg=None):
            assert val, msg

        # apparently a handful of tests are doing this....OK
        def setup(self):
            if hasattr(self, "setUp"):
                self.setUp()

        def teardown(self):
            if hasattr(self, "tearDown"):
                self.tearDown()
else:
    from sqlalchemy.testing.fixtures import TestBase


def capture_db():
    buf = []

    def dump(sql, *multiparams, **params):
        buf.append(str(sql.compile(dialect=engine.dialect)))
    engine = create_engine("postgresql://", strategy="mock", executor=dump)
    return engine, buf

_engs = {}


@contextmanager
def capture_context_buffer(**kw):
    if kw.pop('bytes_io', False):
        buf = io.BytesIO()
    else:
        buf = io.StringIO()

    kw.update({
        'dialect_name': "sqlite",
        'output_buffer': buf
    })
    conf = EnvironmentContext.configure

    def configure(*arg, **opt):
        opt.update(**kw)
        return conf(*arg, **opt)

    with mock.patch.object(EnvironmentContext, "configure", configure):
        yield buf


def op_fixture(
        dialect='default', as_sql=False,
        naming_convention=None, literal_binds=False):

    opts = {}
    if naming_convention:
        if not util.sqla_092:
            raise SkipTest(
                "naming_convention feature requires "
                "sqla 0.9.2 or greater")
        opts['target_metadata'] = MetaData(naming_convention=naming_convention)

    class buffer_(object):
        def __init__(self):
            self.lines = []

        def write(self, msg):
            msg = msg.strip()
            msg = re.sub(r'[\n\t]', '', msg)
            if as_sql:
                # the impl produces soft tabs,
                # so search for blocks of 4 spaces
                msg = re.sub(r'    ', '', msg)
                msg = re.sub('\;\n*$', '', msg)

            self.lines.append(msg)

        def flush(self):
            pass

    buf = buffer_()

    class ctx(MigrationContext):
        def clear_assertions(self):
            buf.lines[:] = []

        def assert_(self, *sql):
            # TODO: make this more flexible about
            # whitespace and such
            eq_(buf.lines, list(sql))

        def assert_contains(self, sql):
            for stmt in buf.lines:
                if sql in stmt:
                    return
            else:
                assert False, "Could not locate fragment %r in %r" % (
                    sql,
                    buf.lines
                )

    if as_sql:
        opts['as_sql'] = as_sql
    if literal_binds:
        opts['literal_binds'] = literal_binds
    ctx_dialect = _get_dialect(dialect)
    if not as_sql:
        def execute(stmt, *multiparam, **param):
            if isinstance(stmt, string_types):
                stmt = text(stmt)
            assert stmt.supports_execution
            sql = text_type(stmt.compile(dialect=ctx_dialect))

            buf.write(sql)

        connection = mock.Mock(dialect=ctx_dialect, execute=execute)
    else:
        opts['output_buffer'] = buf
        connection = None
    context = ctx(
        ctx_dialect,
        connection,
        opts)

    alembic.op._proxy = Operations(context)
    return context