diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-04-03 16:34:03 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-04-03 16:34:03 +0000 |
| commit | abb10856dcea07ca4d38d28df4e493d11d8fd345 (patch) | |
| tree | 3936f0a7a9ddf7560928b6e21c4fd4a8669c337a | |
| parent | a27d6be28a0beb35da2e3eb1dfed7ab7460d7654 (diff) | |
| download | sqlalchemy-abb10856dcea07ca4d38d28df4e493d11d8fd345.tar.gz | |
- case() interprets the "THEN" expressions
as values by default, meaning case([(x==y, "foo")]) will
interpret "foo" as a bound value, not a SQL expression.
use text(expr) for literal SQL expressions in this case.
For the criterion itself, these may be literal strings
only if the "value" keyword is present, otherwise SA
will force explicit usage of either text() or literal().
| -rw-r--r-- | CHANGES | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 47 | ||||
| -rw-r--r-- | test/sql/case_statement.py | 35 |
3 files changed, 68 insertions, 23 deletions
@@ -196,8 +196,13 @@ CHANGES symptom. - The case() function now also takes a dictionary as its whens - parameter. But beware that it doesn't escape literals, use - the literal construct for that. + parameter. It also interprets the "THEN" expressions + as values by default, meaning case([(x==y, "foo")]) will + interpret "foo" as a bound value, not a SQL expression. + use text(expr) for literal SQL expressions in this case. + For the criterion itself, these may be literal strings + only if the "value" keyword is present, otherwise SA + will force explicit usage of either text() or literal(). - declarative extension - The "synonym" function is now directly usable with diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index cc97227a7..39a2ae3eb 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -392,7 +392,7 @@ def not_(clause): result. """ - return operators.inv(clause) + return operators.inv(_literal_as_binds(clause)) def distinct(expr): """Return a ``DISTINCT`` clause.""" @@ -416,24 +416,45 @@ def case(whens, value=None, else_=None): """Produce a ``CASE`` statement. whens - A sequence of pairs or a dict to be translated into "when / then" clauses. + A sequence of pairs, or alternatively a dict, + to be translated into "WHEN / THEN" clauses. value - Optional for simple case statements. + Optional for simple case statements, produces + a column expression as in "CASE <expr> WHEN ..." else\_ - Optional as well, for case defaults. + Optional as well, for case defaults produces + the "ELSE" portion of the "CASE" statement. + + The expressions used for THEN and ELSE, + when specified as strings, will be interpreted + as bound values. To specify textual SQL expressions + for these, use the text(<string>) construct. + + The expressions used for the WHEN criterion + may only be literal strings when "value" is + present, i.e. CASE table.somecol WHEN "x" THEN "y". + Otherwise, literal strings are not accepted + in this position, and either the text(<string>) + or literal(<string>) constructs must be used to + interpret raw string values. + """ - try: whens = util.dictlike_iteritems(whens) except TypeError: pass - - whenlist = [ClauseList('WHEN', c, 'THEN', r, operator=None) + + if value: + crit_filter = _literal_as_binds + else: + crit_filter = _no_literals + + whenlist = [ClauseList('WHEN', crit_filter(c), 'THEN', _literal_as_binds(r), operator=None) for (c,r) in whens] - if not else_ is None: - whenlist.append(ClauseList('ELSE', else_, operator=None)) + if else_ is not None: + whenlist.append(ClauseList('ELSE', _literal_as_binds(else_), operator=None)) if whenlist: type = list(whenlist[-1])[-1].type else: @@ -842,6 +863,14 @@ def _literal_as_binds(element, name=None, type_=None): else: return element +def _no_literals(element): + if isinstance(element, Operators): + return element.expression_element() + elif _is_literal(element): + raise exceptions.ArgumentError("Ambiguous literal: %r. Use the 'text()' function to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element) + else: + return element + def _corresponding_column_or_error(fromclause, column, require_embedded=False): c = fromclause.corresponding_column(column, require_embedded=require_embedded) if not c: diff --git a/test/sql/case_statement.py b/test/sql/case_statement.py index 730517b21..257298c8e 100644 --- a/test/sql/case_statement.py +++ b/test/sql/case_statement.py @@ -2,10 +2,11 @@ import testenv; testenv.configure_for_tests() import sys from sqlalchemy import * from testlib import * -from sqlalchemy import util +from sqlalchemy import util, exceptions +from sqlalchemy.sql import table, column -class CaseTest(TestBase): +class CaseTest(TestBase, AssertsCompiledSQL): def setUpAll(self): metadata = MetaData(testing.db) @@ -30,9 +31,9 @@ class CaseTest(TestBase): def testcase(self): inner = select([case([ [info_table.c.pk < 3, - literal('lessthan3', type_=String)], + 'lessthan3'], [and_(info_table.c.pk >= 3, info_table.c.pk < 7), - literal('gt3', type_=String)]]).label('x'), + 'gt3']]).label('x'), info_table.c.pk, info_table.c.info], from_obj=[info_table]).alias('q_inner') @@ -69,9 +70,9 @@ class CaseTest(TestBase): w_else = select([case([ [info_table.c.pk < 3, - literal(3, type_=Integer)], + 3], [and_(info_table.c.pk >= 3, info_table.c.pk < 6), - literal(6, type_=Integer)]], + 6]], else_ = 0).label('x'), info_table.c.pk, info_table.c.info], from_obj=[info_table]).alias('q_inner') @@ -87,12 +88,21 @@ class CaseTest(TestBase): (0, 6, 'pk_6_data') ] + def test_literal_interpretation(self): + t = table('test', column('col1')) + + self.assertRaises(exceptions.ArgumentError, case, [("x", "y")]) + + self.assert_compile(case([("x", "y")], value=t.c.col1), "CASE test.col1 WHEN :param_1 THEN :param_2 END") + self.assert_compile(case([(t.c.col1==7, "y")], else_="z"), "CASE WHEN (test.col1 = :test_col1_1) THEN :param_1 ELSE :param_2 END") + + @testing.fails_on('maxdb') def testcase_with_dict(self): query = select([case({ - info_table.c.pk < 3: literal('lessthan3'), - info_table.c.pk >= 3: literal('gt3'), - }, else_=literal('other')), + info_table.c.pk < 3: 'lessthan3', + info_table.c.pk >= 3: 'gt3', + }, else_='other'), info_table.c.pk, info_table.c.info ], from_obj=[info_table]) @@ -106,13 +116,14 @@ class CaseTest(TestBase): ] simple_query = select([case({ - 1: literal('one'), - 2: literal('two'), - }, value=info_table.c.pk, else_=literal('other')), + 1: 'one', + 2: 'two', + }, value=info_table.c.pk, else_='other'), info_table.c.pk ], whereclause=info_table.c.pk < 4, from_obj=[info_table]) + assert simple_query.execute().fetchall() == [ ('one', 1), ('two', 2), |
