diff options
Diffstat (limited to 'lib/sqlalchemy/testing/assertsql.py')
| -rw-r--r-- | lib/sqlalchemy/testing/assertsql.py | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index a596d9743..243493607 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -85,7 +85,7 @@ class CompiledSQL(SQLMatchRule): column_keys=context.compiled.column_keys, inline=context.compiled.inline) ) - _received_statement = re.sub(r'[\n\t]', '', str(compiled)) + _received_statement = re.sub(r'[\n\t]', '', util.text_type(compiled)) parameters = execute_observed.parameters if not parameters: @@ -188,21 +188,27 @@ class DialectSQL(CompiledSQL): def _compile_dialect(self, execute_observed): return execute_observed.context.dialect + def _compare_no_space(self, real_stmt, received_stmt): + stmt = re.sub(r'[\n\t]', '', real_stmt) + return received_stmt == stmt + def _received_statement(self, execute_observed): received_stmt, received_params = super(DialectSQL, self).\ _received_statement(execute_observed) + + # TODO: why do we need this part? for real_stmt in execute_observed.statements: - if real_stmt.statement == received_stmt: + if self._compare_no_space(real_stmt.statement, received_stmt): break else: raise AssertionError( "Can't locate compiled statement %r in list of " "statements actually invoked" % received_stmt) + return received_stmt, execute_observed.context.compiled_parameters def _compare_sql(self, execute_observed, received_statement): stmt = re.sub(r'[\n\t]', '', self.statement) - # convert our comparison statement to have the # paramstyle of the received paramstyle = execute_observed.context.dialect.paramstyle |
