From bac14cdf477151f5d3bea3450565462a66c17ee2 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 23 Oct 2012 12:08:20 -0400 Subject: Added a new method :meth:`.Engine.execution_options` to :class:`.Engine`. This method works similarly to :class:`.Connection.execution_options` in that it creates a copy of the parent object which will refer to the new set of options. The method can be used to build sharding schemes where each engine shares the same underlying pool of connections. The method has been tested against the horizontal shard recipe in the ORM as well. --- test/ext/test_horizontal_shard.py | 86 +++++++++++++++++++++++++++++---------- 1 file changed, 65 insertions(+), 21 deletions(-) (limited to 'test/ext') diff --git a/test/ext/test_horizontal_shard.py b/test/ext/test_horizontal_shard.py index 0b9b89d2b..024510886 100644 --- a/test/ext/test_horizontal_shard.py +++ b/test/ext/test_horizontal_shard.py @@ -5,6 +5,7 @@ from sqlalchemy import sql, util from sqlalchemy.orm import * from sqlalchemy.ext.horizontal_shard import ShardedSession from sqlalchemy.sql import operators +from sqlalchemy import pool from sqlalchemy.testing import fixtures from sqlalchemy import testing from sqlalchemy.testing.engines import testing_engine @@ -13,19 +14,19 @@ from nose import SkipTest # TODO: ShardTest can be turned into a base for further subclasses -class ShardTest(fixtures.TestBase): + + + +class ShardTest(object): __skip_if__ = (lambda: util.win32,) + __requires__ = 'sqlite', + + schema = None def setUp(self): global db1, db2, db3, db4, weather_locations, weather_reports - try: - db1 = testing_engine('sqlite:///shard1.db', options=dict(pool_threadlocal=True)) - except ImportError: - raise SkipTest('Requires sqlite') - db2 = testing_engine('sqlite:///shard2.db') - db3 = testing_engine('sqlite:///shard3.db') - db4 = testing_engine('sqlite:///shard4.db') + db1, db2, db3, db4 = self._init_dbs() meta = MetaData() ids = Table('ids', meta, @@ -36,13 +37,14 @@ class ShardTest(fixtures.TestBase): c = db1.contextual_connect() nextid = c.execute(ids.select(for_update=True)).scalar() - c.execute(ids.update(values={ids.c.nextid : ids.c.nextid + 1})) + c.execute(ids.update(values={ids.c.nextid: ids.c.nextid + 1})) return nextid weather_locations = Table("weather_locations", meta, Column('id', Integer, primary_key=True, default=id_generator), Column('continent', String(30), nullable=False), - Column('city', String(50), nullable=False) + Column('city', String(50), nullable=False), + schema=self.schema ) weather_reports = Table( @@ -50,10 +52,11 @@ class ShardTest(fixtures.TestBase): meta, Column('id', Integer, primary_key=True), Column('location_id', Integer, - ForeignKey('weather_locations.id')), + ForeignKey(weather_locations.c.id)), Column('temperature', Float), Column('report_time', DateTime, default=datetime.datetime.now), + schema=self.schema ) for db in (db1, db2, db3, db4): @@ -64,13 +67,6 @@ class ShardTest(fixtures.TestBase): self.setup_session() self.setup_mappers() - def tearDown(self): - clear_mappers() - - for db in (db1, db2, db3, db4): - db.connect().invalidate() - for i in range(1,5): - os.remove("shard%d.db" % i) @classmethod def setup_session(cls): @@ -139,11 +135,12 @@ class ShardTest(fixtures.TestBase): self.temperature = temperature mapper(WeatherLocation, weather_locations, properties={ - 'reports':relationship(Report, backref='location'), + 'reports': relationship(Report, backref='location'), 'city': deferred(weather_locations.c.city), }) mapper(Report, weather_reports) + def _fixture_data(self): tokyo = WeatherLocation('Asia', 'Tokyo') newyork = WeatherLocation('North America', 'New York') @@ -204,7 +201,8 @@ class ShardTest(fixtures.TestBase): event.listen(WeatherLocation, "load", load) sess = self._fixture_data() - tokyo = sess.query(WeatherLocation).filter_by(city="Tokyo").set_shard("asia").one() + tokyo = sess.query(WeatherLocation).\ + filter_by(city="Tokyo").set_shard("asia").one() sess.query(WeatherLocation).all() eq_( @@ -212,4 +210,50 @@ class ShardTest(fixtures.TestBase): ['asia', 'north_america', 'north_america', 'europe', 'europe', 'south_america', 'south_america'] - ) \ No newline at end of file + ) + +class DistinctEngineShardTest(ShardTest, fixtures.TestBase): + + def _init_dbs(self): + db1 = testing_engine('sqlite:///shard1.db', + options=dict(pool_threadlocal=True)) + db2 = testing_engine('sqlite:///shard2.db') + db3 = testing_engine('sqlite:///shard3.db') + db4 = testing_engine('sqlite:///shard4.db') + + return db1, db2, db3, db4 + + def tearDown(self): + clear_mappers() + + for db in (db1, db2, db3, db4): + db.connect().invalidate() + for i in range(1, 5): + os.remove("shard%d.db" % i) + +class AttachedFileShardTest(ShardTest, fixtures.TestBase): + schema = "changeme" + + def _init_dbs(self): + db1 = testing_engine('sqlite://', options={"execution_options": + {"shard_id": "shard1"}}) + assert db1._has_events + + db2 = db1.execution_options(shard_id="shard2") + db3 = db1.execution_options(shard_id="shard3") + db4 = db1.execution_options(shard_id="shard4") + + import re + @event.listens_for(db1, "before_cursor_execute", retval=True) + def _switch_shard(conn, cursor, stmt, params, context, executemany): + shard_id = conn._execution_options['shard_id'] + # because SQLite can't just give us a "use" statement, we have + # to use the schema hack to locate table names + if shard_id: + stmt = re.sub(r"\"?changeme\"?\.", shard_id + "_", stmt) + + return stmt, params + + return db1, db2, db3, db4 + + -- cgit v1.2.1