summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/testing')
-rw-r--r--lib/sqlalchemy/testing/assertions.py5
-rw-r--r--lib/sqlalchemy/testing/assertsql.py48
2 files changed, 38 insertions, 15 deletions
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index dfea33dc7..c0854ea55 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -497,8 +497,9 @@ class AssertsExecutionResults(object):
def assert_sql_execution(self, db, callable_, *rules):
with self.sql_execution_asserter(db) as asserter:
- callable_()
+ result = callable_()
asserter.assert_(*rules)
+ return result
def assert_sql(self, db, callable_, rules):
@@ -512,7 +513,7 @@ class AssertsExecutionResults(object):
newrule = assertsql.CompiledSQL(*rule)
newrules.append(newrule)
- self.assert_sql_execution(db, callable_, *newrules)
+ return self.assert_sql_execution(db, callable_, *newrules)
def assert_sql_count(self, db, callable_, count):
self.assert_sql_execution(
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index e39b6315d..86d850733 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -282,6 +282,32 @@ class AllOf(AssertRule):
self.errormessage = list(self.rules)[0].errormessage
+class EachOf(AssertRule):
+
+ def __init__(self, *rules):
+ self.rules = list(rules)
+
+ def process_statement(self, execute_observed):
+ while self.rules:
+ rule = self.rules[0]
+ rule.process_statement(execute_observed)
+ if rule.is_consumed:
+ self.rules.pop(0)
+ elif rule.errormessage:
+ self.errormessage = rule.errormessage
+ if rule.consume_statement:
+ break
+
+ if not self.rules:
+ self.is_consumed = True
+
+ def no_more_statements(self):
+ if self.rules and not self.rules[0].is_consumed:
+ self.rules[0].no_more_statements()
+ elif self.rules:
+ super(EachOf, self).no_more_statements()
+
+
class Or(AllOf):
def process_statement(self, execute_observed):
@@ -319,24 +345,20 @@ class SQLAsserter(object):
del self.accumulated
def assert_(self, *rules):
- rules = list(rules)
- observed = list(self._final)
+ rule = EachOf(*rules)
- while observed and rules:
- rule = rules[0]
- rule.process_statement(observed[0])
+ observed = list(self._final)
+ while observed:
+ statement = observed.pop(0)
+ rule.process_statement(statement)
if rule.is_consumed:
- rules.pop(0)
+ break
elif rule.errormessage:
assert False, rule.errormessage
-
- if rule.consume_statement:
- observed.pop(0)
-
- if not observed and rules:
- rules[0].no_more_statements()
- elif not rules and observed:
+ if observed:
assert False, "Additional SQL statements remain"
+ elif not rule.is_consumed:
+ rule.no_more_statements()
@contextlib.contextmanager