summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/horizontal_shard.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2011-03-16 12:43:22 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2011-03-16 12:43:22 -0400
commitff1868b3f045435e3829eefa6d6911f492569dca (patch)
tree19e031dc0d3659675b07814c1d7c70aab655be72 /lib/sqlalchemy/ext/horizontal_shard.py
parent464835e409dbd607a8a1fbbc8399f6c0c14b3ea8 (diff)
downloadsqlalchemy-ff1868b3f045435e3829eefa6d6911f492569dca.tar.gz
- The horizontal_shard ShardedSession class accepts the common
Session argument "query_cls" as a constructor argument, to enable further subclassing of ShardedQuery. [ticket:2090] - The Beaker caching example allows a "query_cls" argument to the query_callable() function. [ticket:2090]
Diffstat (limited to 'lib/sqlalchemy/ext/horizontal_shard.py')
-rw-r--r--lib/sqlalchemy/ext/horizontal_shard.py104
1 files changed, 52 insertions, 52 deletions
diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py
index 32c767e11..dfd471c78 100644
--- a/lib/sqlalchemy/ext/horizontal_shard.py
+++ b/lib/sqlalchemy/ext/horizontal_shard.py
@@ -21,9 +21,59 @@ from sqlalchemy.orm.query import Query
__all__ = ['ShardedSession', 'ShardedQuery']
+class ShardedQuery(Query):
+ def __init__(self, *args, **kwargs):
+ super(ShardedQuery, self).__init__(*args, **kwargs)
+ self.id_chooser = self.session.id_chooser
+ self.query_chooser = self.session.query_chooser
+ self._shard_id = None
+
+ def set_shard(self, shard_id):
+ """return a new query, limited to a single shard ID.
+
+ all subsequent operations with the returned query will
+ be against the single shard regardless of other state.
+ """
+
+ q = self._clone()
+ q._shard_id = shard_id
+ return q
+
+ def _execute_and_instances(self, context):
+ if self._shard_id is not None:
+ context.attributes['shard_id'] = self._shard_id
+ result = self.session.connection(
+ mapper=self._mapper_zero(),
+ shard_id=self._shard_id).execute(context.statement, self._params)
+ return self.instances(result, context)
+ else:
+ partial = []
+ for shard_id in self.query_chooser(self):
+ context.attributes['shard_id'] = shard_id
+ result = self.session.connection(
+ mapper=self._mapper_zero(),
+ shard_id=shard_id).execute(context.statement, self._params)
+ partial = partial + list(self.instances(result, context))
+
+ # if some kind of in memory 'sorting'
+ # were done, this is where it would happen
+ return iter(partial)
+
+ def get(self, ident, **kwargs):
+ if self._shard_id is not None:
+ return super(ShardedQuery, self).get(ident)
+ else:
+ ident = util.to_list(ident)
+ for shard_id in self.id_chooser(self, ident):
+ o = self.set_shard(shard_id).get(ident, **kwargs)
+ if o is not None:
+ return o
+ else:
+ return None
class ShardedSession(Session):
- def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None, **kwargs):
+ def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None,
+ query_cls=ShardedQuery, **kwargs):
"""Construct a ShardedSession.
:param shard_chooser: A callable which, passed a Mapper, a mapped instance, and possibly a
@@ -45,13 +95,12 @@ class ShardedSession(Session):
objects.
"""
- super(ShardedSession, self).__init__(**kwargs)
+ super(ShardedSession, self).__init__(query_cls=query_cls, **kwargs)
self.shard_chooser = shard_chooser
self.id_chooser = id_chooser
self.query_chooser = query_chooser
self.__binds = {}
self.connection_callable = self.connection
- self._query_cls = ShardedQuery
if shards is not None:
for k in shards:
self.bind_shard(k, shards[k])
@@ -75,53 +124,4 @@ class ShardedSession(Session):
def bind_shard(self, shard_id, bind):
self.__binds[shard_id] = bind
-class ShardedQuery(Query):
- def __init__(self, *args, **kwargs):
- super(ShardedQuery, self).__init__(*args, **kwargs)
- self.id_chooser = self.session.id_chooser
- self.query_chooser = self.session.query_chooser
- self._shard_id = None
-
- def set_shard(self, shard_id):
- """return a new query, limited to a single shard ID.
-
- all subsequent operations with the returned query will
- be against the single shard regardless of other state.
- """
-
- q = self._clone()
- q._shard_id = shard_id
- return q
-
- def _execute_and_instances(self, context):
- if self._shard_id is not None:
- context.attributes['shard_id'] = self._shard_id
- result = self.session.connection(
- mapper=self._mapper_zero(),
- shard_id=self._shard_id).execute(context.statement, self._params)
- return self.instances(result, context)
- else:
- partial = []
- for shard_id in self.query_chooser(self):
- context.attributes['shard_id'] = shard_id
- result = self.session.connection(
- mapper=self._mapper_zero(),
- shard_id=shard_id).execute(context.statement, self._params)
- partial = partial + list(self.instances(result, context))
-
- # if some kind of in memory 'sorting'
- # were done, this is where it would happen
- return iter(partial)
-
- def get(self, ident, **kwargs):
- if self._shard_id is not None:
- return super(ShardedQuery, self).get(ident)
- else:
- ident = util.to_list(ident)
- for shard_id in self.id_chooser(self, ident):
- o = self.set_shard(shard_id).get(ident, **kwargs)
- if o is not None:
- return o
- else:
- return None