diff options
Diffstat (limited to 'lib/sqlalchemy/testing')
| -rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/assertsql.py | 48 |
2 files changed, 38 insertions, 15 deletions
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index dfea33dc7..c0854ea55 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -497,8 +497,9 @@ class AssertsExecutionResults(object): def assert_sql_execution(self, db, callable_, *rules): with self.sql_execution_asserter(db) as asserter: - callable_() + result = callable_() asserter.assert_(*rules) + return result def assert_sql(self, db, callable_, rules): @@ -512,7 +513,7 @@ class AssertsExecutionResults(object): newrule = assertsql.CompiledSQL(*rule) newrules.append(newrule) - self.assert_sql_execution(db, callable_, *newrules) + return self.assert_sql_execution(db, callable_, *newrules) def assert_sql_count(self, db, callable_, count): self.assert_sql_execution( diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index e39b6315d..86d850733 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -282,6 +282,32 @@ class AllOf(AssertRule): self.errormessage = list(self.rules)[0].errormessage +class EachOf(AssertRule): + + def __init__(self, *rules): + self.rules = list(rules) + + def process_statement(self, execute_observed): + while self.rules: + rule = self.rules[0] + rule.process_statement(execute_observed) + if rule.is_consumed: + self.rules.pop(0) + elif rule.errormessage: + self.errormessage = rule.errormessage + if rule.consume_statement: + break + + if not self.rules: + self.is_consumed = True + + def no_more_statements(self): + if self.rules and not self.rules[0].is_consumed: + self.rules[0].no_more_statements() + elif self.rules: + super(EachOf, self).no_more_statements() + + class Or(AllOf): def process_statement(self, execute_observed): @@ -319,24 +345,20 @@ class SQLAsserter(object): del self.accumulated def assert_(self, *rules): - rules = list(rules) - observed = list(self._final) + rule = EachOf(*rules) - while observed and rules: - rule = rules[0] - rule.process_statement(observed[0]) + observed = list(self._final) + while observed: + statement = observed.pop(0) + rule.process_statement(statement) if rule.is_consumed: - rules.pop(0) + break elif rule.errormessage: assert False, rule.errormessage - - if rule.consume_statement: - observed.pop(0) - - if not observed and rules: - rules[0].no_more_statements() - elif not rules and observed: + if observed: assert False, "Additional SQL statements remain" + elif not rule.is_consumed: + rule.no_more_statements() @contextlib.contextmanager |
