diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-01-17 21:36:52 -0500 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2015-01-17 21:58:49 -0500 |
| commit | f49c367ef712d080e630ba722f96903922d7de7b (patch) | |
| tree | 003e894ee985784fc8806f735d7a28c53580e01e | |
| parent | 469b6fabaf78fa0aad485005fd7bc8be7fe27f92 (diff) | |
| download | sqlalchemy-f49c367ef712d080e630ba722f96903922d7de7b.tar.gz | |
- fix a regression from ref #3178, where dialects that don't actually support
sane multi rowcount (e.g. pyodbc) would fail on multirow update. add
a test that mocks this breakage into plain dialects
| -rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 16 | ||||
| -rw-r--r-- | test/orm/test_unitofworkv2.py | 68 |
2 files changed, 77 insertions, 7 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index f477e1dd7..dbf1d3eb4 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -617,6 +617,14 @@ def _emit_update_statements(base_mapper, uowtransaction, rows = 0 records = list(records) + # TODO: would be super-nice to not have to determine this boolean + # inside the loop here, in the 99.9999% of the time there's only + # one connection in use + assert_singlerow = connection.dialect.supports_sane_rowcount + assert_multirow = assert_singlerow and \ + connection.dialect.supports_sane_multi_rowcount + allow_multirow = not needs_version_id or assert_multirow + if hasvalue: for state, state_dict, params, mapper, \ connection, value_params in records: @@ -635,9 +643,7 @@ def _emit_update_statements(base_mapper, uowtransaction, value_params) rows += c.rowcount else: - if needs_version_id and \ - not connection.dialect.supports_sane_multi_rowcount and \ - connection.dialect.supports_sane_rowcount: + if not allow_multirow: for state, state_dict, params, mapper, \ connection, value_params in records: c = cached_connections[connection].\ @@ -654,6 +660,7 @@ def _emit_update_statements(base_mapper, uowtransaction, rows += c.rowcount else: multiparams = [rec[2] for rec in records] + c = cached_connections[connection].\ execute(statement, multiparams) @@ -670,7 +677,8 @@ def _emit_update_statements(base_mapper, uowtransaction, c.context.compiled_parameters[0], value_params) - if connection.dialect.supports_sane_rowcount: + if assert_multirow or assert_singlerow and \ + len(multiparams) == 1: if rows != len(records): raise orm_exc.StaleDataError( "UPDATE statement on table '%s' expected to " diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py index 374a77237..681b104cf 100644 --- a/test/orm/test_unitofworkv2.py +++ b/test/orm/test_unitofworkv2.py @@ -3,13 +3,13 @@ from sqlalchemy import testing from sqlalchemy.testing import engines from sqlalchemy.testing.schema import Table, Column from test.orm import _fixtures -from sqlalchemy import exc -from sqlalchemy.testing import fixtures +from sqlalchemy import exc, util +from sqlalchemy.testing import fixtures, config from sqlalchemy import Integer, String, ForeignKey, func from sqlalchemy.orm import mapper, relationship, backref, \ create_session, unitofwork, attributes,\ Session, exc as orm_exc -from sqlalchemy.testing.mock import Mock +from sqlalchemy.testing.mock import Mock, patch from sqlalchemy.testing.assertsql import AllOf, CompiledSQL from sqlalchemy import event @@ -1473,6 +1473,67 @@ class BasicStaleChecksTest(fixtures.MappedTest): sess.flush ) + def test_update_single_missing_broken_multi_rowcount(self): + @util.memoized_property + def rowcount(self): + if len(self.context.compiled_parameters) > 1: + return -1 + else: + return self.context.rowcount + + with patch.object( + config.db.dialect, "supports_sane_multi_rowcount", False): + with patch( + "sqlalchemy.engine.result.ResultProxy.rowcount", + rowcount): + Parent, Child = self._fixture() + sess = Session() + p1 = Parent(id=1, data=2) + sess.add(p1) + sess.flush() + + sess.execute(self.tables.parent.delete()) + + p1.data = 3 + assert_raises_message( + orm_exc.StaleDataError, + "UPDATE statement on table 'parent' expected to " + "update 1 row\(s\); 0 were matched.", + sess.flush + ) + + def test_update_multi_missing_broken_multi_rowcount(self): + @util.memoized_property + def rowcount(self): + if len(self.context.compiled_parameters) > 1: + return -1 + else: + return self.context.rowcount + + with patch.object( + config.db.dialect, "supports_sane_multi_rowcount", False): + with patch( + "sqlalchemy.engine.result.ResultProxy.rowcount", + rowcount): + Parent, Child = self._fixture() + sess = Session() + p1 = Parent(id=1, data=2) + p2 = Parent(id=2, data=3) + sess.add_all([p1, p2]) + sess.flush() + + sess.execute(self.tables.parent.delete().where(Parent.id == 1)) + + p1.data = 3 + p2.data = 4 + sess.flush() # no exception + + # update occurred for remaining row + eq_( + sess.query(Parent.id, Parent.data).all(), + [(2, 4)] + ) + @testing.requires.sane_multi_rowcount def test_delete_multi_missing_warning(self): Parent, Child = self._fixture() @@ -1544,6 +1605,7 @@ class BatchInsertsTest(fixtures.MappedTest, testing.AssertsExecutionResults): T(id=10, data='t10', def_='def3'), T(id=11, data='t11'), ]) + self.assert_sql_execution( testing.db, sess.flush, |
