diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2009-08-06 21:11:27 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2009-08-06 21:11:27 +0000 |
| commit | 8fc5005dfe3eb66a46470ad8a8c7b95fc4d6bdca (patch) | |
| tree | ae9e27d12c9fbf8297bb90469509e1cb6a206242 /test/engine | |
| parent | 7638aa7f242c6ea3d743aa9100e32be2052546a6 (diff) | |
| download | sqlalchemy-8fc5005dfe3eb66a46470ad8a8c7b95fc4d6bdca.tar.gz | |
merge 0.6 series to trunk.
Diffstat (limited to 'test/engine')
| -rw-r--r-- | test/engine/test_bind.py | 4 | ||||
| -rw-r--r-- | test/engine/test_ddlevents.py | 72 | ||||
| -rw-r--r-- | test/engine/test_execute.py | 34 | ||||
| -rw-r--r-- | test/engine/test_metadata.py | 41 | ||||
| -rw-r--r-- | test/engine/test_parseconnect.py | 127 | ||||
| -rw-r--r-- | test/engine/test_pool.py | 630 | ||||
| -rw-r--r-- | test/engine/test_reconnect.py | 18 | ||||
| -rw-r--r-- | test/engine/test_reflection.py | 467 | ||||
| -rw-r--r-- | test/engine/test_transaction.py | 34 |
9 files changed, 948 insertions, 479 deletions
diff --git a/test/engine/test_bind.py b/test/engine/test_bind.py index 7fd3009bc..1122f1632 100644 --- a/test/engine/test_bind.py +++ b/test/engine/test_bind.py @@ -121,7 +121,7 @@ class BindTest(testing.TestBase): table = Table('test_table', metadata, Column('foo', Integer)) - metadata.connect(bind) + metadata.bind = bind assert metadata.bind is table.bind is bind metadata.create_all() @@ -199,7 +199,7 @@ class BindTest(testing.TestBase): try: e = elem(bind=bind) assert e.bind is bind - e.execute() + e.execute().close() finally: if isinstance(bind, engine.Connection): bind.close() diff --git a/test/engine/test_ddlevents.py b/test/engine/test_ddlevents.py index 5716006d9..434a5d873 100644 --- a/test/engine/test_ddlevents.py +++ b/test/engine/test_ddlevents.py @@ -1,12 +1,13 @@ -from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message -from sqlalchemy.schema import DDL +from sqlalchemy.test.testing import assert_raises, assert_raises_message +from sqlalchemy.schema import DDL, CheckConstraint, AddConstraint, DropConstraint from sqlalchemy import create_engine from sqlalchemy import MetaData, Integer, String from sqlalchemy.test.schema import Table from sqlalchemy.test.schema import Column import sqlalchemy as tsa from sqlalchemy.test import TestBase, testing, engines - +from sqlalchemy.test.testing import AssertsCompiledSQL +from nose import SkipTest class DDLEventTest(TestBase): class Canary(object): @@ -15,25 +16,25 @@ class DDLEventTest(TestBase): self.schema_item = schema_item self.bind = bind - def before_create(self, action, schema_item, bind): + def before_create(self, action, schema_item, bind, **kw): assert self.state is None assert schema_item is self.schema_item assert bind is self.bind self.state = action - def after_create(self, action, schema_item, bind): + def after_create(self, action, schema_item, bind, **kw): assert self.state in ('before-create', 'skipped') assert schema_item is self.schema_item assert bind is self.bind self.state = action - def before_drop(self, action, schema_item, bind): + def before_drop(self, action, schema_item, bind, **kw): assert self.state is None assert schema_item is self.schema_item assert bind is self.bind self.state = action - def after_drop(self, action, schema_item, bind): + def after_drop(self, action, schema_item, bind, **kw): assert self.state in ('before-drop', 'skipped') assert schema_item is self.schema_item assert bind is self.bind @@ -232,7 +233,33 @@ class DDLExecutionTest(TestBase): assert 'klptzyxm' not in strings assert 'xyzzy' in strings assert 'fnord' in strings - + + def test_conditional_constraint(self): + metadata, users, engine = self.metadata, self.users, self.engine + nonpg_mock = engines.mock_engine(dialect_name='sqlite') + pg_mock = engines.mock_engine(dialect_name='postgresql') + + constraint = CheckConstraint('a < b',name="my_test_constraint", table=users) + + # by placing the constraint in an Add/Drop construct, + # the 'inline_ddl' flag is set to False + AddConstraint(constraint, on='postgresql').execute_at("after-create", users) + DropConstraint(constraint, on='postgresql').execute_at("before-drop", users) + + metadata.create_all(bind=nonpg_mock) + strings = " ".join(str(x) for x in nonpg_mock.mock) + assert "my_test_constraint" not in strings + metadata.drop_all(bind=nonpg_mock) + strings = " ".join(str(x) for x in nonpg_mock.mock) + assert "my_test_constraint" not in strings + + metadata.create_all(bind=pg_mock) + strings = " ".join(str(x) for x in pg_mock.mock) + assert "my_test_constraint" in strings + metadata.drop_all(bind=pg_mock) + strings = " ".join(str(x) for x in pg_mock.mock) + assert "my_test_constraint" in strings + def test_metadata(self): metadata, engine = self.metadata, self.engine DDL('mxyzptlk').execute_at('before-create', metadata) @@ -255,7 +282,10 @@ class DDLExecutionTest(TestBase): assert 'fnord' in strings def test_ddl_execute(self): - engine = create_engine('sqlite:///') + try: + engine = create_engine('sqlite:///') + except ImportError: + raise SkipTest('Requires sqlite') cx = engine.connect() table = self.users ddl = DDL('SELECT 1') @@ -286,7 +316,7 @@ class DDLExecutionTest(TestBase): r = eval(py) assert list(r) == [(1,)], py -class DDLTest(TestBase): +class DDLTest(TestBase, AssertsCompiledSQL): def mock_engine(self): executor = lambda *a, **kw: None engine = create_engine(testing.db.name + '://', @@ -297,7 +327,6 @@ class DDLTest(TestBase): def test_tokens(self): m = MetaData() - bind = self.mock_engine() sane_alone = Table('t', m, Column('id', Integer)) sane_schema = Table('t', m, Column('id', Integer), schema='s') insane_alone = Table('t t', m, Column('id', Integer)) @@ -305,20 +334,21 @@ class DDLTest(TestBase): ddl = DDL('%(schema)s-%(table)s-%(fullname)s') - eq_(ddl._expand(sane_alone, bind), '-t-t') - eq_(ddl._expand(sane_schema, bind), 's-t-s.t') - eq_(ddl._expand(insane_alone, bind), '-"t t"-"t t"') - eq_(ddl._expand(insane_schema, bind), - '"s s"-"t t"-"s s"."t t"') + dialect = self.mock_engine().dialect + self.assert_compile(ddl.against(sane_alone), '-t-t', dialect=dialect) + self.assert_compile(ddl.against(sane_schema), 's-t-s.t', dialect=dialect) + self.assert_compile(ddl.against(insane_alone), '-"t t"-"t t"', dialect=dialect) + self.assert_compile(ddl.against(insane_schema), '"s s"-"t t"-"s s"."t t"', dialect=dialect) # overrides are used piece-meal and verbatim. ddl = DDL('%(schema)s-%(table)s-%(fullname)s-%(bonus)s', context={'schema':'S S', 'table': 'T T', 'bonus': 'b'}) - eq_(ddl._expand(sane_alone, bind), 'S S-T T-t-b') - eq_(ddl._expand(sane_schema, bind), 'S S-T T-s.t-b') - eq_(ddl._expand(insane_alone, bind), 'S S-T T-"t t"-b') - eq_(ddl._expand(insane_schema, bind), - 'S S-T T-"s s"."t t"-b') + + self.assert_compile(ddl.against(sane_alone), 'S S-T T-t-b', dialect=dialect) + self.assert_compile(ddl.against(sane_schema), 'S S-T T-s.t-b', dialect=dialect) + self.assert_compile(ddl.against(insane_alone), 'S S-T T-"t t"-b', dialect=dialect) + self.assert_compile(ddl.against(insane_schema), 'S S-T T-"s s"."t t"-b', dialect=dialect) + def test_filter(self): cx = self.mock_engine() diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index 08bf80fe2..4783c5508 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -15,18 +15,20 @@ class ExecuteTest(TestBase): global users, metadata metadata = MetaData(testing.db) users = Table('users', metadata, - Column('user_id', INT, primary_key = True), + Column('user_id', INT, primary_key = True, test_needs_autoincrement=True), Column('user_name', VARCHAR(20)), ) metadata.create_all() + @engines.close_first def teardown(self): testing.db.connect().execute(users.delete()) + @classmethod def teardown_class(cls): metadata.drop_all() - @testing.fails_on_everything_except('firebird', 'maxdb', 'sqlite') + @testing.fails_on_everything_except('firebird', 'maxdb', 'sqlite', 'mysql+pyodbc', '+zxjdbc') def test_raw_qmark(self): for conn in (testing.db, testing.db.connect()): conn.execute("insert into users (user_id, user_name) values (?, ?)", (1,"jack")) @@ -38,7 +40,8 @@ class ExecuteTest(TestBase): assert res.fetchall() == [(1, "jack"), (2, "fred"), (3, "ed"), (4, "horse"), (5, "barney"), (6, "donkey"), (7, 'sally')] conn.execute("delete from users") - @testing.fails_on_everything_except('mysql', 'postgres') + @testing.fails_on_everything_except('mysql+mysqldb', 'postgresql') + @testing.fails_on('postgresql+zxjdbc', 'sprintf not supported') # some psycopg2 versions bomb this. def test_raw_sprintf(self): for conn in (testing.db, testing.db.connect()): @@ -52,8 +55,8 @@ class ExecuteTest(TestBase): # pyformat is supported for mysql, but skipping because a few driver # versions have a bug that bombs out on this test. (1.2.2b3, 1.2.2c1, 1.2.2) - @testing.skip_if(lambda: testing.against('mysql'), 'db-api flaky') - @testing.fails_on_everything_except('postgres') + @testing.skip_if(lambda: testing.against('mysql+mysqldb'), 'db-api flaky') + @testing.fails_on_everything_except('postgresql+psycopg2') def test_raw_python(self): for conn in (testing.db, testing.db.connect()): conn.execute("insert into users (user_id, user_name) values (%(id)s, %(name)s)", {'id':1, 'name':'jack'}) @@ -63,7 +66,7 @@ class ExecuteTest(TestBase): assert res.fetchall() == [(1, "jack"), (2, "ed"), (3, "horse"), (4, 'sally')] conn.execute("delete from users") - @testing.fails_on_everything_except('sqlite', 'oracle') + @testing.fails_on_everything_except('sqlite', 'oracle+cx_oracle') def test_raw_named(self): for conn in (testing.db, testing.db.connect()): conn.execute("insert into users (user_id, user_name) values (:id, :name)", {'id':1, 'name':'jack'}) @@ -81,11 +84,12 @@ class ExecuteTest(TestBase): except tsa.exc.DBAPIError: assert True - @testing.fails_on('mssql', 'rowcount returns -1') def test_empty_insert(self): """test that execute() interprets [] as a list with no params""" result = testing.db.execute(users.insert().values(user_name=bindparam('name')), []) - eq_(result.rowcount, 1) + eq_(testing.db.execute(users.select()).fetchall(), [ + (1, None) + ]) class ProxyConnectionTest(TestBase): @testing.fails_on('firebird', 'Data type unknown') @@ -102,6 +106,7 @@ class ProxyConnectionTest(TestBase): return execute(clauseelement, *multiparams, **params) def cursor_execute(self, execute, cursor, statement, parameters, context, executemany): + print "CE", statement, parameters cursor_stmts.append( (statement, parameters, None) ) @@ -118,8 +123,8 @@ class ProxyConnectionTest(TestBase): break for engine in ( - engines.testing_engine(options=dict(proxy=MyProxy())), - engines.testing_engine(options=dict(proxy=MyProxy(), strategy='threadlocal')) + engines.testing_engine(options=dict(implicit_returning=False, proxy=MyProxy())), + engines.testing_engine(options=dict(implicit_returning=False, proxy=MyProxy(), strategy='threadlocal')) ): m = MetaData(engine) @@ -131,6 +136,7 @@ class ProxyConnectionTest(TestBase): t1.insert().execute(c1=6) assert engine.execute("select * from t1").fetchall() == [(5, 'some data'), (6, 'foo')] finally: + pass m.drop_all() engine.dispose() @@ -143,14 +149,14 @@ class ProxyConnectionTest(TestBase): ("DROP TABLE t1", {}, None) ] - if engine.dialect.preexecute_pk_sequences: + if True: # or engine.dialect.preexecute_pk_sequences: cursor = [ - ("CREATE TABLE t1", {}, None), + ("CREATE TABLE t1", {}, ()), ("INSERT INTO t1 (c1, c2)", {'c2': 'some data', 'c1': 5}, [5, 'some data']), ("SELECT lower", {'lower_2':'Foo'}, ['Foo']), ("INSERT INTO t1 (c1, c2)", {'c2': 'foo', 'c1': 6}, [6, 'foo']), - ("select * from t1", {}, None), - ("DROP TABLE t1", {}, None) + ("select * from t1", {}, ()), + ("DROP TABLE t1", {}, ()) ] else: cursor = [ diff --git a/test/engine/test_metadata.py b/test/engine/test_metadata.py index ca4fbaa48..784a7b9ce 100644 --- a/test/engine/test_metadata.py +++ b/test/engine/test_metadata.py @@ -1,11 +1,11 @@ from sqlalchemy.test.testing import assert_raises, assert_raises_message import pickle -from sqlalchemy import MetaData -from sqlalchemy import Integer, String, UniqueConstraint, CheckConstraint, ForeignKey +from sqlalchemy import Integer, String, UniqueConstraint, CheckConstraint, ForeignKey, MetaData from sqlalchemy.test.schema import Table from sqlalchemy.test.schema import Column +from sqlalchemy import schema import sqlalchemy as tsa -from sqlalchemy.test import TestBase, ComparesTables, testing, engines +from sqlalchemy.test import TestBase, ComparesTables, AssertsCompiledSQL, testing, engines from sqlalchemy.test.testing import eq_ class MetaDataTest(TestBase, ComparesTables): @@ -83,7 +83,7 @@ class MetaDataTest(TestBase, ComparesTables): meta.create_all(testing.db) try: - for test, has_constraints in ((test_to_metadata, True), (test_pickle, True), (test_pickle_via_reflect, False)): + for test, has_constraints in ((test_to_metadata, True), (test_pickle, True),(test_pickle_via_reflect, False)): table_c, table2_c = test() self.assert_tables_equal(table, table_c) self.assert_tables_equal(table2, table2_c) @@ -143,29 +143,30 @@ class MetaDataTest(TestBase, ComparesTables): MetaData(testing.db), autoload=True) -class TableOptionsTest(TestBase): - def setup(self): - self.engine = engines.mock_engine() - self.metadata = MetaData(self.engine) - +class TableOptionsTest(TestBase, AssertsCompiledSQL): def test_prefixes(self): - table1 = Table("temporary_table_1", self.metadata, + table1 = Table("temporary_table_1", MetaData(), Column("col1", Integer), prefixes = ["TEMPORARY"]) - table1.create() - assert [str(x) for x in self.engine.mock if 'CREATE TEMPORARY TABLE' in str(x)] - del self.engine.mock[:] - table2 = Table("temporary_table_2", self.metadata, + + self.assert_compile( + schema.CreateTable(table1), + "CREATE TEMPORARY TABLE temporary_table_1 (col1 INTEGER)" + ) + + table2 = Table("temporary_table_2", MetaData(), Column("col1", Integer), prefixes = ["VIRTUAL"]) - table2.create() - assert [str(x) for x in self.engine.mock if 'CREATE VIRTUAL TABLE' in str(x)] + self.assert_compile( + schema.CreateTable(table2), + "CREATE VIRTUAL TABLE temporary_table_2 (col1 INTEGER)" + ) def test_table_info(self): - - t1 = Table('foo', self.metadata, info={'x':'y'}) - t2 = Table('bar', self.metadata, info={}) - t3 = Table('bat', self.metadata) + metadata = MetaData() + t1 = Table('foo', metadata, info={'x':'y'}) + t2 = Table('bar', metadata, info={}) + t3 = Table('bat', metadata) assert t1.info == {'x':'y'} assert t2.info == {} assert t3.info == {} diff --git a/test/engine/test_parseconnect.py b/test/engine/test_parseconnect.py index 6b7ac37b2..90c0969be 100644 --- a/test/engine/test_parseconnect.py +++ b/test/engine/test_parseconnect.py @@ -1,4 +1,6 @@ -import ConfigParser, StringIO +from sqlalchemy.test.testing import assert_raises, assert_raises_message +import ConfigParser +import StringIO import sqlalchemy.engine.url as url from sqlalchemy import create_engine, engine_from_config import sqlalchemy as tsa @@ -28,8 +30,6 @@ class ParseConnectTest(TestBase): 'dbtype://username:apples%2Foranges@hostspec/mydatabase', ): u = url.make_url(text) - print u, text - print "username=", u.username, "password=", u.password, "database=", u.database, "host=", u.host assert u.drivername == 'dbtype' assert u.username == 'username' or u.username is None assert u.password == 'password' or u.password == 'apples/oranges' or u.password is None @@ -41,21 +41,28 @@ class CreateEngineTest(TestBase): def test_connect_query(self): dbapi = MockDBAPI(foober='12', lala='18', fooz='somevalue') - # start the postgres dialect, but put our mock DBAPI as the module instead of psycopg - e = create_engine('postgres://scott:tiger@somehost/test?foober=12&lala=18&fooz=somevalue', module=dbapi) + e = create_engine( + 'postgresql://scott:tiger@somehost/test?foober=12&lala=18&fooz=somevalue', + module=dbapi, + _initialize=False + ) c = e.connect() def test_kwargs(self): dbapi = MockDBAPI(foober=12, lala=18, hoho={'this':'dict'}, fooz='somevalue') - # start the postgres dialect, but put our mock DBAPI as the module instead of psycopg - e = create_engine('postgres://scott:tiger@somehost/test?fooz=somevalue', connect_args={'foober':12, 'lala':18, 'hoho':{'this':'dict'}}, module=dbapi) + e = create_engine( + 'postgresql://scott:tiger@somehost/test?fooz=somevalue', + connect_args={'foober':12, 'lala':18, 'hoho':{'this':'dict'}}, + module=dbapi, + _initialize=False + ) c = e.connect() def test_coerce_config(self): raw = r""" [prefixed] -sqlalchemy.url=postgres://scott:tiger@somehost/test?fooz=somevalue +sqlalchemy.url=postgresql://scott:tiger@somehost/test?fooz=somevalue sqlalchemy.convert_unicode=0 sqlalchemy.echo=false sqlalchemy.echo_pool=1 @@ -65,7 +72,7 @@ sqlalchemy.pool_size=2 sqlalchemy.pool_threadlocal=1 sqlalchemy.pool_timeout=10 [plain] -url=postgres://scott:tiger@somehost/test?fooz=somevalue +url=postgresql://scott:tiger@somehost/test?fooz=somevalue convert_unicode=0 echo=0 echo_pool=1 @@ -79,7 +86,7 @@ pool_timeout=10 ini.readfp(StringIO.StringIO(raw)) expected = { - 'url': 'postgres://scott:tiger@somehost/test?fooz=somevalue', + 'url': 'postgresql://scott:tiger@somehost/test?fooz=somevalue', 'convert_unicode': 0, 'echo': False, 'echo_pool': True, @@ -97,17 +104,17 @@ pool_timeout=10 self.assert_(tsa.engine._coerce_config(plain, '') == expected) def test_engine_from_config(self): - dbapi = MockDBAPI() + dbapi = mock_dbapi config = { - 'sqlalchemy.url':'postgres://scott:tiger@somehost/test?fooz=somevalue', + 'sqlalchemy.url':'postgresql://scott:tiger@somehost/test?fooz=somevalue', 'sqlalchemy.pool_recycle':'50', 'sqlalchemy.echo':'true' } e = engine_from_config(config, module=dbapi) assert e.pool._recycle == 50 - assert e.url == url.make_url('postgres://scott:tiger@somehost/test?fooz=somevalue') + assert e.url == url.make_url('postgresql://scott:tiger@somehost/test?fooz=somevalue') assert e.echo is True def test_custom(self): @@ -116,109 +123,77 @@ pool_timeout=10 def connect(): return dbapi.connect(foober=12, lala=18, fooz='somevalue', hoho={'this':'dict'}) - # start the postgres dialect, but put our mock DBAPI as the module instead of psycopg - e = create_engine('postgres://', creator=connect, module=dbapi) + # start the postgresql dialect, but put our mock DBAPI as the module instead of psycopg + e = create_engine('postgresql://', creator=connect, module=dbapi, _initialize=False) c = e.connect() def test_recycle(self): dbapi = MockDBAPI(foober=12, lala=18, hoho={'this':'dict'}, fooz='somevalue') - e = create_engine('postgres://', pool_recycle=472, module=dbapi) + e = create_engine('postgresql://', pool_recycle=472, module=dbapi, _initialize=False) assert e.pool._recycle == 472 def test_badargs(self): - # good arg, use MockDBAPI to prevent oracle import errors - e = create_engine('oracle://', use_ansi=True, module=MockDBAPI()) - - try: - e = create_engine("foobar://", module=MockDBAPI()) - assert False - except ImportError: - assert True + assert_raises(ImportError, create_engine, "foobar://", module=mock_dbapi) # bad arg - try: - e = create_engine('postgres://', use_ansi=True, module=MockDBAPI()) - assert False - except TypeError: - assert True + assert_raises(TypeError, create_engine, 'postgresql://', use_ansi=True, module=mock_dbapi) # bad arg - try: - e = create_engine('oracle://', lala=5, use_ansi=True, module=MockDBAPI()) - assert False - except TypeError: - assert True + assert_raises(TypeError, create_engine, 'oracle://', lala=5, use_ansi=True, module=mock_dbapi) - try: - e = create_engine('postgres://', lala=5, module=MockDBAPI()) - assert False - except TypeError: - assert True + assert_raises(TypeError, create_engine, 'postgresql://', lala=5, module=mock_dbapi) - try: - e = create_engine('sqlite://', lala=5) - assert False - except TypeError: - assert True + assert_raises(TypeError, create_engine,'sqlite://', lala=5, module=mock_sqlite_dbapi) - try: - e = create_engine('mysql://', use_unicode=True, module=MockDBAPI()) - assert False - except TypeError: - assert True - - try: - # sqlite uses SingletonThreadPool which doesnt have max_overflow - e = create_engine('sqlite://', max_overflow=5) - assert False - except TypeError: - assert True + assert_raises(TypeError, create_engine, 'mysql+mysqldb://', use_unicode=True, module=mock_dbapi) - e = create_engine('mysql://', module=MockDBAPI(), connect_args={'use_unicode':True}, convert_unicode=True) + # sqlite uses SingletonThreadPool which doesnt have max_overflow + assert_raises(TypeError, create_engine, 'sqlite://', max_overflow=5, + module=mock_sqlite_dbapi) - e = create_engine('sqlite://', connect_args={'use_unicode':True}, convert_unicode=True) try: - c = e.connect() - assert False - except tsa.exc.DBAPIError: - assert True + e = create_engine('sqlite://', connect_args={'use_unicode':True}, convert_unicode=True) + except ImportError: + # no sqlite + pass + else: + # raises DBAPIerror due to use_unicode not a sqlite arg + assert_raises(tsa.exc.DBAPIError, e.connect) def test_urlattr(self): """test the url attribute on ``Engine``.""" - e = create_engine('mysql://scott:tiger@localhost/test', module=MockDBAPI()) + e = create_engine('mysql://scott:tiger@localhost/test', module=mock_dbapi, _initialize=False) u = url.make_url('mysql://scott:tiger@localhost/test') - e2 = create_engine(u, module=MockDBAPI()) + e2 = create_engine(u, module=mock_dbapi, _initialize=False) assert e.url.drivername == e2.url.drivername == 'mysql' assert e.url.username == e2.url.username == 'scott' assert e2.url is u def test_poolargs(self): """test that connection pool args make it thru""" - e = create_engine('postgres://', creator=None, pool_recycle=50, echo_pool=None, module=MockDBAPI()) + e = create_engine('postgresql://', creator=None, pool_recycle=50, echo_pool=None, module=mock_dbapi, _initialize=False) assert e.pool._recycle == 50 # these args work for QueuePool - e = create_engine('postgres://', max_overflow=8, pool_timeout=60, poolclass=tsa.pool.QueuePool, module=MockDBAPI()) + e = create_engine('postgresql://', max_overflow=8, pool_timeout=60, poolclass=tsa.pool.QueuePool, module=mock_dbapi) - try: - # but not SingletonThreadPool - e = create_engine('sqlite://', max_overflow=8, pool_timeout=60, poolclass=tsa.pool.SingletonThreadPool) - assert False - except TypeError: - assert True + # but not SingletonThreadPool + assert_raises(TypeError, create_engine, 'sqlite://', max_overflow=8, pool_timeout=60, + poolclass=tsa.pool.SingletonThreadPool, module=mock_sqlite_dbapi) class MockDBAPI(object): def __init__(self, **kwargs): self.kwargs = kwargs self.paramstyle = 'named' - def connect(self, **kwargs): - print kwargs, self.kwargs + def connect(self, *args, **kwargs): for k in self.kwargs: assert k in kwargs, "key %s not present in dictionary" % k assert kwargs[k]==self.kwargs[k], "value %s does not match %s" % (kwargs[k], self.kwargs[k]) return MockConnection() class MockConnection(object): + def get_server_info(self): + return "5.0" def close(self): pass def cursor(self): @@ -227,4 +202,6 @@ class MockCursor(object): def close(self): pass mock_dbapi = MockDBAPI() - +mock_sqlite_dbapi = msd = MockDBAPI() +msd.version_info = msd.sqlite_version_info = (99, 9, 9) +msd.sqlite_version = '99.9.9' diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py index d135ad337..68637281e 100644 --- a/test/engine/test_pool.py +++ b/test/engine/test_pool.py @@ -1,7 +1,8 @@ -import threading, time, gc -from sqlalchemy import pool, interfaces +import threading, time +from sqlalchemy import pool, interfaces, create_engine, select import sqlalchemy as tsa -from sqlalchemy.test import TestBase +from sqlalchemy.test import TestBase, testing +from sqlalchemy.test.util import gc_collect, lazy_gc mcid = 1 @@ -51,7 +52,6 @@ class PoolTest(PoolTestBase): connection2 = manager.connect('foo.db') connection3 = manager.connect('bar.db') - print "connection " + repr(connection) self.assert_(connection.cursor() is not None) self.assert_(connection is connection2) self.assert_(connection2 is not connection3) @@ -70,8 +70,6 @@ class PoolTest(PoolTestBase): connection = manager.connect('foo.db') connection2 = manager.connect('foo.db') - print "connection " + repr(connection) - self.assert_(connection.cursor() is not None) self.assert_(connection is not connection2) @@ -103,7 +101,8 @@ class PoolTest(PoolTestBase): c2.close() else: c2 = None - + lazy_gc() + if useclose: c1 = p.connect() c2 = p.connect() @@ -117,6 +116,8 @@ class PoolTest(PoolTestBase): # extra tests with QueuePool to ensure connections get __del__()ed when dereferenced if isinstance(p, pool.QueuePool): + lazy_gc() + self.assert_(p.checkedout() == 0) c1 = p.connect() c2 = p.connect() @@ -126,6 +127,7 @@ class PoolTest(PoolTestBase): else: c2 = None c1 = None + lazy_gc() self.assert_(p.checkedout() == 0) def test_properties(self): @@ -164,6 +166,8 @@ class PoolTest(PoolTestBase): def __init__(self): if hasattr(self, 'connect'): self.connect = self.inst_connect + if hasattr(self, 'first_connect'): + self.first_connect = self.inst_first_connect if hasattr(self, 'checkout'): self.checkout = self.inst_checkout if hasattr(self, 'checkin'): @@ -171,14 +175,17 @@ class PoolTest(PoolTestBase): self.clear() def clear(self): self.connected = [] + self.first_connected = [] self.checked_out = [] self.checked_in = [] - def assert_total(innerself, conn, cout, cin): + def assert_total(innerself, conn, fconn, cout, cin): self.assert_(len(innerself.connected) == conn) + self.assert_(len(innerself.first_connected) == fconn) self.assert_(len(innerself.checked_out) == cout) self.assert_(len(innerself.checked_in) == cin) - def assert_in(innerself, item, in_conn, in_cout, in_cin): + def assert_in(innerself, item, in_conn, in_fconn, in_cout, in_cin): self.assert_((item in innerself.connected) == in_conn) + self.assert_((item in innerself.first_connected) == in_fconn) self.assert_((item in innerself.checked_out) == in_cout) self.assert_((item in innerself.checked_in) == in_cin) def inst_connect(self, con, record): @@ -186,6 +193,11 @@ class PoolTest(PoolTestBase): assert con is not None assert record is not None self.connected.append(con) + def inst_first_connect(self, con, record): + print "first_connect(%s, %s)" % (con, record) + assert con is not None + assert record is not None + self.first_connected.append(con) def inst_checkout(self, con, record, proxy): print "checkout(%s, %s, %s)" % (con, record, proxy) assert con is not None @@ -203,6 +215,9 @@ class PoolTest(PoolTestBase): class ListenConnect(InstrumentingListener): def connect(self, con, record): pass + class ListenFirstConnect(InstrumentingListener): + def first_connect(self, con, record): + pass class ListenCheckOut(InstrumentingListener): def checkout(self, con, record, proxy, num): pass @@ -214,40 +229,43 @@ class PoolTest(PoolTestBase): return pool.QueuePool(creator=lambda: dbapi.connect('foo.db'), use_threadlocal=False, **kw) - def assert_listeners(p, total, conn, cout, cin): + def assert_listeners(p, total, conn, fconn, cout, cin): for instance in (p, p.recreate()): self.assert_(len(instance.listeners) == total) self.assert_(len(instance._on_connect) == conn) + self.assert_(len(instance._on_first_connect) == fconn) self.assert_(len(instance._on_checkout) == cout) self.assert_(len(instance._on_checkin) == cin) p = _pool() - assert_listeners(p, 0, 0, 0, 0) + assert_listeners(p, 0, 0, 0, 0, 0) p.add_listener(ListenAll()) - assert_listeners(p, 1, 1, 1, 1) + assert_listeners(p, 1, 1, 1, 1, 1) p.add_listener(ListenConnect()) - assert_listeners(p, 2, 2, 1, 1) + assert_listeners(p, 2, 2, 1, 1, 1) + + p.add_listener(ListenFirstConnect()) + assert_listeners(p, 3, 2, 2, 1, 1) p.add_listener(ListenCheckOut()) - assert_listeners(p, 3, 2, 2, 1) + assert_listeners(p, 4, 2, 2, 2, 1) p.add_listener(ListenCheckIn()) - assert_listeners(p, 4, 2, 2, 2) + assert_listeners(p, 5, 2, 2, 2, 2) del p - print "----" snoop = ListenAll() p = _pool(listeners=[snoop]) - assert_listeners(p, 1, 1, 1, 1) + assert_listeners(p, 1, 1, 1, 1, 1) c = p.connect() - snoop.assert_total(1, 1, 0) + snoop.assert_total(1, 1, 1, 0) cc = c.connection - snoop.assert_in(cc, True, True, False) + snoop.assert_in(cc, True, True, True, False) c.close() - snoop.assert_in(cc, True, True, True) + snoop.assert_in(cc, True, True, True, True) del c, cc snoop.clear() @@ -255,10 +273,11 @@ class PoolTest(PoolTestBase): # this one depends on immediate gc c = p.connect() cc = c.connection - snoop.assert_in(cc, False, True, False) - snoop.assert_total(0, 1, 0) + snoop.assert_in(cc, False, False, True, False) + snoop.assert_total(0, 0, 1, 0) del c, cc - snoop.assert_total(0, 1, 1) + lazy_gc() + snoop.assert_total(0, 0, 1, 1) p.dispose() snoop.clear() @@ -266,44 +285,46 @@ class PoolTest(PoolTestBase): c = p.connect() c.close() c = p.connect() - snoop.assert_total(1, 2, 1) + snoop.assert_total(1, 0, 2, 1) c.close() - snoop.assert_total(1, 2, 2) + snoop.assert_total(1, 0, 2, 2) # invalidation p.dispose() snoop.clear() c = p.connect() - snoop.assert_total(1, 1, 0) + snoop.assert_total(1, 0, 1, 0) c.invalidate() - snoop.assert_total(1, 1, 1) + snoop.assert_total(1, 0, 1, 1) c.close() - snoop.assert_total(1, 1, 1) + snoop.assert_total(1, 0, 1, 1) del c - snoop.assert_total(1, 1, 1) + lazy_gc() + snoop.assert_total(1, 0, 1, 1) c = p.connect() - snoop.assert_total(2, 2, 1) + snoop.assert_total(2, 0, 2, 1) c.close() del c - snoop.assert_total(2, 2, 2) + lazy_gc() + snoop.assert_total(2, 0, 2, 2) # detached p.dispose() snoop.clear() c = p.connect() - snoop.assert_total(1, 1, 0) + snoop.assert_total(1, 0, 1, 0) c.detach() - snoop.assert_total(1, 1, 0) + snoop.assert_total(1, 0, 1, 0) c.close() del c - snoop.assert_total(1, 1, 0) + snoop.assert_total(1, 0, 1, 0) c = p.connect() - snoop.assert_total(2, 2, 0) + snoop.assert_total(2, 0, 2, 0) c.close() del c - snoop.assert_total(2, 2, 1) + snoop.assert_total(2, 0, 2, 1) def test_listeners_callables(self): dbapi = MockDBAPI() @@ -362,262 +383,293 @@ class PoolTest(PoolTestBase): c.close() assert counts == [1, 2, 3] + def test_listener_after_oninit(self): + """Test that listeners are called after OnInit is removed""" + called = [] + def listener(*args): + called.append(True) + listener.connect = listener + engine = create_engine(testing.db.url) + engine.pool.add_listener(listener) + engine.execute(select([1])) + assert called, "Listener not called on connect" + + class QueuePoolTest(PoolTestBase): - def testqueuepool_del(self): - self._do_testqueuepool(useclose=False) - - def testqueuepool_close(self): - self._do_testqueuepool(useclose=True) - - def _do_testqueuepool(self, useclose=False): - p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = False) - - def status(pool): - tup = (pool.size(), pool.checkedin(), pool.overflow(), pool.checkedout()) - print "Pool size: %d Connections in pool: %d Current Overflow: %d Current Checked out connections: %d" % tup - return tup - - c1 = p.connect() - self.assert_(status(p) == (3,0,-2,1)) - c2 = p.connect() - self.assert_(status(p) == (3,0,-1,2)) - c3 = p.connect() - self.assert_(status(p) == (3,0,0,3)) - c4 = p.connect() - self.assert_(status(p) == (3,0,1,4)) - c5 = p.connect() - self.assert_(status(p) == (3,0,2,5)) - c6 = p.connect() - self.assert_(status(p) == (3,0,3,6)) - if useclose: - c4.close() - c3.close() - c2.close() - else: - c4 = c3 = c2 = None - self.assert_(status(p) == (3,3,3,3)) - if useclose: - c1.close() - c5.close() - c6.close() - else: - c1 = c5 = c6 = None - self.assert_(status(p) == (3,3,0,0)) - c1 = p.connect() - c2 = p.connect() - self.assert_(status(p) == (3, 1, 0, 2), status(p)) - if useclose: - c2.close() - else: - c2 = None - self.assert_(status(p) == (3, 2, 0, 1)) - - def test_timeout(self): - p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = 0, use_threadlocal = False, timeout=2) - c1 = p.connect() - c2 = p.connect() - c3 = p.connect() - now = time.time() - try: - c4 = p.connect() - assert False - except tsa.exc.TimeoutError, e: - assert int(time.time() - now) == 2 - - def test_timeout_race(self): - # test a race condition where the initial connecting threads all race - # to queue.Empty, then block on the mutex. each thread consumes a - # connection as they go in. when the limit is reached, the remaining - # threads go in, and get TimeoutError; even though they never got to - # wait for the timeout on queue.get(). the fix involves checking the - # timeout again within the mutex, and if so, unlocking and throwing - # them back to the start of do_get() - p = pool.QueuePool(creator = lambda: mock_dbapi.connect(delay=.05), pool_size = 2, max_overflow = 1, use_threadlocal = False, timeout=3) - timeouts = [] - def checkout(): - for x in xrange(1): - now = time.time() - try: - c1 = p.connect() - except tsa.exc.TimeoutError, e: - timeouts.append(int(time.time()) - now) - continue - time.sleep(4) - c1.close() - - threads = [] - for i in xrange(10): - th = threading.Thread(target=checkout) - th.start() - threads.append(th) - for th in threads: - th.join() - - print timeouts - assert len(timeouts) > 0 - for t in timeouts: - assert abs(t - 3) < 1, "Not all timeouts were 3 seconds: " + repr(timeouts) - - def _test_overflow(self, thread_count, max_overflow): - def creator(): - time.sleep(.05) - return mock_dbapi.connect() - - p = pool.QueuePool(creator=creator, - pool_size=3, timeout=2, - max_overflow=max_overflow) - peaks = [] - def whammy(): - for i in range(10): - try: - con = p.connect() - time.sleep(.005) - peaks.append(p.overflow()) - con.close() - del con - except tsa.exc.TimeoutError: - pass - threads = [] - for i in xrange(thread_count): - th = threading.Thread(target=whammy) - th.start() - threads.append(th) - for th in threads: - th.join() - - self.assert_(max(peaks) <= max_overflow) - - def test_no_overflow(self): - self._test_overflow(40, 0) - - def test_max_overflow(self): - self._test_overflow(40, 5) - - def test_mixed_close(self): - p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) - c1 = p.connect() - c2 = p.connect() - assert c1 is c2 - c1.close() - c2 = None - assert p.checkedout() == 1 - c1 = None - assert p.checkedout() == 0 - - def test_weakref_kaboom(self): - p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) - c1 = p.connect() - c2 = p.connect() - c1.close() - c2 = None - del c1 - del c2 - gc.collect() - assert p.checkedout() == 0 - c3 = p.connect() - assert c3 is not None - - def test_trick_the_counter(self): - """this is a "flaw" in the connection pool; since threadlocal uses a single ConnectionFairy per thread - with an open/close counter, you can fool the counter into giving you a ConnectionFairy with an - ambiguous counter. i.e. its not true reference counting.""" - p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) - c1 = p.connect() - c2 = p.connect() - assert c1 is c2 - c1.close() - c2 = p.connect() - c2.close() - self.assert_(p.checkedout() != 0) - - c2.close() - self.assert_(p.checkedout() == 0) - - def test_recycle(self): - p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 1, max_overflow = 0, use_threadlocal = False, recycle=3) - - c1 = p.connect() - c_id = id(c1.connection) - c1.close() - c2 = p.connect() - assert id(c2.connection) == c_id - c2.close() - time.sleep(4) - c3= p.connect() - assert id(c3.connection) != c_id - - def test_invalidate(self): - dbapi = MockDBAPI() - p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) - c1 = p.connect() - c_id = c1.connection.id - c1.close(); c1=None - c1 = p.connect() - assert c1.connection.id == c_id - c1.invalidate() - c1 = None - - c1 = p.connect() - assert c1.connection.id != c_id - - def test_recreate(self): - dbapi = MockDBAPI() - p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) - p2 = p.recreate() - assert p2.size() == 1 - assert p2._use_threadlocal is False - assert p2._max_overflow == 0 - - def test_reconnect(self): - """tests reconnect operations at the pool level. SA's engine/dialect includes another - layer of reconnect support for 'database was lost' errors.""" + def testqueuepool_del(self): + self._do_testqueuepool(useclose=False) + + def testqueuepool_close(self): + self._do_testqueuepool(useclose=True) + + def _do_testqueuepool(self, useclose=False): + p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = False) + + def status(pool): + tup = (pool.size(), pool.checkedin(), pool.overflow(), pool.checkedout()) + print "Pool size: %d Connections in pool: %d Current Overflow: %d Current Checked out connections: %d" % tup + return tup + + c1 = p.connect() + self.assert_(status(p) == (3,0,-2,1)) + c2 = p.connect() + self.assert_(status(p) == (3,0,-1,2)) + c3 = p.connect() + self.assert_(status(p) == (3,0,0,3)) + c4 = p.connect() + self.assert_(status(p) == (3,0,1,4)) + c5 = p.connect() + self.assert_(status(p) == (3,0,2,5)) + c6 = p.connect() + self.assert_(status(p) == (3,0,3,6)) + if useclose: + c4.close() + c3.close() + c2.close() + else: + c4 = c3 = c2 = None + lazy_gc() + + self.assert_(status(p) == (3,3,3,3)) + if useclose: + c1.close() + c5.close() + c6.close() + else: + c1 = c5 = c6 = None + lazy_gc() + + self.assert_(status(p) == (3,3,0,0)) + + c1 = p.connect() + c2 = p.connect() + self.assert_(status(p) == (3, 1, 0, 2), status(p)) + if useclose: + c2.close() + else: + c2 = None + lazy_gc() + + self.assert_(status(p) == (3, 2, 0, 1)) + + c1.close() + + lazy_gc() + assert not pool._refs - dbapi = MockDBAPI() - p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) - c1 = p.connect() - c_id = c1.connection.id - c1.close(); c1=None - - c1 = p.connect() - assert c1.connection.id == c_id - dbapi.raise_error = True - c1.invalidate() - c1 = None - - c1 = p.connect() - assert c1.connection.id != c_id - - def test_detach(self): - dbapi = MockDBAPI() - p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) - - c1 = p.connect() - c1.detach() - c_id = c1.connection.id - - c2 = p.connect() - assert c2.connection.id != c1.connection.id - dbapi.raise_error = True - - c2.invalidate() - c2 = None - - c2 = p.connect() - assert c2.connection.id != c1.connection.id - - con = c1.connection - - assert not con.closed - c1.close() - assert con.closed - - def test_threadfairy(self): - p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) - c1 = p.connect() - c1.close() - c2 = p.connect() - assert c2.connection is not None + def test_timeout(self): + p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = 0, use_threadlocal = False, timeout=2) + c1 = p.connect() + c2 = p.connect() + c3 = p.connect() + now = time.time() + try: + c4 = p.connect() + assert False + except tsa.exc.TimeoutError, e: + assert int(time.time() - now) == 2 + + def test_timeout_race(self): + # test a race condition where the initial connecting threads all race + # to queue.Empty, then block on the mutex. each thread consumes a + # connection as they go in. when the limit is reached, the remaining + # threads go in, and get TimeoutError; even though they never got to + # wait for the timeout on queue.get(). the fix involves checking the + # timeout again within the mutex, and if so, unlocking and throwing + # them back to the start of do_get() + p = pool.QueuePool(creator = lambda: mock_dbapi.connect(delay=.05), pool_size = 2, max_overflow = 1, use_threadlocal = False, timeout=3) + timeouts = [] + def checkout(): + for x in xrange(1): + now = time.time() + try: + c1 = p.connect() + except tsa.exc.TimeoutError, e: + timeouts.append(int(time.time()) - now) + continue + time.sleep(4) + c1.close() + + threads = [] + for i in xrange(10): + th = threading.Thread(target=checkout) + th.start() + threads.append(th) + for th in threads: + th.join() + + print timeouts + assert len(timeouts) > 0 + for t in timeouts: + assert abs(t - 3) < 1, "Not all timeouts were 3 seconds: " + repr(timeouts) + + def _test_overflow(self, thread_count, max_overflow): + def creator(): + time.sleep(.05) + return mock_dbapi.connect() + + p = pool.QueuePool(creator=creator, + pool_size=3, timeout=2, + max_overflow=max_overflow) + peaks = [] + def whammy(): + for i in range(10): + try: + con = p.connect() + time.sleep(.005) + peaks.append(p.overflow()) + con.close() + del con + except tsa.exc.TimeoutError: + pass + threads = [] + for i in xrange(thread_count): + th = threading.Thread(target=whammy) + th.start() + threads.append(th) + for th in threads: + th.join() + + self.assert_(max(peaks) <= max_overflow) + + lazy_gc() + assert not pool._refs + + def test_no_overflow(self): + self._test_overflow(40, 0) + + def test_max_overflow(self): + self._test_overflow(40, 5) + + def test_mixed_close(self): + p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) + c1 = p.connect() + c2 = p.connect() + assert c1 is c2 + c1.close() + c2 = None + assert p.checkedout() == 1 + c1 = None + lazy_gc() + assert p.checkedout() == 0 + + lazy_gc() + assert not pool._refs + + def test_weakref_kaboom(self): + p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) + c1 = p.connect() + c2 = p.connect() + c1.close() + c2 = None + del c1 + del c2 + gc_collect() + assert p.checkedout() == 0 + c3 = p.connect() + assert c3 is not None + + def test_trick_the_counter(self): + """this is a "flaw" in the connection pool; since threadlocal uses a single ConnectionFairy per thread + with an open/close counter, you can fool the counter into giving you a ConnectionFairy with an + ambiguous counter. i.e. its not true reference counting.""" + p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) + c1 = p.connect() + c2 = p.connect() + assert c1 is c2 + c1.close() + c2 = p.connect() + c2.close() + self.assert_(p.checkedout() != 0) + + c2.close() + self.assert_(p.checkedout() == 0) + + def test_recycle(self): + p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 1, max_overflow = 0, use_threadlocal = False, recycle=3) + + c1 = p.connect() + c_id = id(c1.connection) + c1.close() + c2 = p.connect() + assert id(c2.connection) == c_id + c2.close() + time.sleep(4) + c3= p.connect() + assert id(c3.connection) != c_id + + def test_invalidate(self): + dbapi = MockDBAPI() + p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) + c1 = p.connect() + c_id = c1.connection.id + c1.close(); c1=None + c1 = p.connect() + assert c1.connection.id == c_id + c1.invalidate() + c1 = None + + c1 = p.connect() + assert c1.connection.id != c_id + + def test_recreate(self): + dbapi = MockDBAPI() + p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) + p2 = p.recreate() + assert p2.size() == 1 + assert p2._use_threadlocal is False + assert p2._max_overflow == 0 + + def test_reconnect(self): + """tests reconnect operations at the pool level. SA's engine/dialect includes another + layer of reconnect support for 'database was lost' errors.""" + + dbapi = MockDBAPI() + p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) + c1 = p.connect() + c_id = c1.connection.id + c1.close(); c1=None + + c1 = p.connect() + assert c1.connection.id == c_id + dbapi.raise_error = True + c1.invalidate() + c1 = None + + c1 = p.connect() + assert c1.connection.id != c_id + + def test_detach(self): + dbapi = MockDBAPI() + p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) + + c1 = p.connect() + c1.detach() + c_id = c1.connection.id + + c2 = p.connect() + assert c2.connection.id != c1.connection.id + dbapi.raise_error = True + + c2.invalidate() + c2 = None + + c2 = p.connect() + assert c2.connection.id != c1.connection.id + + con = c1.connection + + assert not con.closed + c1.close() + assert con.closed + + def test_threadfairy(self): + p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) + c1 = p.connect() + c1.close() + c2 = p.connect() + assert c2.connection is not None class SingletonThreadPoolTest(PoolTestBase): def test_cleanup(self): diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py index 3a525c2a7..6afd71515 100644 --- a/test/engine/test_reconnect.py +++ b/test/engine/test_reconnect.py @@ -1,12 +1,13 @@ from sqlalchemy.test.testing import eq_ +import time import weakref from sqlalchemy import select, MetaData, Integer, String, pool from sqlalchemy.test.schema import Table from sqlalchemy.test.schema import Column import sqlalchemy as tsa from sqlalchemy.test import TestBase, testing, engines -import time -import gc +from sqlalchemy.test.util import gc_collect + class MockDisconnect(Exception): pass @@ -54,7 +55,7 @@ class MockReconnectTest(TestBase): dbapi = MockDBAPI() # create engine using our current dburi - db = tsa.create_engine('postgres://foo:bar@localhost/test', module=dbapi) + db = tsa.create_engine('postgresql://foo:bar@localhost/test', module=dbapi, _initialize=False) # monkeypatch disconnect checker db.dialect.is_disconnect = lambda e: isinstance(e, MockDisconnect) @@ -98,7 +99,7 @@ class MockReconnectTest(TestBase): assert id(db.pool) != pid # ensure all connections closed (pool was recycled) - gc.collect() + gc_collect() assert len(dbapi.connections) == 0 conn =db.connect() @@ -118,7 +119,7 @@ class MockReconnectTest(TestBase): pass # assert was invalidated - gc.collect() + gc_collect() assert len(dbapi.connections) == 0 assert not conn.closed assert conn.invalidated @@ -168,7 +169,7 @@ class MockReconnectTest(TestBase): assert conn.invalidated # ensure all connections closed (pool was recycled) - gc.collect() + gc_collect() assert len(dbapi.connections) == 0 # test reconnects @@ -334,7 +335,8 @@ class InvalidateDuringResultTest(TestBase): meta.drop_all() engine.dispose() - @testing.fails_on('mysql', 'FIXME: unknown') + @testing.fails_on('+mysqldb', "Buffers the result set and doesn't check for connection close") + @testing.fails_on('+pg8000', "Buffers the result set and doesn't check for connection close") def test_invalidate_on_results(self): conn = engine.connect() @@ -344,7 +346,7 @@ class InvalidateDuringResultTest(TestBase): engine.test_shutdown() try: - result.fetchone() + print "ghost result: %r" % result.fetchone() assert False except tsa.exc.DBAPIError, e: if not e.connection_invalidated: diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py index ea80776a6..dff9fa1bb 100644 --- a/test/engine/test_reflection.py +++ b/test/engine/test_reflection.py @@ -1,17 +1,22 @@ from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message import StringIO, unicodedata import sqlalchemy as sa +from sqlalchemy import types as sql_types +from sqlalchemy import schema +from sqlalchemy.engine.reflection import Inspector from sqlalchemy import MetaData from sqlalchemy.test.schema import Table from sqlalchemy.test.schema import Column import sqlalchemy as tsa from sqlalchemy.test import TestBase, ComparesTables, testing, engines +create_inspector = Inspector.from_engine metadata, users = None, None class ReflectionTest(TestBase, ComparesTables): + @testing.exclude('mssql', '<', (10, 0, 0), 'Date is only supported on MSSQL 2008+') @testing.exclude('mysql', '<', (4, 1, 1), 'early types are squirrely') def test_basic_reflection(self): meta = MetaData(testing.db) @@ -22,16 +27,16 @@ class ReflectionTest(TestBase, ComparesTables): Column('test1', sa.CHAR(5), nullable=False), Column('test2', sa.Float(5), nullable=False), Column('test3', sa.Text), - Column('test4', sa.Numeric, nullable = False), - Column('test5', sa.DateTime), + Column('test4', sa.Numeric(10, 2), nullable = False), + Column('test5', sa.Date), Column('parent_user_id', sa.Integer, sa.ForeignKey('engine_users.user_id')), - Column('test6', sa.DateTime, nullable=False), + Column('test6', sa.Date, nullable=False), Column('test7', sa.Text), Column('test8', sa.Binary), Column('test_passivedefault2', sa.Integer, server_default='5'), Column('test9', sa.Binary(100)), - Column('test_numeric', sa.Numeric()), + Column('test10', sa.Numeric(10, 2)), test_needs_fk=True, ) @@ -52,9 +57,35 @@ class ReflectionTest(TestBase, ComparesTables): self.assert_tables_equal(users, reflected_users) self.assert_tables_equal(addresses, reflected_addresses) finally: - addresses.drop() - users.drop() - + meta.drop_all() + + def test_two_foreign_keys(self): + meta = MetaData(testing.db) + t1 = Table('t1', meta, + Column('id', sa.Integer, primary_key=True), + Column('t2id', sa.Integer, sa.ForeignKey('t2.id')), + Column('t3id', sa.Integer, sa.ForeignKey('t3.id')), + test_needs_fk=True + ) + t2 = Table('t2', meta, + Column('id', sa.Integer, primary_key=True), + test_needs_fk=True + ) + t3 = Table('t3', meta, + Column('id', sa.Integer, primary_key=True), + test_needs_fk=True + ) + meta.create_all() + try: + meta2 = MetaData() + t1r, t2r, t3r = [Table(x, meta2, autoload=True, autoload_with=testing.db) for x in ('t1', 't2', 't3')] + + assert t1r.c.t2id.references(t2r.c.id) + assert t1r.c.t3id.references(t3r.c.id) + + finally: + meta.drop_all() + def test_include_columns(self): meta = MetaData(testing.db) foo = Table('foo', meta, *[Column(n, sa.String(30)) @@ -84,26 +115,68 @@ class ReflectionTest(TestBase, ComparesTables): finally: meta.drop_all() + @testing.emits_warning(r".*omitted columns") + def test_include_columns_indexes(self): + m = MetaData(testing.db) + + t1 = Table('t1', m, Column('a', sa.Integer), Column('b', sa.Integer)) + sa.Index('foobar', t1.c.a, t1.c.b) + sa.Index('bat', t1.c.a) + m.create_all() + try: + m2 = MetaData(testing.db) + t2 = Table('t1', m2, autoload=True) + assert len(t2.indexes) == 2 + m2 = MetaData(testing.db) + t2 = Table('t1', m2, autoload=True, include_columns=['a']) + assert len(t2.indexes) == 1 + + m2 = MetaData(testing.db) + t2 = Table('t1', m2, autoload=True, include_columns=['a', 'b']) + assert len(t2.indexes) == 2 + finally: + m.drop_all() + + def test_autoincrement_col(self): + """test that 'autoincrement' is reflected according to sqla's policy. + + Don't mark this test as unsupported for any backend ! + + (technically it fails with MySQL InnoDB since "id" comes before "id2") + + """ + + meta = MetaData(testing.db) + t1 = Table('test', meta, + Column('id', sa.Integer, primary_key=True), + Column('data', sa.String(50)), + ) + t2 = Table('test2', meta, + Column('id', sa.Integer, sa.ForeignKey('test.id'), primary_key=True), + Column('id2', sa.Integer, primary_key=True), + Column('data', sa.String(50)), + ) + meta.create_all() + try: + m2 = MetaData(testing.db) + t1a = Table('test', m2, autoload=True) + assert t1a._autoincrement_column is t1a.c.id + + t2a = Table('test2', m2, autoload=True) + assert t2a._autoincrement_column is t2a.c.id2 + + finally: + meta.drop_all() + def test_unknown_types(self): meta = MetaData(testing.db) t = Table("test", meta, Column('foo', sa.DateTime)) - import sys - dialect_module = sys.modules[testing.db.dialect.__module__] - - # we're relying on the presence of "ischema_names" in the - # dialect module, else we can't test this. we need to be able - # to get the dialect to not be aware of some type so we temporarily - # monkeypatch. not sure what a better way for this could be, - # except for an established dialect hook or dialect-specific tests - if not hasattr(dialect_module, 'ischema_names'): - return - - ischema_names = dialect_module.ischema_names + ischema_names = testing.db.dialect.ischema_names t.create() - dialect_module.ischema_names = {} + testing.db.dialect.ischema_names = {} try: m2 = MetaData(testing.db) assert_raises(tsa.exc.SAWarning, Table, "test", m2, autoload=True) @@ -115,7 +188,7 @@ class ReflectionTest(TestBase, ComparesTables): assert t3.c.foo.type.__class__ == sa.types.NullType finally: - dialect_module.ischema_names = ischema_names + testing.db.dialect.ischema_names = ischema_names t.drop() def test_basic_override(self): @@ -578,7 +651,6 @@ class ReflectionTest(TestBase, ComparesTables): m9.reflect() self.assert_(not m9.tables) - @testing.fails_on_everything_except('postgres', 'mysql') def test_index_reflection(self): m1 = MetaData(testing.db) t1 = Table('party', m1, @@ -698,7 +770,7 @@ class UnicodeReflectionTest(TestBase): def test_basic(self): try: # the 'convert_unicode' should not get in the way of the reflection - # process. reflecttable for oracle, postgres (others?) expect non-unicode + # process. reflecttable for oracle, postgresql (others?) expect non-unicode # strings in result sets/bind params bind = engines.utf8_engine(options={'convert_unicode':True}) metadata = MetaData(bind) @@ -713,7 +785,8 @@ class UnicodeReflectionTest(TestBase): metadata.create_all() reflected = set(bind.table_names()) - if not names.issubset(reflected): + # Jython 2.5 on Java 5 lacks unicodedata.normalize + if not names.issubset(reflected) and hasattr(unicodedata, 'normalize'): # Python source files in the utf-8 coding seem to normalize # literals as NFC (and the above are explicitly NFC). Maybe # this database normalizes NFD on reflection. @@ -741,23 +814,15 @@ class SchemaTest(TestBase): Column('col1', sa.Integer, primary_key=True), Column('col2', sa.Integer, sa.ForeignKey('someschema.table1.col1')), schema='someschema') - # ensure this doesnt crash - print [t for t in metadata.sorted_tables] - buf = StringIO.StringIO() - def foo(s, p=None): - buf.write(s) - gen = sa.create_engine(testing.db.name + "://", strategy="mock", executor=foo) - gen = gen.dialect.schemagenerator(gen.dialect, gen) - gen.traverse(table1) - gen.traverse(table2) - buf = buf.getvalue() - print buf + + t1 = str(schema.CreateTable(table1).compile(bind=testing.db)) + t2 = str(schema.CreateTable(table2).compile(bind=testing.db)) if testing.db.dialect.preparer(testing.db.dialect).omit_schema: - assert buf.index("CREATE TABLE table1") > -1 - assert buf.index("CREATE TABLE table2") > -1 + assert t1.index("CREATE TABLE table1") > -1 + assert t2.index("CREATE TABLE table2") > -1 else: - assert buf.index("CREATE TABLE someschema.table1") > -1 - assert buf.index("CREATE TABLE someschema.table2") > -1 + assert t1.index("CREATE TABLE someschema.table1") > -1 + assert t2.index("CREATE TABLE someschema.table2") > -1 @testing.crashes('firebird', 'No schema support') @testing.fails_on('sqlite', 'FIXME: unknown') @@ -767,9 +832,9 @@ class SchemaTest(TestBase): def test_explicit_default_schema(self): engine = testing.db - if testing.against('mysql'): + if testing.against('mysql+mysqldb'): schema = testing.db.url.database - elif testing.against('postgres'): + elif testing.against('postgresql'): schema = 'public' elif testing.against('sqlite'): # Works for CREATE TABLE main.foo, SELECT FROM main.foo, etc., @@ -820,4 +885,324 @@ class HasSequenceTest(TestBase): metadata.drop_all(bind=testing.db) eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), False) +# Tests related to engine.reflection + +def get_schema(): + if testing.against('oracle'): + return 'scott' + return 'test_schema' + +def createTables(meta, schema=None): + if schema: + parent_user_id = Column('parent_user_id', sa.Integer, + sa.ForeignKey('%s.users.user_id' % schema) + ) + else: + parent_user_id = Column('parent_user_id', sa.Integer, + sa.ForeignKey('users.user_id') + ) + + users = Table('users', meta, + Column('user_id', sa.INT, primary_key=True), + Column('user_name', sa.VARCHAR(20), nullable=False), + Column('test1', sa.CHAR(5), nullable=False), + Column('test2', sa.Float(5), nullable=False), + Column('test3', sa.Text), + Column('test4', sa.Numeric(10, 2), nullable = False), + Column('test5', sa.DateTime), + Column('test5-1', sa.TIMESTAMP), + parent_user_id, + Column('test6', sa.DateTime, nullable=False), + Column('test7', sa.Text), + Column('test8', sa.Binary), + Column('test_passivedefault2', sa.Integer, server_default='5'), + Column('test9', sa.Binary(100)), + Column('test10', sa.Numeric(10, 2)), + schema=schema, + test_needs_fk=True, + ) + addresses = Table('email_addresses', meta, + Column('address_id', sa.Integer, primary_key = True), + Column('remote_user_id', sa.Integer, + sa.ForeignKey(users.c.user_id)), + Column('email_address', sa.String(20)), + schema=schema, + test_needs_fk=True, + ) + return (users, addresses) + +def createIndexes(con, schema=None): + fullname = 'users' + if schema: + fullname = "%s.%s" % (schema, 'users') + query = "CREATE INDEX users_t_idx ON %s (test1, test2)" % fullname + con.execute(sa.sql.text(query)) + +def createViews(con, schema=None): + for table_name in ('users', 'email_addresses'): + fullname = table_name + if schema: + fullname = "%s.%s" % (schema, table_name) + view_name = fullname + '_v' + query = "CREATE VIEW %s AS SELECT * FROM %s" % (view_name, + fullname) + con.execute(sa.sql.text(query)) + +def dropViews(con, schema=None): + for table_name in ('email_addresses', 'users'): + fullname = table_name + if schema: + fullname = "%s.%s" % (schema, table_name) + view_name = fullname + '_v' + query = "DROP VIEW %s" % view_name + con.execute(sa.sql.text(query)) + + +class ComponentReflectionTest(TestBase): + + @testing.requires.schemas + def test_get_schema_names(self): + meta = MetaData(testing.db) + insp = Inspector(meta.bind) + + self.assert_(get_schema() in insp.get_schema_names()) + + def _test_get_table_names(self, schema=None, table_type='table', + order_by=None): + meta = MetaData(testing.db) + (users, addresses) = createTables(meta, schema) + meta.create_all() + createViews(meta.bind, schema) + try: + insp = Inspector(meta.bind) + if table_type == 'view': + table_names = insp.get_view_names(schema) + table_names.sort() + answer = ['email_addresses_v', 'users_v'] + else: + table_names = insp.get_table_names(schema, + order_by=order_by) + table_names.sort() + if order_by == 'foreign_key': + answer = ['users', 'email_addresses'] + else: + answer = ['email_addresses', 'users'] + eq_(table_names, answer) + finally: + dropViews(meta.bind, schema) + addresses.drop() + users.drop() + + def test_get_table_names(self): + self._test_get_table_names() + + @testing.requires.schemas + def test_get_table_names_with_schema(self): + self._test_get_table_names(get_schema()) + + def test_get_view_names(self): + self._test_get_table_names(table_type='view') + + @testing.requires.schemas + def test_get_view_names_with_schema(self): + self._test_get_table_names(get_schema(), table_type='view') + + def _test_get_columns(self, schema=None, table_type='table'): + meta = MetaData(testing.db) + (users, addresses) = createTables(meta, schema) + table_names = ['users', 'email_addresses'] + meta.create_all() + if table_type == 'view': + createViews(meta.bind, schema) + table_names = ['users_v', 'email_addresses_v'] + try: + insp = Inspector(meta.bind) + for (table_name, table) in zip(table_names, (users, addresses)): + schema_name = schema + cols = insp.get_columns(table_name, schema=schema_name) + self.assert_(len(cols) > 0, len(cols)) + # should be in order + for (i, col) in enumerate(table.columns): + eq_(col.name, cols[i]['name']) + ctype = cols[i]['type'].__class__ + ctype_def = col.type + if isinstance(ctype_def, sa.types.TypeEngine): + ctype_def = ctype_def.__class__ + + # Oracle returns Date for DateTime. + if testing.against('oracle') \ + and ctype_def in (sql_types.Date, sql_types.DateTime): + ctype_def = sql_types.Date + + # assert that the desired type and return type + # share a base within one of the generic types. + self.assert_( + len( + set( + ctype.__mro__ + ).intersection(ctype_def.__mro__) + .intersection([sql_types.Integer, sql_types.Numeric, + sql_types.DateTime, sql_types.Date, sql_types.Time, + sql_types.String, sql_types.Binary]) + ) > 0 + ,("%s(%s), %s(%s)" % (col.name, col.type, cols[i]['name'], + ctype))) + finally: + if table_type == 'view': + dropViews(meta.bind, schema) + addresses.drop() + users.drop() + + def test_get_columns(self): + self._test_get_columns() + + @testing.requires.schemas + def test_get_columns_with_schema(self): + self._test_get_columns(schema=get_schema()) + + def test_get_view_columns(self): + self._test_get_columns(table_type='view') + + @testing.requires.schemas + def test_get_view_columns_with_schema(self): + self._test_get_columns(schema=get_schema(), table_type='view') + + def _test_get_primary_keys(self, schema=None): + meta = MetaData(testing.db) + (users, addresses) = createTables(meta, schema) + meta.create_all() + insp = Inspector(meta.bind) + try: + users_pkeys = insp.get_primary_keys(users.name, + schema=schema) + eq_(users_pkeys, ['user_id']) + addr_pkeys = insp.get_primary_keys(addresses.name, + schema=schema) + eq_(addr_pkeys, ['address_id']) + + finally: + addresses.drop() + users.drop() + + def test_get_primary_keys(self): + self._test_get_primary_keys() + + @testing.fails_on('sqlite', 'no schemas') + def test_get_primary_keys_with_schema(self): + self._test_get_primary_keys(schema=get_schema()) + + def _test_get_foreign_keys(self, schema=None): + meta = MetaData(testing.db) + (users, addresses) = createTables(meta, schema) + meta.create_all() + insp = Inspector(meta.bind) + try: + expected_schema = schema + # users + users_fkeys = insp.get_foreign_keys(users.name, + schema=schema) + fkey1 = users_fkeys[0] + self.assert_(fkey1['name'] is not None) + eq_(fkey1['referred_schema'], expected_schema) + eq_(fkey1['referred_table'], users.name) + eq_(fkey1['referred_columns'], ['user_id', ]) + eq_(fkey1['constrained_columns'], ['parent_user_id']) + #addresses + addr_fkeys = insp.get_foreign_keys(addresses.name, + schema=schema) + fkey1 = addr_fkeys[0] + self.assert_(fkey1['name'] is not None) + eq_(fkey1['referred_schema'], expected_schema) + eq_(fkey1['referred_table'], users.name) + eq_(fkey1['referred_columns'], ['user_id', ]) + eq_(fkey1['constrained_columns'], ['remote_user_id']) + finally: + addresses.drop() + users.drop() + + def test_get_foreign_keys(self): + self._test_get_foreign_keys() + + @testing.requires.schemas + def test_get_foreign_keys_with_schema(self): + self._test_get_foreign_keys(schema=get_schema()) + + def _test_get_indexes(self, schema=None): + meta = MetaData(testing.db) + (users, addresses) = createTables(meta, schema) + meta.create_all() + createIndexes(meta.bind, schema) + try: + # The database may decide to create indexes for foreign keys, etc. + # so there may be more indexes than expected. + insp = Inspector(meta.bind) + indexes = insp.get_indexes('users', schema=schema) + indexes.sort() + expected_indexes = [ + {'unique': False, + 'column_names': ['test1', 'test2'], + 'name': 'users_t_idx'}] + index_names = [d['name'] for d in indexes] + for e_index in expected_indexes: + assert e_index['name'] in index_names + index = indexes[index_names.index(e_index['name'])] + for key in e_index: + eq_(e_index[key], index[key]) + + finally: + addresses.drop() + users.drop() + + def test_get_indexes(self): + self._test_get_indexes() + + @testing.requires.schemas + def test_get_indexes_with_schema(self): + self._test_get_indexes(schema=get_schema()) + + def _test_get_view_definition(self, schema=None): + meta = MetaData(testing.db) + (users, addresses) = createTables(meta, schema) + meta.create_all() + createViews(meta.bind, schema) + view_name1 = 'users_v' + view_name2 = 'email_addresses_v' + try: + insp = Inspector(meta.bind) + v1 = insp.get_view_definition(view_name1, schema=schema) + self.assert_(v1) + v2 = insp.get_view_definition(view_name2, schema=schema) + self.assert_(v2) + finally: + dropViews(meta.bind, schema) + addresses.drop() + users.drop() + + def test_get_view_definition(self): + self._test_get_view_definition() + + @testing.requires.schemas + def test_get_view_definition_with_schema(self): + self._test_get_view_definition(schema=get_schema()) + + def _test_get_table_oid(self, table_name, schema=None): + if testing.against('postgresql'): + meta = MetaData(testing.db) + (users, addresses) = createTables(meta, schema) + meta.create_all() + try: + insp = create_inspector(meta.bind) + oid = insp.get_table_oid(table_name, schema) + self.assert_(isinstance(oid, (int, long))) + finally: + addresses.drop() + users.drop() + + def test_get_table_oid(self): + self._test_get_table_oid('users') + + @testing.requires.schemas + def test_get_table_oid_with_schema(self): + self._test_get_table_oid('users', schema=get_schema()) + diff --git a/test/engine/test_transaction.py b/test/engine/test_transaction.py index 6698259a4..8e3f3412d 100644 --- a/test/engine/test_transaction.py +++ b/test/engine/test_transaction.py @@ -20,7 +20,8 @@ class TransactionTest(TestBase): users.create(testing.db) def teardown(self): - testing.db.connect().execute(users.delete()) + testing.db.execute(users.delete()).close() + @classmethod def teardown_class(cls): users.drop(testing.db) @@ -40,6 +41,7 @@ class TransactionTest(TestBase): result = connection.execute("select * from query_users") assert len(result.fetchall()) == 3 transaction.commit() + connection.close() def test_rollback(self): """test a basic rollback""" @@ -176,6 +178,7 @@ class TransactionTest(TestBase): connection.close() @testing.requires.savepoints + @testing.crashes('oracle+zxjdbc', 'Errors out and causes subsequent tests to deadlock') def test_nested_subtransaction_commit(self): connection = testing.db.connect() transaction = connection.begin() @@ -274,6 +277,7 @@ class TransactionTest(TestBase): connection.close() @testing.requires.two_phase_transactions + @testing.crashes('mysql+zxjdbc', 'Deadlocks, causing subsequent tests to fail') @testing.fails_on('mysql', 'FIXME: unknown') def test_two_phase_recover(self): # MySQL recovery doesn't currently seem to work correctly @@ -369,7 +373,7 @@ class ExplicitAutoCommitTest(TestBase): Requires PostgreSQL so that we may define a custom function which modifies the database. """ - __only_on__ = 'postgres' + __only_on__ = 'postgresql' @classmethod def setup_class(cls): @@ -380,7 +384,7 @@ class ExplicitAutoCommitTest(TestBase): testing.db.execute("create function insert_foo(varchar) returns integer as 'insert into foo(data) values ($1);select 1;' language sql") def teardown(self): - foo.delete().execute() + foo.delete().execute().close() @classmethod def teardown_class(cls): @@ -453,8 +457,10 @@ class TLTransactionTest(TestBase): test_needs_acid=True, ) users.create(tlengine) + def teardown(self): - tlengine.execute(users.delete()) + tlengine.execute(users.delete()).close() + @classmethod def teardown_class(cls): users.drop(tlengine) @@ -497,6 +503,7 @@ class TLTransactionTest(TestBase): try: assert len(result.fetchall()) == 0 finally: + c.close() external_connection.close() def test_rollback(self): @@ -530,7 +537,9 @@ class TLTransactionTest(TestBase): external_connection.close() def test_commits(self): - assert tlengine.connect().execute("select count(1) from query_users").scalar() == 0 + connection = tlengine.connect() + assert connection.execute("select count(1) from query_users").scalar() == 0 + connection.close() connection = tlengine.contextual_connect() transaction = connection.begin() @@ -547,6 +556,7 @@ class TLTransactionTest(TestBase): l = result.fetchall() assert len(l) == 3, "expected 3 got %d" % len(l) transaction.commit() + connection.close() def test_rollback_off_conn(self): # test that a TLTransaction opened off a TLConnection allows that @@ -563,6 +573,7 @@ class TLTransactionTest(TestBase): try: assert len(result.fetchall()) == 0 finally: + conn.close() external_connection.close() def test_morerollback_off_conn(self): @@ -581,6 +592,8 @@ class TLTransactionTest(TestBase): try: assert len(result.fetchall()) == 0 finally: + conn.close() + conn2.close() external_connection.close() def test_commit_off_connection(self): @@ -596,6 +609,7 @@ class TLTransactionTest(TestBase): try: assert len(result.fetchall()) == 3 finally: + conn.close() external_connection.close() def test_nesting(self): @@ -712,8 +726,10 @@ class ForUpdateTest(TestBase): test_needs_acid=True, ) counters.create(testing.db) + def teardown(self): - testing.db.connect().execute(counters.delete()) + testing.db.execute(counters.delete()).close() + @classmethod def teardown_class(cls): counters.drop(testing.db) @@ -726,7 +742,7 @@ class ForUpdateTest(TestBase): for i in xrange(count): trans = con.begin() try: - existing = con.execute(sel).fetchone() + existing = con.execute(sel).first() incr = existing['counter_value'] + 1 time.sleep(delay) @@ -734,7 +750,7 @@ class ForUpdateTest(TestBase): values={'counter_value':incr})) time.sleep(delay) - readback = con.execute(sel).fetchone() + readback = con.execute(sel).first() if (readback['counter_value'] != incr): raise AssertionError("Got %s post-update, expected %s" % (readback['counter_value'], incr)) @@ -778,7 +794,7 @@ class ForUpdateTest(TestBase): self.assert_(len(errors) == 0) sel = counters.select(whereclause=counters.c.counter_id==1) - final = db.execute(sel).fetchone() + final = db.execute(sel).first() self.assert_(final['counter_value'] == iterations * thread_count) def overlap(self, ids, errors, update_style): |
