summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-04-18 22:54:40 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-04-18 22:54:40 +0000
commit4fffc21c87cbdfc538fe2924f82bf1591823856d (patch)
treef200a79e608f9e901baf515ce3d0e1b3b21b8bf6
parent7efd23b23cbbd1d714cc31e44e776b7e1e9af319 (diff)
downloadsqlalchemy-4fffc21c87cbdfc538fe2924f82bf1591823856d.tar.gz
- the "where" criterion of an update() and delete() now correlates
embedded select() statements against the table being updated or deleted. this works the same as nested select() statement correlation, and can be disabled via the correlate=False flag on the embedded select().
-rw-r--r--CHANGES5
-rw-r--r--lib/sqlalchemy/sql.py26
-rw-r--r--test/sql/select.py7
3 files changed, 32 insertions, 6 deletions
diff --git a/CHANGES b/CHANGES
index c560f4fe7..3c600c194 100644
--- a/CHANGES
+++ b/CHANGES
@@ -32,6 +32,11 @@
of unicode situations that occur in db's such as MS-SQL to be
better handled and allows subclassing of the Unicode datatype.
[ticket:522]
+ - the "where" criterion of an update() and delete() now correlates
+ embedded select() statements against the table being updated or
+ deleted. this works the same as nested select() statement
+ correlation, and can be disabled via the correlate=False flag on
+ the embedded select().
- column labels are now generated in the compilation phase, which
means their lengths are dialect-dependent. So on oracle a label
that gets truncated to 30 chars will go out to 63 characters
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py
index 94b618491..a8663ed4c 100644
--- a/lib/sqlalchemy/sql.py
+++ b/lib/sqlalchemy/sql.py
@@ -2301,6 +2301,7 @@ class Select(_SelectBaseMixin, FromClause):
use_labels=False, distinct=False, for_update=False,
engine=None, limit=None, offset=None, scalar=False,
correlate=True):
+ # TODO: docstring !
_SelectBaseMixin.__init__(self)
self.__froms = util.OrderedSet()
self.__hide_froms = util.Set([self])
@@ -2319,7 +2320,7 @@ class Select(_SelectBaseMixin, FromClause):
self.is_scalar = scalar
# indicates if this select statement, as a subquery, should automatically correlate
- # its FROM clause to that of an enclosing select statement.
+ # its FROM clause to that of an enclosing select, update, or delete statement.
# note that the "correlate" method can be used to explicitly add a value to be correlated.
self.should_correlate = correlate
@@ -2560,6 +2561,20 @@ class _UpdateBase(ClauseElement):
def supports_execution(self):
return True
+ class _SelectCorrelator(NoColumnVisitor):
+ def __init__(self, table):
+ NoColumnVisitor.__init__(self)
+ self.table = table
+
+ def visit_select(self, select):
+ if select.should_correlate:
+ select.correlate(self.table)
+
+ def _process_whereclause(self, whereclause):
+ if whereclause is not None:
+ _UpdateBase._SelectCorrelator(self.table).traverse(whereclause)
+ return whereclause
+
def _process_colparams(self, parameters):
"""Receive the *values* of an ``INSERT`` or ``UPDATE``
statement and construct appropriate bind parameters.
@@ -2576,10 +2591,11 @@ class _UpdateBase(ClauseElement):
i +=1
parameters = pp
+ correlator = _UpdateBase._SelectCorrelator(self.table)
for key in parameters.keys():
value = parameters[key]
- if isinstance(value, Select):
- value.correlate(self.table)
+ if isinstance(value, ClauseElement):
+ correlator.traverse(value)
elif _is_literal(value):
if _is_literal(key):
col = self.table.c[key]
@@ -2611,7 +2627,7 @@ class _Insert(_UpdateBase):
class _Update(_UpdateBase):
def __init__(self, table, whereclause, values=None):
self.table = table
- self.whereclause = whereclause
+ self.whereclause = self._process_whereclause(whereclause)
self.parameters = self._process_colparams(values)
def get_children(self, **kwargs):
@@ -2625,7 +2641,7 @@ class _Update(_UpdateBase):
class _Delete(_UpdateBase):
def __init__(self, table, whereclause):
self.table = table
- self.whereclause = whereclause
+ self.whereclause = self._process_whereclause(whereclause)
def get_children(self, **kwargs):
if self.whereclause is not None:
diff --git a/test/sql/select.py b/test/sql/select.py
index 91b293cbe..1d0a63e2f 100644
--- a/test/sql/select.py
+++ b/test/sql/select.py
@@ -828,10 +828,15 @@ class CRUDTest(SQLTest):
u = update(table1, table1.c.name == 'jack', values = {table1.c.name : s})
self.runtest(u, "UPDATE mytable SET name=(SELECT myothertable.otherid, myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid) WHERE mytable.name = :mytable_name")
- # test a correlated WHERE clause
+ # test a non-correlated WHERE clause
s = select([table2.c.othername], table2.c.otherid == 7)
u = update(table1, table1.c.name==s)
self.runtest(u, "UPDATE mytable SET myid=:myid, name=:name, description=:description WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = :myothertable_otherid)")
+
+ # test one that is actually correlated...
+ s = select([table2.c.othername], table2.c.otherid == table1.c.myid)
+ u = table1.update(table1.c.name==s)
+ self.runtest(u, "UPDATE mytable SET myid=:myid, name=:name, description=:description WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid)")
def testdelete(self):
self.runtest(delete(table1, table1.c.myid == 7), "DELETE FROM mytable WHERE mytable.myid = :mytable_myid")