summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing/assertsql.py
blob: 3e0d4c9d3c2583a86e2e38790c877a9c202ed0ee (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
# testing/assertsql.py
# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php

from ..engine.default import DefaultDialect
from .. import util
import re


class AssertRule(object):

    def process_execute(self, clauseelement, *multiparams, **params):
        pass

    def process_cursor_execute(self, statement, parameters, context,
                               executemany):
        pass

    def is_consumed(self):
        """Return True if this rule has been consumed, False if not.

        Should raise an AssertionError if this rule's condition has
        definitely failed.

        """

        raise NotImplementedError()

    def rule_passed(self):
        """Return True if the last test of this rule passed, False if
        failed, None if no test was applied."""

        raise NotImplementedError()

    def consume_final(self):
        """Return True if this rule has been consumed.

        Should raise an AssertionError if this rule's condition has not
        been consumed or has failed.

        """

        if self._result is None:
            assert False, 'Rule has not been consumed'
        return self.is_consumed()


class SQLMatchRule(AssertRule):
    def __init__(self):
        self._result = None
        self._errmsg = ""

    def rule_passed(self):
        return self._result

    def is_consumed(self):
        if self._result is None:
            return False

        assert self._result, self._errmsg

        return True


class ExactSQL(SQLMatchRule):

    def __init__(self, sql, params=None):
        SQLMatchRule.__init__(self)
        self.sql = sql
        self.params = params

    def process_cursor_execute(self, statement, parameters, context,
                               executemany):
        if not context:
            return
        _received_statement = \
            _process_engine_statement(context.unicode_statement,
                context)
        _received_parameters = context.compiled_parameters

        # TODO: remove this step once all unit tests are migrated, as
        # ExactSQL should really be *exact* SQL

        sql = _process_assertion_statement(self.sql, context)
        equivalent = _received_statement == sql
        if self.params:
            if util.callable(self.params):
                params = self.params(context)
            else:
                params = self.params
            if not isinstance(params, list):
                params = [params]
            equivalent = equivalent and params \
                == context.compiled_parameters
        else:
            params = {}
        self._result = equivalent
        if not self._result:
            self._errmsg = \
                'Testing for exact statement %r exact params %r, '\
                'received %r with params %r' % (sql, params,
                    _received_statement, _received_parameters)


class RegexSQL(SQLMatchRule):

    def __init__(self, regex, params=None):
        SQLMatchRule.__init__(self)
        self.regex = re.compile(regex)
        self.orig_regex = regex
        self.params = params

    def process_cursor_execute(self, statement, parameters, context,
                               executemany):
        if not context:
            return
        _received_statement = \
            _process_engine_statement(context.unicode_statement,
                context)
        _received_parameters = context.compiled_parameters
        equivalent = bool(self.regex.match(_received_statement))
        if self.params:
            if util.callable(self.params):
                params = self.params(context)
            else:
                params = self.params
            if not isinstance(params, list):
                params = [params]

            # do a positive compare only

            for param, received in zip(params, _received_parameters):
                for k, v in param.items():
                    if k not in received or received[k] != v:
                        equivalent = False
                        break
        else:
            params = {}
        self._result = equivalent
        if not self._result:
            self._errmsg = \
                'Testing for regex %r partial params %r, received %r '\
                'with params %r' % (self.orig_regex, params,
                                    _received_statement,
                                    _received_parameters)


class CompiledSQL(SQLMatchRule):

    def __init__(self, statement, params=None):
        SQLMatchRule.__init__(self)
        self.statement = statement
        self.params = params

    def process_cursor_execute(self, statement, parameters, context,
                               executemany):
        if not context:
            return
        from sqlalchemy.schema import _DDLCompiles
        _received_parameters = list(context.compiled_parameters)

        # recompile from the context, using the default dialect

        if isinstance(context.compiled.statement, _DDLCompiles):
            compiled = \
                context.compiled.statement.compile(dialect=DefaultDialect())
        else:
            compiled = \
                context.compiled.statement.compile(dialect=DefaultDialect(),
                column_keys=context.compiled.column_keys)
        _received_statement = re.sub(r'[\n\t]', '', str(compiled))
        equivalent = self.statement == _received_statement
        if self.params:
            if util.callable(self.params):
                params = self.params(context)
            else:
                params = self.params
            if not isinstance(params, list):
                params = [params]
            else:
                params = list(params)
            all_params = list(params)
            all_received = list(_received_parameters)
            while params:
                param = dict(params.pop(0))
                for k, v in context.compiled.params.items():
                    param.setdefault(k, v)
                if param not in _received_parameters:
                    equivalent = False
                    break
                else:
                    _received_parameters.remove(param)
            if _received_parameters:
                equivalent = False
        else:
            params = {}
            all_params = {}
            all_received = []
        self._result = equivalent
        if not self._result:
            print('Testing for compiled statement %r partial params '\
                '%r, received %r with params %r' % (self.statement,
                    all_params, _received_statement, all_received))
            self._errmsg = \
                'Testing for compiled statement %r partial params %r, '\
                'received %r with params %r' % (self.statement,
                    all_params, _received_statement, all_received)


            # print self._errmsg

class CountStatements(AssertRule):

    def __init__(self, count):
        self.count = count
        self._statement_count = 0

    def process_execute(self, clauseelement, *multiparams, **params):
        self._statement_count += 1

    def process_cursor_execute(self, statement, parameters, context,
                               executemany):
        pass

    def is_consumed(self):
        return False

    def consume_final(self):
        assert self.count == self._statement_count, \
            'desired statement count %d does not match %d' \
            % (self.count, self._statement_count)
        return True


class AllOf(AssertRule):

    def __init__(self, *rules):
        self.rules = set(rules)

    def process_execute(self, clauseelement, *multiparams, **params):
        for rule in self.rules:
            rule.process_execute(clauseelement, *multiparams, **params)

    def process_cursor_execute(self, statement, parameters, context,
                               executemany):
        for rule in self.rules:
            rule.process_cursor_execute(statement, parameters, context,
                    executemany)

    def is_consumed(self):
        if not self.rules:
            return True
        for rule in list(self.rules):
            if rule.rule_passed():  # a rule passed, move on
                self.rules.remove(rule)
                return len(self.rules) == 0
        assert False, 'No assertion rules were satisfied for statement'

    def consume_final(self):
        return len(self.rules) == 0


def _process_engine_statement(query, context):
    if util.jython:

        # oracle+zxjdbc passes a PyStatement when returning into

        query = str(query)
    if context.engine.name == 'mssql' \
        and query.endswith('; select scope_identity()'):
        query = query[:-25]
    query = re.sub(r'\n', '', query)
    return query


def _process_assertion_statement(query, context):
    paramstyle = context.dialect.paramstyle
    if paramstyle == 'named':
        pass
    elif paramstyle == 'pyformat':
        query = re.sub(r':([\w_]+)', r"%(\1)s", query)
    else:
        # positional params
        repl = None
        if paramstyle == 'qmark':
            repl = "?"
        elif paramstyle == 'format':
            repl = r"%s"
        elif paramstyle == 'numeric':
            repl = None
        query = re.sub(r':([\w_]+)', repl, query)

    return query


class SQLAssert(object):

    rules = None

    def add_rules(self, rules):
        self.rules = list(rules)

    def statement_complete(self):
        for rule in self.rules:
            if not rule.consume_final():
                assert False, \
                    'All statements are complete, but pending '\
                    'assertion rules remain'

    def clear_rules(self):
        del self.rules

    def execute(self, conn, clauseelement, multiparams, params, result):
        if self.rules is not None:
            if not self.rules:
                assert False, \
                    'All rules have been exhausted, but further '\
                    'statements remain'
            rule = self.rules[0]
            rule.process_execute(clauseelement, *multiparams, **params)
            if rule.is_consumed():
                self.rules.pop(0)

    def cursor_execute(self, conn, cursor, statement, parameters,
                       context, executemany):
        if self.rules:
            rule = self.rules[0]
            rule.process_cursor_execute(statement, parameters, context,
                    executemany)

asserter = SQLAssert()