From 3081269e6f1fc51d8d5cfc5120dd10ee2872e871 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 6 Feb 2018 19:30:55 -0500 Subject: Route bulk update/delete exec through new Query._execute_crud method Added support for bulk :meth:`.Query.update` and :meth:`.Query.delete` to the :class:`.ShardedQuery` class within the horiziontal sharding extension. This also adds an additional expansion hook to the bulk update/delete methods :meth:`.Query._execute_crud`. Fixes: #4196 Change-Id: I65f56458176497a8cbdd368f41b879881f06348b --- lib/sqlalchemy/ext/horizontal_shard.py | 45 ++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) (limited to 'lib/sqlalchemy/ext') diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index 6ef4c5612..425d28963 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -64,6 +64,28 @@ class ShardedQuery(Query): # were done, this is where it would happen return iter(partial) + def _execute_crud(self, stmt, mapper): + def exec_for_shard(shard_id): + conn = self._connection_from_session( + mapper=mapper, + shard_id=shard_id, + clause=stmt, + close_with_result=True) + result = conn.execute(stmt, self._params) + return result + + if self._shard_id is not None: + return exec_for_shard(self._shard_id) + else: + rowcount = 0 + results = [] + for shard_id in self.query_chooser(self): + result = exec_for_shard(shard_id) + rowcount += result.rowcount + results.append(result) + + return ShardedResult(results, rowcount) + def _identity_lookup( self, mapper, primary_key_identity, identity_token=None, lazy_loaded_from=None, **kw): @@ -123,6 +145,29 @@ class ShardedQuery(Query): primary_key_identity, _db_load_fn, identity_token=identity_token) +class ShardedResult(object): + """A value object that represents multiple :class:`.ResultProxy` objects. + + This is used by the :meth:`.ShardedQuery._execute_crud` hook to return + an object that takes the place of the single :class:`.ResultProxy`. + + Attribute include ``result_proxies``, which is a sequence of the + actual :class:`.ResultProxy` objects, as well as ``aggregate_rowcount`` + or ``rowcount``, which is the sum of all the individual rowcount values. + + .. versionadded:: 1.3 + """ + + __slots__ = ('result_proxies', 'aggregate_rowcount',) + + def __init__(self, result_proxies, aggregate_rowcount): + self.result_proxies = result_proxies + self.aggregate_rowcount = aggregate_rowcount + + @property + def rowcount(self): + return self.aggregate_rowcount + class ShardedSession(Session): def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None, query_cls=ShardedQuery, **kwargs): -- cgit v1.2.1