diff options
| -rw-r--r-- | CHANGES | 6 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/oracle/base.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 13 | ||||
| -rw-r--r-- | test/sql/test_select.py | 28 |
4 files changed, 45 insertions, 6 deletions
@@ -191,6 +191,12 @@ CHANGES performed). This occurs if no end-user returning() was specified. + - insert() and update() constructs can now embed bindparam() + objects using names that match the keys of columns. These + bind parameters will circumvent the usual route to those + keys showing up in the VALUES or SET clause of the generated + SQL. [ticket:1579] + - Databases which rely upon postfetch of "last inserted id" to get at a generated sequence value (i.e. MySQL, MS-SQL) now work correctly when there is a composite primary key diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 8af28c8e5..689b518f1 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -465,13 +465,13 @@ class OracleDDLCompiler(compiler.DDLCompiler): class OracleIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = set([x.lower() for x in RESERVED_WORDS]) - illegal_initial_characters = re.compile(r'[0-9_$]') + illegal_initial_characters = set(xrange(0, 10)).union(["_", "$"]) def _bindparam_requires_quotes(self, value): """Return True if the given identifier requires quoting.""" lc_value = value.lower() return (lc_value in self.reserved_words - or self.illegal_initial_characters.match(value[0]) + or value[0] in self.illegal_initial_characters or not self.legal_characters.match(unicode(value)) ) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 4c3130879..5f5b31c68 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -47,7 +47,7 @@ RESERVED_WORDS = set([ 'using', 'verbose', 'when', 'where']) LEGAL_CHARACTERS = re.compile(r'^[A-Z0-9_$]+$', re.I) -ILLEGAL_INITIAL_CHARACTERS = re.compile(r'[0-9$]') +ILLEGAL_INITIAL_CHARACTERS = set(xrange(0, 10)).union(['$']) BIND_PARAMS = re.compile(r'(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])', re.UNICODE) BIND_PARAMS_ESC = re.compile(r'\x5c(:[\w\$]+)(?![:\w\$])', re.UNICODE) @@ -776,12 +776,17 @@ class SQLCompiler(engine.Compiled): self.prefetch = [] self.returning = [] + # get the keys of explicitly constructed bindparam() objects + bind_names = set(b.key for b in visitors.iterate(stmt, {}) if b.__visit_name__ == 'bindparam') + if stmt.parameters: + bind_names.update(stmt.parameters) + # no parameters in the statement, no parameters in the # compiled params - return binds for all columns if self.column_keys is None and stmt.parameters is None: return [ (c, self._create_crud_bind_param(c, None, required=True)) - for c in stmt.table.columns + for c in stmt.table.columns if c.key not in bind_names ] required = object() @@ -792,7 +797,7 @@ class SQLCompiler(engine.Compiled): parameters = {} else: parameters = dict((sql._column_as_key(key), required) - for key in self.column_keys) + for key in self.column_keys if key not in bind_names) if stmt.parameters is not None: for k, v in stmt.parameters.iteritems(): @@ -1312,7 +1317,7 @@ class IdentifierPreparer(object): """Return True if the given identifier requires quoting.""" lc_value = value.lower() return (lc_value in self.reserved_words - or self.illegal_initial_characters.match(value[0]) + or value[0] in self.illegal_initial_characters or not self.legal_characters.match(unicode(value)) or (lc_value != value)) diff --git a/test/sql/test_select.py b/test/sql/test_select.py index 3dc09c9df..1db2559bc 100644 --- a/test/sql/test_select.py +++ b/test/sql/test_select.py @@ -1574,7 +1574,35 @@ class CRUDTest(TestBase, AssertsCompiledSQL): s = select([table2.c.othername], table2.c.otherid == table1.c.myid) u = table1.delete(table1.c.name==s) self.assert_compile(u, "DELETE FROM mytable WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid)") + + def test_binds_that_match_columns(self): + """test bind params named after column names replace the normal SET/VALUES generation.""" + + t = table('foo', column('x'), column('y')) + u = t.update().where(t.c.x==bindparam('x')) + + self.assert_compile(u, "UPDATE foo SET y=:y WHERE foo.x = :x") + self.assert_compile(u, "UPDATE foo SET WHERE foo.x = :x", params={}) + self.assert_compile(u.values(x=7), "UPDATE foo SET x=:x WHERE foo.x = :x") + self.assert_compile(u.values(y=7), "UPDATE foo SET y=:y WHERE foo.x = :x") + self.assert_compile(u.values(x=7), "UPDATE foo SET x=:x, y=:y WHERE foo.x = :x", params={'x':1, 'y':2}) + self.assert_compile(u, "UPDATE foo SET y=:y WHERE foo.x = :x", params={'x':1, 'y':2}) + + self.assert_compile(u.values(x=3 + bindparam('x')), "UPDATE foo SET x=(:param_1 + :x) WHERE foo.x = :x") + self.assert_compile(u.values(x=3 + bindparam('x')), "UPDATE foo SET x=(:param_1 + :x) WHERE foo.x = :x", params={'x':1}) + self.assert_compile(u.values(x=3 + bindparam('x')), "UPDATE foo SET x=(:param_1 + :x), y=:y WHERE foo.x = :x", params={'x':1, 'y':2}) + + i = t.insert().values(x=3 + bindparam('x')) + self.assert_compile(i, "INSERT INTO foo (x) VALUES ((:param_1 + :x))") + self.assert_compile(i, "INSERT INTO foo (x, y) VALUES ((:param_1 + :x), :y)", params={'x':1, 'y':2}) + + i = t.insert().values(x=3 + bindparam('x2')) + self.assert_compile(i, "INSERT INTO foo (x) VALUES ((:param_1 + :x2))") + self.assert_compile(i, "INSERT INTO foo (x) VALUES ((:param_1 + :x2))", params={}) + self.assert_compile(i, "INSERT INTO foo (x, y) VALUES ((:param_1 + :x2), :y)", params={'x':1, 'y':2}) + self.assert_compile(i, "INSERT INTO foo (x, y) VALUES ((:param_1 + :x2), :y)", params={'x2':1, 'y':2}) + class InlineDefaultTest(TestBase, AssertsCompiledSQL): def test_insert(self): m = MetaData() |
