diff options
-rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 11 | ||||
-rw-r--r-- | test/orm/test_bulk.py | 38 |
2 files changed, 39 insertions, 10 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index b429aa4c1..254d3bf09 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -472,12 +472,11 @@ def _collect_update_commands( continue if bulk: - pk_params = dict( - (propkey_to_col[propkey]._label, state_dict.get(propkey)) - for propkey in - set(propkey_to_col). - intersection(mapper._pk_keys_by_table[table]) - ) + pk_params = {} + for propkey in set(propkey_to_col).intersection(mapper._pk_keys_by_table[table]): + col = propkey_to_col[propkey] + pk_params[col._label] = state_dict.get(propkey) + params.pop(col.key, None) else: pk_params = {} for col in pks: diff --git a/test/orm/test_bulk.py b/test/orm/test_bulk.py index e27d3b73c..1e0a735c7 100644 --- a/test/orm/test_bulk.py +++ b/test/orm/test_bulk.py @@ -96,11 +96,41 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): asserter.assert_( CompiledSQL( - "UPDATE users SET id=:id, name=:name WHERE " + "UPDATE users SET name=:name WHERE " "users.id = :users_id", - [{'users_id': 1, 'id': 1, 'name': 'u1new'}, - {'users_id': 2, 'id': 2, 'name': 'u2'}, - {'users_id': 3, 'id': 3, 'name': 'u3new'}] + [{'users_id': 1, 'name': 'u1new'}, + {'users_id': 2, 'name': 'u2'}, + {'users_id': 3, 'name': 'u3new'}] + ) + ) + + def test_bulk_update(self): + User, = self.classes("User",) + + s = Session(expire_on_commit=False) + objects = [ + User(name="u1"), + User(name="u2"), + User(name="u3") + ] + s.add_all(objects) + s.commit() + + s = Session() + with self.sql_execution_asserter() as asserter: + s.bulk_update_mappings( + User, + [{'id': 1, 'name': 'u1new'}, + {'id': 2, 'name': 'u2'}, + {'id': 3, 'name': 'u3new'}] + ) + + asserter.assert_( + CompiledSQL( + "UPDATE users SET name=:name WHERE users.id = :users_id", + [{'users_id': 1, 'name': 'u1new'}, + {'users_id': 2, 'name': 'u2'}, + {'users_id': 3, 'name': 'u3new'}] ) ) |