summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGES6
-rw-r--r--lib/sqlalchemy/dialects/oracle/base.py4
-rw-r--r--lib/sqlalchemy/sql/compiler.py13
-rw-r--r--test/sql/test_select.py28
4 files changed, 45 insertions, 6 deletions
diff --git a/CHANGES b/CHANGES
index 86baaa78a..baa34e789 100644
--- a/CHANGES
+++ b/CHANGES
@@ -191,6 +191,12 @@ CHANGES
performed). This occurs if no end-user returning() was
specified.
+ - insert() and update() constructs can now embed bindparam()
+ objects using names that match the keys of columns. These
+ bind parameters will circumvent the usual route to those
+ keys showing up in the VALUES or SET clause of the generated
+ SQL. [ticket:1579]
+
- Databases which rely upon postfetch of "last inserted id"
to get at a generated sequence value (i.e. MySQL, MS-SQL)
now work correctly when there is a composite primary key
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
index 8af28c8e5..689b518f1 100644
--- a/lib/sqlalchemy/dialects/oracle/base.py
+++ b/lib/sqlalchemy/dialects/oracle/base.py
@@ -465,13 +465,13 @@ class OracleDDLCompiler(compiler.DDLCompiler):
class OracleIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = set([x.lower() for x in RESERVED_WORDS])
- illegal_initial_characters = re.compile(r'[0-9_$]')
+ illegal_initial_characters = set(xrange(0, 10)).union(["_", "$"])
def _bindparam_requires_quotes(self, value):
"""Return True if the given identifier requires quoting."""
lc_value = value.lower()
return (lc_value in self.reserved_words
- or self.illegal_initial_characters.match(value[0])
+ or value[0] in self.illegal_initial_characters
or not self.legal_characters.match(unicode(value))
)
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 4c3130879..5f5b31c68 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -47,7 +47,7 @@ RESERVED_WORDS = set([
'using', 'verbose', 'when', 'where'])
LEGAL_CHARACTERS = re.compile(r'^[A-Z0-9_$]+$', re.I)
-ILLEGAL_INITIAL_CHARACTERS = re.compile(r'[0-9$]')
+ILLEGAL_INITIAL_CHARACTERS = set(xrange(0, 10)).union(['$'])
BIND_PARAMS = re.compile(r'(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])', re.UNICODE)
BIND_PARAMS_ESC = re.compile(r'\x5c(:[\w\$]+)(?![:\w\$])', re.UNICODE)
@@ -776,12 +776,17 @@ class SQLCompiler(engine.Compiled):
self.prefetch = []
self.returning = []
+ # get the keys of explicitly constructed bindparam() objects
+ bind_names = set(b.key for b in visitors.iterate(stmt, {}) if b.__visit_name__ == 'bindparam')
+ if stmt.parameters:
+ bind_names.update(stmt.parameters)
+
# no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
if self.column_keys is None and stmt.parameters is None:
return [
(c, self._create_crud_bind_param(c, None, required=True))
- for c in stmt.table.columns
+ for c in stmt.table.columns if c.key not in bind_names
]
required = object()
@@ -792,7 +797,7 @@ class SQLCompiler(engine.Compiled):
parameters = {}
else:
parameters = dict((sql._column_as_key(key), required)
- for key in self.column_keys)
+ for key in self.column_keys if key not in bind_names)
if stmt.parameters is not None:
for k, v in stmt.parameters.iteritems():
@@ -1312,7 +1317,7 @@ class IdentifierPreparer(object):
"""Return True if the given identifier requires quoting."""
lc_value = value.lower()
return (lc_value in self.reserved_words
- or self.illegal_initial_characters.match(value[0])
+ or value[0] in self.illegal_initial_characters
or not self.legal_characters.match(unicode(value))
or (lc_value != value))
diff --git a/test/sql/test_select.py b/test/sql/test_select.py
index 3dc09c9df..1db2559bc 100644
--- a/test/sql/test_select.py
+++ b/test/sql/test_select.py
@@ -1574,7 +1574,35 @@ class CRUDTest(TestBase, AssertsCompiledSQL):
s = select([table2.c.othername], table2.c.otherid == table1.c.myid)
u = table1.delete(table1.c.name==s)
self.assert_compile(u, "DELETE FROM mytable WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid)")
+
+ def test_binds_that_match_columns(self):
+ """test bind params named after column names replace the normal SET/VALUES generation."""
+
+ t = table('foo', column('x'), column('y'))
+ u = t.update().where(t.c.x==bindparam('x'))
+
+ self.assert_compile(u, "UPDATE foo SET y=:y WHERE foo.x = :x")
+ self.assert_compile(u, "UPDATE foo SET WHERE foo.x = :x", params={})
+ self.assert_compile(u.values(x=7), "UPDATE foo SET x=:x WHERE foo.x = :x")
+ self.assert_compile(u.values(y=7), "UPDATE foo SET y=:y WHERE foo.x = :x")
+ self.assert_compile(u.values(x=7), "UPDATE foo SET x=:x, y=:y WHERE foo.x = :x", params={'x':1, 'y':2})
+ self.assert_compile(u, "UPDATE foo SET y=:y WHERE foo.x = :x", params={'x':1, 'y':2})
+
+ self.assert_compile(u.values(x=3 + bindparam('x')), "UPDATE foo SET x=(:param_1 + :x) WHERE foo.x = :x")
+ self.assert_compile(u.values(x=3 + bindparam('x')), "UPDATE foo SET x=(:param_1 + :x) WHERE foo.x = :x", params={'x':1})
+ self.assert_compile(u.values(x=3 + bindparam('x')), "UPDATE foo SET x=(:param_1 + :x), y=:y WHERE foo.x = :x", params={'x':1, 'y':2})
+
+ i = t.insert().values(x=3 + bindparam('x'))
+ self.assert_compile(i, "INSERT INTO foo (x) VALUES ((:param_1 + :x))")
+ self.assert_compile(i, "INSERT INTO foo (x, y) VALUES ((:param_1 + :x), :y)", params={'x':1, 'y':2})
+
+ i = t.insert().values(x=3 + bindparam('x2'))
+ self.assert_compile(i, "INSERT INTO foo (x) VALUES ((:param_1 + :x2))")
+ self.assert_compile(i, "INSERT INTO foo (x) VALUES ((:param_1 + :x2))", params={})
+ self.assert_compile(i, "INSERT INTO foo (x, y) VALUES ((:param_1 + :x2), :y)", params={'x':1, 'y':2})
+ self.assert_compile(i, "INSERT INTO foo (x, y) VALUES ((:param_1 + :x2), :y)", params={'x2':1, 'y':2})
+
class InlineDefaultTest(TestBase, AssertsCompiledSQL):
def test_insert(self):
m = MetaData()