diff options
| -rw-r--r-- | doc/build/changelog/changelog_09.rst | 13 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 8 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 110 | ||||
| -rw-r--r-- | test/orm/test_update_delete.py | 34 | ||||
| -rw-r--r-- | test/sql/test_update.py | 156 |
5 files changed, 270 insertions, 51 deletions
diff --git a/doc/build/changelog/changelog_09.rst b/doc/build/changelog/changelog_09.rst index 0efffce62..d59f3ec60 100644 --- a/doc/build/changelog/changelog_09.rst +++ b/doc/build/changelog/changelog_09.rst @@ -83,6 +83,19 @@ Pullreq courtesy Derek Harland. .. change:: + :tags: bug, sql, orm + :tickets: 2912 + + Fixed the multiple-table "UPDATE..FROM" construct, only usable on + MySQL, to correctly render the SET clause among multiple columns + with the same name across tables. This also changes the name used for + the bound parameter in the SET clause to "<tablename>_<colname>" for + the non-primary table only; as this parameter is typically specified + using the :class:`.Column` object directly this should not have an + impact on applications. The fix takes effect for both + :meth:`.Table.update` as well as :meth:`.Query.update` in the ORM. + + .. change:: :tags: bug, oracle :tickets: 2911 diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index e507885fa..ed975b8cf 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -895,6 +895,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext): and generate inserted_primary_key collection. """ + key_getter = self.compiled._key_getters_for_crud_column[2] + if self.executemany: if len(self.compiled.prefetch): scalar_defaults = {} @@ -918,7 +920,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): else: val = self.get_update_default(c) if val is not None: - param[c.key] = val + param[key_getter(c)] = val del self.current_parameters else: self.current_parameters = compiled_parameters = \ @@ -931,12 +933,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext): val = self.get_update_default(c) if val is not None: - compiled_parameters[c.key] = val + compiled_parameters[key_getter(c)] = val del self.current_parameters if self.isinsert: self.inserted_primary_key = [ - self.compiled_parameters[0].get(c.key, None) + self.compiled_parameters[0].get(key_getter(c), None) for c in self.compiled.\ statement.table.primary_key ] diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 5c5bfad55..4448f7c7b 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -28,6 +28,7 @@ from . import schema, sqltypes, operators, functions, \ from .. import util, exc import decimal import itertools +import operator RESERVED_WORDS = set([ 'all', 'analyse', 'analyze', 'and', 'any', 'array', @@ -1771,7 +1772,7 @@ class SQLCompiler(Compiled): table_text = self.update_tables_clause(update_stmt, update_stmt.table, extra_froms, **kw) - colparams = self._get_colparams(update_stmt, extra_froms, **kw) + colparams = self._get_colparams(update_stmt, **kw) if update_stmt._hints: dialect_hints = dict([ @@ -1840,7 +1841,40 @@ class SQLCompiler(Compiled): bindparam._is_crud = True return bindparam._compiler_dispatch(self) - def _get_colparams(self, stmt, extra_tables=None, **kw): + @util.memoized_property + def _key_getters_for_crud_column(self): + if self.isupdate and self.statement._extra_froms: + # when extra tables are present, refer to the columns + # in those extra tables as table-qualified, including in + # dictionaries and when rendering bind param names. + # the "main" table of the statement remains unqualified, + # allowing the most compatibility with a non-multi-table + # statement. + _et = set(self.statement._extra_froms) + def _column_as_key(key): + str_key = elements._column_as_key(key) + if hasattr(key, 'table') and key.table in _et: + return (key.table.name, str_key) + else: + return str_key + def _getattr_col_key(col): + if col.table in _et: + return (col.table.name, col.key) + else: + return col.key + def _col_bind_name(col): + if col.table in _et: + return "%s_%s" % (col.table.name, col.key) + else: + return col.key + + else: + _column_as_key = elements._column_as_key + _getattr_col_key = _col_bind_name = operator.attrgetter("key") + + return _column_as_key, _getattr_col_key, _col_bind_name + + def _get_colparams(self, stmt, **kw): """create a set of tuples representing column/string pairs for use in an INSERT or UPDATE statement. @@ -1869,12 +1903,18 @@ class SQLCompiler(Compiled): else: stmt_parameters = stmt.parameters + # getters - these are normally just column.key, + # but in the case of mysql multi-table update, the rules for + # .key must conditionally take tablename into account + _column_as_key, _getattr_col_key, _col_bind_name = \ + self._key_getters_for_crud_column + # if we have statement parameters - set defaults in the # compiled params if self.column_keys is None: parameters = {} else: - parameters = dict((elements._column_as_key(key), REQUIRED) + parameters = dict((_column_as_key(key), REQUIRED) for key in self.column_keys if not stmt_parameters or key not in stmt_parameters) @@ -1884,7 +1924,7 @@ class SQLCompiler(Compiled): if stmt_parameters is not None: for k, v in stmt_parameters.items(): - colkey = elements._column_as_key(k) + colkey = _column_as_key(k) if colkey is not None: parameters.setdefault(colkey, v) else: @@ -1892,7 +1932,9 @@ class SQLCompiler(Compiled): # add it to values() in an "as-is" state, # coercing right side to bound param if elements._is_literal(v): - v = self.process(elements.BindParameter(None, v, type_=k.type), **kw) + v = self.process( + elements.BindParameter(None, v, type_=k.type), + **kw) else: v = self.process(v.self_group(), **kw) @@ -1922,24 +1964,25 @@ class SQLCompiler(Compiled): postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid check_columns = {} + # special logic that only occurs for multi-table UPDATE # statements - if extra_tables and stmt_parameters: + if self.isupdate and stmt._extra_froms and stmt_parameters: normalized_params = dict( (elements._clause_element_as_expr(c), param) for c, param in stmt_parameters.items() ) - assert self.isupdate affected_tables = set() - for t in extra_tables: + for t in stmt._extra_froms: for c in t.c: if c in normalized_params: affected_tables.add(t) - check_columns[c.key] = c + check_columns[_getattr_col_key(c)] = c value = normalized_params[c] if elements._is_literal(value): value = self._create_crud_bind_param( - c, value, required=value is REQUIRED) + c, value, required=value is REQUIRED, + name=_col_bind_name(c)) else: self.postfetch.append(c) value = self.process(value.self_group(), **kw) @@ -1954,12 +1997,18 @@ class SQLCompiler(Compiled): elif c.onupdate is not None and not c.onupdate.is_sequence: if c.onupdate.is_clause_element: values.append( - (c, self.process(c.onupdate.arg.self_group(), **kw)) + (c, self.process( + c.onupdate.arg.self_group(), + **kw) + ) ) self.postfetch.append(c) else: values.append( - (c, self._create_crud_bind_param(c, None)) + (c, self._create_crud_bind_param( + c, None, name=_col_bind_name(c) + ) + ) ) self.prefetch.append(c) elif c.server_onupdate is not None: @@ -1968,7 +2017,7 @@ class SQLCompiler(Compiled): if self.isinsert and stmt.select_names: # for an insert from select, we can only use names that # are given, so only select for those names. - cols = (stmt.table.c[elements._column_as_key(name)] + cols = (stmt.table.c[_column_as_key(name)] for name in stmt.select_names) else: # iterate through all table columns to maintain @@ -1976,14 +2025,15 @@ class SQLCompiler(Compiled): cols = stmt.table.columns for c in cols: - if c.key in parameters and c.key not in check_columns: - value = parameters.pop(c.key) + col_key = _getattr_col_key(c) + if col_key in parameters and col_key not in check_columns: + value = parameters.pop(col_key) if elements._is_literal(value): value = self._create_crud_bind_param( c, value, required=value is REQUIRED, - name=c.key + name=_col_bind_name(c) if not stmt._has_multi_parameters - else "%s_0" % c.key + else "%s_0" % _col_bind_name(c) ) else: if isinstance(value, elements.BindParameter) and \ @@ -2119,12 +2169,12 @@ class SQLCompiler(Compiled): if parameters and stmt_parameters: check = set(parameters).intersection( - elements._column_as_key(k) for k in stmt.parameters + _column_as_key(k) for k in stmt.parameters ).difference(check_columns) if check: raise exc.CompileError( "Unconsumed column names: %s" % - (", ".join(check)) + (", ".join("%s" % c for c in check)) ) if stmt._has_multi_parameters: @@ -2133,17 +2183,17 @@ class SQLCompiler(Compiled): values.extend( [ - ( - c, - self._create_crud_bind_param( - c, row[c.key], - name="%s_%d" % (c.key, i + 1) - ) - if c.key in row else param - ) - for (c, param) in values_0 - ] - for i, row in enumerate(stmt.parameters[1:]) + ( + c, + self._create_crud_bind_param( + c, row[c.key], + name="%s_%d" % (c.key, i + 1) + ) + if c.key in row else param + ) + for (c, param) in values_0 + ] + for i, row in enumerate(stmt.parameters[1:]) ) return values diff --git a/test/orm/test_update_delete.py b/test/orm/test_update_delete.py index 6915ac8a2..ac94fde2f 100644 --- a/test/orm/test_update_delete.py +++ b/test/orm/test_update_delete.py @@ -545,12 +545,14 @@ class UpdateDeleteFromTest(fixtures.MappedTest): def define_tables(cls, metadata): Table('users', metadata, Column('id', Integer, primary_key=True), + Column('samename', String(10)), ) Table('documents', metadata, Column('id', Integer, primary_key=True), Column('user_id', None, ForeignKey('users.id')), Column('title', String(32)), - Column('flag', Boolean) + Column('flag', Boolean), + Column('samename', String(10)), ) @classmethod @@ -659,6 +661,34 @@ class UpdateDeleteFromTest(fixtures.MappedTest): ]) ) + @testing.only_on('mysql', 'Multi table update') + def test_update_from_multitable_same_names(self): + Document = self.classes.Document + User = self.classes.User + + s = Session() + + s.query(Document).\ + filter(User.id == Document.user_id).\ + filter(User.id == 2).update({ + Document.samename: 'd_samename', + User.samename: 'u_samename' + } + ) + eq_( + s.query(User.id, Document.samename, User.samename). + filter(User.id == Document.user_id). + order_by(User.id).all(), + [ + (1, None, None), + (1, None, None), + (2, 'd_samename', 'u_samename'), + (2, 'd_samename', 'u_samename'), + (3, None, None), + (3, None, None), + ] + ) + class ExpressionUpdateTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): @@ -786,3 +816,5 @@ class InheritTest(fixtures.DeclarativeMappedTest): set(s.query(Person.name, Engineer.engineer_name)), set([('e1', 'e1', ), ('e22', 'e55')]) ) + + diff --git a/test/sql/test_update.py b/test/sql/test_update.py index a8510f374..10306372b 100644 --- a/test/sql/test_update.py +++ b/test/sql/test_update.py @@ -192,22 +192,6 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): 'UPDATE A B C D mytable SET myid=%s, name=%s, description=%s', dialect=mysql.dialect()) - def test_alias(self): - table1 = self.tables.mytable - talias1 = table1.alias('t1') - - self.assert_compile(update(talias1, talias1.c.myid == 7), - 'UPDATE mytable AS t1 ' - 'SET name=:name ' - 'WHERE t1.myid = :myid_1', - params={table1.c.name: 'fred'}) - - self.assert_compile(update(talias1, table1.c.myid == 7), - 'UPDATE mytable AS t1 ' - 'SET name=:name ' - 'FROM mytable ' - 'WHERE mytable.myid = :myid_1', - params={table1.c.name: 'fred'}) def test_update_to_expression(self): """test update from an expression. @@ -268,6 +252,64 @@ class UpdateFromCompileTest(_UpdateFromTestBase, fixtures.TablesTest, run_create_tables = run_inserts = run_deletes = None + def test_alias_one(self): + table1 = self.tables.mytable + talias1 = table1.alias('t1') + + # this case is nonsensical. the UPDATE is entirely + # against the alias, but we name the table-bound column + # in values. The behavior here isn't really defined + self.assert_compile( + update(talias1, talias1.c.myid == 7). + values({table1.c.name: "fred"}), + 'UPDATE mytable AS t1 ' + 'SET name=:name ' + 'WHERE t1.myid = :myid_1') + + def test_alias_two(self): + table1 = self.tables.mytable + talias1 = table1.alias('t1') + + # Here, compared to + # test_alias_one(), here we actually have UPDATE..FROM, + # which is causing the "table1.c.name" param to be handled + # as an "extra table", hence we see the full table name rendered. + self.assert_compile( + update(talias1, table1.c.myid == 7). + values({table1.c.name: 'fred'}), + 'UPDATE mytable AS t1 ' + 'SET name=:mytable_name ' + 'FROM mytable ' + 'WHERE mytable.myid = :myid_1', + checkparams={'mytable_name': 'fred', 'myid_1': 7}, + ) + + def test_alias_two_mysql(self): + table1 = self.tables.mytable + talias1 = table1.alias('t1') + + self.assert_compile( + update(talias1, table1.c.myid == 7). + values({table1.c.name: 'fred'}), + "UPDATE mytable AS t1, mytable SET mytable.name=%s " + "WHERE mytable.myid = %s", + checkparams={'mytable_name': 'fred', 'myid_1': 7}, + dialect='mysql') + + def test_update_from_multitable_same_name_mysql(self): + users, addresses = self.tables.users, self.tables.addresses + + self.assert_compile( + users.update(). + values(name='newname').\ + values({addresses.c.name: "new address"}).\ + where(users.c.id == addresses.c.user_id), + "UPDATE users, addresses SET addresses.name=%s, " + "users.name=%s WHERE users.id = addresses.user_id", + checkparams={u'addresses_name': 'new address', 'name': 'newname'}, + dialect='mysql' + ) + def test_render_table(self): users, addresses = self.tables.users, self.tables.addresses @@ -455,6 +497,36 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest): (10, 'chuck')] self._assert_users(users, expected) + @testing.only_on('mysql', 'Multi table update') + def test_exec_multitable_same_name(self): + users, addresses = self.tables.users, self.tables.addresses + + values = { + addresses.c.name: 'ad_ed2', + users.c.name: 'ed2' + } + + testing.db.execute( + addresses.update(). + values(values). + where(users.c.id == addresses.c.user_id). + where(users.c.name == 'ed')) + + expected = [ + (1, 7, 'x', 'jack@bean.com'), + (2, 8, 'ad_ed2', 'ed@wood.com'), + (3, 8, 'ad_ed2', 'ed@bettyboop.com'), + (4, 8, 'ad_ed2', 'ed@lala.com'), + (5, 9, 'x', 'fred@fred.com')] + self._assert_addresses(addresses, expected) + + expected = [ + (7, 'jack'), + (8, 'ed2'), + (9, 'fred'), + (10, 'chuck')] + self._assert_users(users, expected) + def _assert_addresses(self, addresses, expected): stmt = addresses.select().order_by(addresses.c.id) eq_(testing.db.execute(stmt).fetchall(), expected) @@ -478,7 +550,16 @@ class UpdateFromMultiTableUpdateDefaultsTest(_UpdateFromTestBase, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('user_id', None, ForeignKey('users.id')), - Column('email_address', String(50), nullable=False)) + Column('email_address', String(50), nullable=False), + ) + + Table('foobar', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('user_id', None, ForeignKey('users.id')), + Column('data', String(30)), + Column('some_update', String(30), onupdate='im the other update') + ) @classmethod def fixtures(cls): @@ -494,6 +575,12 @@ class UpdateFromMultiTableUpdateDefaultsTest(_UpdateFromTestBase, (3, 8, 'ed@bettyboop.com'), (4, 9, 'fred@fred.com') ), + foobar=( + ('id', 'user_id', 'data'), + (2, 8, 'd1'), + (3, 8, 'd2'), + (4, 9, 'd3') + ) ) @testing.only_on('mysql', 'Multi table update') @@ -525,6 +612,37 @@ class UpdateFromMultiTableUpdateDefaultsTest(_UpdateFromTestBase, self._assert_users(users, expected) @testing.only_on('mysql', 'Multi table update') + def test_defaults_second_table_same_name(self): + users, foobar = self.tables.users, self.tables.foobar + + values = { + foobar.c.data: foobar.c.data + 'a', + users.c.name: 'ed2' + } + + ret = testing.db.execute( + users.update(). + values(values). + where(users.c.id == foobar.c.user_id). + where(users.c.name == 'ed')) + + eq_( + set(ret.prefetch_cols()), + set([users.c.some_update, foobar.c.some_update]) + ) + + expected = [ + (2, 8, 'd1a', 'im the other update'), + (3, 8, 'd2a', 'im the other update'), + (4, 9, 'd3', None)] + self._assert_foobar(foobar, expected) + + expected = [ + (8, 'ed2', 'im the update'), + (9, 'fred', 'value')] + self._assert_users(users, expected) + + @testing.only_on('mysql', 'Multi table update') def test_no_defaults_second_table(self): users, addresses = self.tables.users, self.tables.addresses @@ -548,6 +666,10 @@ class UpdateFromMultiTableUpdateDefaultsTest(_UpdateFromTestBase, (9, 'fred', 'value')] self._assert_users(users, expected) + def _assert_foobar(self, foobar, expected): + stmt = foobar.select().order_by(foobar.c.id) + eq_(testing.db.execute(stmt).fetchall(), expected) + def _assert_addresses(self, addresses, expected): stmt = addresses.select().order_by(addresses.c.id) eq_(testing.db.execute(stmt).fetchall(), expected) |
