summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGES3
-rw-r--r--lib/sqlalchemy/ext/horizontal_shard.py19
-rw-r--r--lib/sqlalchemy/orm/query.py19
-rw-r--r--test/lib/requires.py17
-rw-r--r--test/orm/test_query.py13
5 files changed, 53 insertions, 18 deletions
diff --git a/CHANGES b/CHANGES
index 87c20d7a0..abdb9299e 100644
--- a/CHANGES
+++ b/CHANGES
@@ -21,6 +21,9 @@ CHANGES
a deprecation warning in 0.6.8.
[ticket:2144]
+ - added Query.with_session() method, switches
+ Query to use a different session.
+
- sql
- Some improvements to error handling inside
of the execute procedure to ensure auto-close
diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py
index dfd471c78..6aafb2274 100644
--- a/lib/sqlalchemy/ext/horizontal_shard.py
+++ b/lib/sqlalchemy/ext/horizontal_shard.py
@@ -40,20 +40,21 @@ class ShardedQuery(Query):
return q
def _execute_and_instances(self, context):
- if self._shard_id is not None:
- context.attributes['shard_id'] = self._shard_id
- result = self.session.connection(
+ def iter_for_shard(shard_id):
+ context.attributes['shard_id'] = shard_id
+ result = self._connection_from_session(
mapper=self._mapper_zero(),
- shard_id=self._shard_id).execute(context.statement, self._params)
+ shard_id=shard_id).execute(
+ context.statement,
+ self._params)
return self.instances(result, context)
+
+ if self._shard_id is not None:
+ return iter_for_shard(self._shard_id)
else:
partial = []
for shard_id in self.query_chooser(self):
- context.attributes['shard_id'] = shard_id
- result = self.session.connection(
- mapper=self._mapper_zero(),
- shard_id=shard_id).execute(context.statement, self._params)
- partial = partial + list(self.instances(result, context))
+ partial.extend(iter_for_shard(shard_id))
# if some kind of in memory 'sorting'
# were done, this is where it would happen
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index ef42e0d3a..75fd5870e 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -773,6 +773,14 @@ class Query(object):
m = _MapperEntity(self, entity)
self._setup_aliasizers([m])
+ @_generative()
+ def with_session(self, session):
+ """Return a :class:`Query` that will use the given :class:`.Session`.
+
+ """
+
+ self.session = session
+
def from_self(self, *entities):
"""return a Query that selects from this Query's
SELECT statement.
@@ -1766,13 +1774,18 @@ class Query(object):
self.session._autoflush()
return self._execute_and_instances(context)
- def _execute_and_instances(self, querycontext):
+ def _connection_from_session(self, **kw):
conn = self.session.connection(
+ **kw)
+ if self._execution_options:
+ conn = conn.execution_options(**self._execution_options)
+ return conn
+
+ def _execute_and_instances(self, querycontext):
+ conn = self._connection_from_session(
mapper = self._mapper_zero_or_none(),
clause = querycontext.statement,
close_with_result=True)
- if self._execution_options:
- conn = conn.execution_options(**self._execution_options)
result = conn.execute(querycontext.statement, self._params)
return self.instances(result, querycontext)
diff --git a/test/lib/requires.py b/test/lib/requires.py
index 1be308fe7..5f3eb1c9a 100644
--- a/test/lib/requires.py
+++ b/test/lib/requires.py
@@ -287,12 +287,17 @@ def cextensions(fn):
)
def dbapi_lastrowid(fn):
- return _chain_decorators_on(
- fn,
- fails_on_everything_except('mysql+mysqldb', 'mysql+oursql',
- 'sqlite+pysqlite', 'mysql+pymysql'),
- fails_if(lambda: util.pypy),
- )
+ if util.pypy:
+ return _chain_decorators_on(
+ fn,
+ fails_if(lambda:True)
+ )
+ else:
+ return _chain_decorators_on(
+ fn,
+ fails_on_everything_except('mysql+mysqldb', 'mysql+oursql',
+ 'sqlite+pysqlite', 'mysql+pymysql'),
+ )
def sane_multi_rowcount(fn):
return _chain_decorators_on(
diff --git a/test/orm/test_query.py b/test/orm/test_query.py
index a87e1398a..3a60e878d 100644
--- a/test/orm/test_query.py
+++ b/test/orm/test_query.py
@@ -66,6 +66,19 @@ class QueryTest(_fixtures.FixtureTest):
configure_mappers()
+class MiscTest(QueryTest):
+ run_create_tables = None
+ run_inserts = None
+
+ def test_with_session(self):
+ User = self.classes.User
+ s1 = Session()
+ s2 = Session()
+ q1 = s1.query(User)
+ q2 = q1.with_session(s2)
+ assert q2.session is s2
+ assert q1.session is s1
+
class RowTupleTest(QueryTest):
run_setup_mappers = None