summaryrefslogtreecommitdiff
path: root/test/engine
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2009-08-06 21:11:27 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2009-08-06 21:11:27 +0000
commit8fc5005dfe3eb66a46470ad8a8c7b95fc4d6bdca (patch)
treeae9e27d12c9fbf8297bb90469509e1cb6a206242 /test/engine
parent7638aa7f242c6ea3d743aa9100e32be2052546a6 (diff)
downloadsqlalchemy-8fc5005dfe3eb66a46470ad8a8c7b95fc4d6bdca.tar.gz
merge 0.6 series to trunk.
Diffstat (limited to 'test/engine')
-rw-r--r--test/engine/test_bind.py4
-rw-r--r--test/engine/test_ddlevents.py72
-rw-r--r--test/engine/test_execute.py34
-rw-r--r--test/engine/test_metadata.py41
-rw-r--r--test/engine/test_parseconnect.py127
-rw-r--r--test/engine/test_pool.py630
-rw-r--r--test/engine/test_reconnect.py18
-rw-r--r--test/engine/test_reflection.py467
-rw-r--r--test/engine/test_transaction.py34
9 files changed, 948 insertions, 479 deletions
diff --git a/test/engine/test_bind.py b/test/engine/test_bind.py
index 7fd3009bc..1122f1632 100644
--- a/test/engine/test_bind.py
+++ b/test/engine/test_bind.py
@@ -121,7 +121,7 @@ class BindTest(testing.TestBase):
table = Table('test_table', metadata,
Column('foo', Integer))
- metadata.connect(bind)
+ metadata.bind = bind
assert metadata.bind is table.bind is bind
metadata.create_all()
@@ -199,7 +199,7 @@ class BindTest(testing.TestBase):
try:
e = elem(bind=bind)
assert e.bind is bind
- e.execute()
+ e.execute().close()
finally:
if isinstance(bind, engine.Connection):
bind.close()
diff --git a/test/engine/test_ddlevents.py b/test/engine/test_ddlevents.py
index 5716006d9..434a5d873 100644
--- a/test/engine/test_ddlevents.py
+++ b/test/engine/test_ddlevents.py
@@ -1,12 +1,13 @@
-from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
-from sqlalchemy.schema import DDL
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
+from sqlalchemy.schema import DDL, CheckConstraint, AddConstraint, DropConstraint
from sqlalchemy import create_engine
from sqlalchemy import MetaData, Integer, String
from sqlalchemy.test.schema import Table
from sqlalchemy.test.schema import Column
import sqlalchemy as tsa
from sqlalchemy.test import TestBase, testing, engines
-
+from sqlalchemy.test.testing import AssertsCompiledSQL
+from nose import SkipTest
class DDLEventTest(TestBase):
class Canary(object):
@@ -15,25 +16,25 @@ class DDLEventTest(TestBase):
self.schema_item = schema_item
self.bind = bind
- def before_create(self, action, schema_item, bind):
+ def before_create(self, action, schema_item, bind, **kw):
assert self.state is None
assert schema_item is self.schema_item
assert bind is self.bind
self.state = action
- def after_create(self, action, schema_item, bind):
+ def after_create(self, action, schema_item, bind, **kw):
assert self.state in ('before-create', 'skipped')
assert schema_item is self.schema_item
assert bind is self.bind
self.state = action
- def before_drop(self, action, schema_item, bind):
+ def before_drop(self, action, schema_item, bind, **kw):
assert self.state is None
assert schema_item is self.schema_item
assert bind is self.bind
self.state = action
- def after_drop(self, action, schema_item, bind):
+ def after_drop(self, action, schema_item, bind, **kw):
assert self.state in ('before-drop', 'skipped')
assert schema_item is self.schema_item
assert bind is self.bind
@@ -232,7 +233,33 @@ class DDLExecutionTest(TestBase):
assert 'klptzyxm' not in strings
assert 'xyzzy' in strings
assert 'fnord' in strings
-
+
+ def test_conditional_constraint(self):
+ metadata, users, engine = self.metadata, self.users, self.engine
+ nonpg_mock = engines.mock_engine(dialect_name='sqlite')
+ pg_mock = engines.mock_engine(dialect_name='postgresql')
+
+ constraint = CheckConstraint('a < b',name="my_test_constraint", table=users)
+
+ # by placing the constraint in an Add/Drop construct,
+ # the 'inline_ddl' flag is set to False
+ AddConstraint(constraint, on='postgresql').execute_at("after-create", users)
+ DropConstraint(constraint, on='postgresql').execute_at("before-drop", users)
+
+ metadata.create_all(bind=nonpg_mock)
+ strings = " ".join(str(x) for x in nonpg_mock.mock)
+ assert "my_test_constraint" not in strings
+ metadata.drop_all(bind=nonpg_mock)
+ strings = " ".join(str(x) for x in nonpg_mock.mock)
+ assert "my_test_constraint" not in strings
+
+ metadata.create_all(bind=pg_mock)
+ strings = " ".join(str(x) for x in pg_mock.mock)
+ assert "my_test_constraint" in strings
+ metadata.drop_all(bind=pg_mock)
+ strings = " ".join(str(x) for x in pg_mock.mock)
+ assert "my_test_constraint" in strings
+
def test_metadata(self):
metadata, engine = self.metadata, self.engine
DDL('mxyzptlk').execute_at('before-create', metadata)
@@ -255,7 +282,10 @@ class DDLExecutionTest(TestBase):
assert 'fnord' in strings
def test_ddl_execute(self):
- engine = create_engine('sqlite:///')
+ try:
+ engine = create_engine('sqlite:///')
+ except ImportError:
+ raise SkipTest('Requires sqlite')
cx = engine.connect()
table = self.users
ddl = DDL('SELECT 1')
@@ -286,7 +316,7 @@ class DDLExecutionTest(TestBase):
r = eval(py)
assert list(r) == [(1,)], py
-class DDLTest(TestBase):
+class DDLTest(TestBase, AssertsCompiledSQL):
def mock_engine(self):
executor = lambda *a, **kw: None
engine = create_engine(testing.db.name + '://',
@@ -297,7 +327,6 @@ class DDLTest(TestBase):
def test_tokens(self):
m = MetaData()
- bind = self.mock_engine()
sane_alone = Table('t', m, Column('id', Integer))
sane_schema = Table('t', m, Column('id', Integer), schema='s')
insane_alone = Table('t t', m, Column('id', Integer))
@@ -305,20 +334,21 @@ class DDLTest(TestBase):
ddl = DDL('%(schema)s-%(table)s-%(fullname)s')
- eq_(ddl._expand(sane_alone, bind), '-t-t')
- eq_(ddl._expand(sane_schema, bind), 's-t-s.t')
- eq_(ddl._expand(insane_alone, bind), '-"t t"-"t t"')
- eq_(ddl._expand(insane_schema, bind),
- '"s s"-"t t"-"s s"."t t"')
+ dialect = self.mock_engine().dialect
+ self.assert_compile(ddl.against(sane_alone), '-t-t', dialect=dialect)
+ self.assert_compile(ddl.against(sane_schema), 's-t-s.t', dialect=dialect)
+ self.assert_compile(ddl.against(insane_alone), '-"t t"-"t t"', dialect=dialect)
+ self.assert_compile(ddl.against(insane_schema), '"s s"-"t t"-"s s"."t t"', dialect=dialect)
# overrides are used piece-meal and verbatim.
ddl = DDL('%(schema)s-%(table)s-%(fullname)s-%(bonus)s',
context={'schema':'S S', 'table': 'T T', 'bonus': 'b'})
- eq_(ddl._expand(sane_alone, bind), 'S S-T T-t-b')
- eq_(ddl._expand(sane_schema, bind), 'S S-T T-s.t-b')
- eq_(ddl._expand(insane_alone, bind), 'S S-T T-"t t"-b')
- eq_(ddl._expand(insane_schema, bind),
- 'S S-T T-"s s"."t t"-b')
+
+ self.assert_compile(ddl.against(sane_alone), 'S S-T T-t-b', dialect=dialect)
+ self.assert_compile(ddl.against(sane_schema), 'S S-T T-s.t-b', dialect=dialect)
+ self.assert_compile(ddl.against(insane_alone), 'S S-T T-"t t"-b', dialect=dialect)
+ self.assert_compile(ddl.against(insane_schema), 'S S-T T-"s s"."t t"-b', dialect=dialect)
+
def test_filter(self):
cx = self.mock_engine()
diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py
index 08bf80fe2..4783c5508 100644
--- a/test/engine/test_execute.py
+++ b/test/engine/test_execute.py
@@ -15,18 +15,20 @@ class ExecuteTest(TestBase):
global users, metadata
metadata = MetaData(testing.db)
users = Table('users', metadata,
- Column('user_id', INT, primary_key = True),
+ Column('user_id', INT, primary_key = True, test_needs_autoincrement=True),
Column('user_name', VARCHAR(20)),
)
metadata.create_all()
+ @engines.close_first
def teardown(self):
testing.db.connect().execute(users.delete())
+
@classmethod
def teardown_class(cls):
metadata.drop_all()
- @testing.fails_on_everything_except('firebird', 'maxdb', 'sqlite')
+ @testing.fails_on_everything_except('firebird', 'maxdb', 'sqlite', 'mysql+pyodbc', '+zxjdbc')
def test_raw_qmark(self):
for conn in (testing.db, testing.db.connect()):
conn.execute("insert into users (user_id, user_name) values (?, ?)", (1,"jack"))
@@ -38,7 +40,8 @@ class ExecuteTest(TestBase):
assert res.fetchall() == [(1, "jack"), (2, "fred"), (3, "ed"), (4, "horse"), (5, "barney"), (6, "donkey"), (7, 'sally')]
conn.execute("delete from users")
- @testing.fails_on_everything_except('mysql', 'postgres')
+ @testing.fails_on_everything_except('mysql+mysqldb', 'postgresql')
+ @testing.fails_on('postgresql+zxjdbc', 'sprintf not supported')
# some psycopg2 versions bomb this.
def test_raw_sprintf(self):
for conn in (testing.db, testing.db.connect()):
@@ -52,8 +55,8 @@ class ExecuteTest(TestBase):
# pyformat is supported for mysql, but skipping because a few driver
# versions have a bug that bombs out on this test. (1.2.2b3, 1.2.2c1, 1.2.2)
- @testing.skip_if(lambda: testing.against('mysql'), 'db-api flaky')
- @testing.fails_on_everything_except('postgres')
+ @testing.skip_if(lambda: testing.against('mysql+mysqldb'), 'db-api flaky')
+ @testing.fails_on_everything_except('postgresql+psycopg2')
def test_raw_python(self):
for conn in (testing.db, testing.db.connect()):
conn.execute("insert into users (user_id, user_name) values (%(id)s, %(name)s)", {'id':1, 'name':'jack'})
@@ -63,7 +66,7 @@ class ExecuteTest(TestBase):
assert res.fetchall() == [(1, "jack"), (2, "ed"), (3, "horse"), (4, 'sally')]
conn.execute("delete from users")
- @testing.fails_on_everything_except('sqlite', 'oracle')
+ @testing.fails_on_everything_except('sqlite', 'oracle+cx_oracle')
def test_raw_named(self):
for conn in (testing.db, testing.db.connect()):
conn.execute("insert into users (user_id, user_name) values (:id, :name)", {'id':1, 'name':'jack'})
@@ -81,11 +84,12 @@ class ExecuteTest(TestBase):
except tsa.exc.DBAPIError:
assert True
- @testing.fails_on('mssql', 'rowcount returns -1')
def test_empty_insert(self):
"""test that execute() interprets [] as a list with no params"""
result = testing.db.execute(users.insert().values(user_name=bindparam('name')), [])
- eq_(result.rowcount, 1)
+ eq_(testing.db.execute(users.select()).fetchall(), [
+ (1, None)
+ ])
class ProxyConnectionTest(TestBase):
@testing.fails_on('firebird', 'Data type unknown')
@@ -102,6 +106,7 @@ class ProxyConnectionTest(TestBase):
return execute(clauseelement, *multiparams, **params)
def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
+ print "CE", statement, parameters
cursor_stmts.append(
(statement, parameters, None)
)
@@ -118,8 +123,8 @@ class ProxyConnectionTest(TestBase):
break
for engine in (
- engines.testing_engine(options=dict(proxy=MyProxy())),
- engines.testing_engine(options=dict(proxy=MyProxy(), strategy='threadlocal'))
+ engines.testing_engine(options=dict(implicit_returning=False, proxy=MyProxy())),
+ engines.testing_engine(options=dict(implicit_returning=False, proxy=MyProxy(), strategy='threadlocal'))
):
m = MetaData(engine)
@@ -131,6 +136,7 @@ class ProxyConnectionTest(TestBase):
t1.insert().execute(c1=6)
assert engine.execute("select * from t1").fetchall() == [(5, 'some data'), (6, 'foo')]
finally:
+ pass
m.drop_all()
engine.dispose()
@@ -143,14 +149,14 @@ class ProxyConnectionTest(TestBase):
("DROP TABLE t1", {}, None)
]
- if engine.dialect.preexecute_pk_sequences:
+ if True: # or engine.dialect.preexecute_pk_sequences:
cursor = [
- ("CREATE TABLE t1", {}, None),
+ ("CREATE TABLE t1", {}, ()),
("INSERT INTO t1 (c1, c2)", {'c2': 'some data', 'c1': 5}, [5, 'some data']),
("SELECT lower", {'lower_2':'Foo'}, ['Foo']),
("INSERT INTO t1 (c1, c2)", {'c2': 'foo', 'c1': 6}, [6, 'foo']),
- ("select * from t1", {}, None),
- ("DROP TABLE t1", {}, None)
+ ("select * from t1", {}, ()),
+ ("DROP TABLE t1", {}, ())
]
else:
cursor = [
diff --git a/test/engine/test_metadata.py b/test/engine/test_metadata.py
index ca4fbaa48..784a7b9ce 100644
--- a/test/engine/test_metadata.py
+++ b/test/engine/test_metadata.py
@@ -1,11 +1,11 @@
from sqlalchemy.test.testing import assert_raises, assert_raises_message
import pickle
-from sqlalchemy import MetaData
-from sqlalchemy import Integer, String, UniqueConstraint, CheckConstraint, ForeignKey
+from sqlalchemy import Integer, String, UniqueConstraint, CheckConstraint, ForeignKey, MetaData
from sqlalchemy.test.schema import Table
from sqlalchemy.test.schema import Column
+from sqlalchemy import schema
import sqlalchemy as tsa
-from sqlalchemy.test import TestBase, ComparesTables, testing, engines
+from sqlalchemy.test import TestBase, ComparesTables, AssertsCompiledSQL, testing, engines
from sqlalchemy.test.testing import eq_
class MetaDataTest(TestBase, ComparesTables):
@@ -83,7 +83,7 @@ class MetaDataTest(TestBase, ComparesTables):
meta.create_all(testing.db)
try:
- for test, has_constraints in ((test_to_metadata, True), (test_pickle, True), (test_pickle_via_reflect, False)):
+ for test, has_constraints in ((test_to_metadata, True), (test_pickle, True),(test_pickle_via_reflect, False)):
table_c, table2_c = test()
self.assert_tables_equal(table, table_c)
self.assert_tables_equal(table2, table2_c)
@@ -143,29 +143,30 @@ class MetaDataTest(TestBase, ComparesTables):
MetaData(testing.db), autoload=True)
-class TableOptionsTest(TestBase):
- def setup(self):
- self.engine = engines.mock_engine()
- self.metadata = MetaData(self.engine)
-
+class TableOptionsTest(TestBase, AssertsCompiledSQL):
def test_prefixes(self):
- table1 = Table("temporary_table_1", self.metadata,
+ table1 = Table("temporary_table_1", MetaData(),
Column("col1", Integer),
prefixes = ["TEMPORARY"])
- table1.create()
- assert [str(x) for x in self.engine.mock if 'CREATE TEMPORARY TABLE' in str(x)]
- del self.engine.mock[:]
- table2 = Table("temporary_table_2", self.metadata,
+
+ self.assert_compile(
+ schema.CreateTable(table1),
+ "CREATE TEMPORARY TABLE temporary_table_1 (col1 INTEGER)"
+ )
+
+ table2 = Table("temporary_table_2", MetaData(),
Column("col1", Integer),
prefixes = ["VIRTUAL"])
- table2.create()
- assert [str(x) for x in self.engine.mock if 'CREATE VIRTUAL TABLE' in str(x)]
+ self.assert_compile(
+ schema.CreateTable(table2),
+ "CREATE VIRTUAL TABLE temporary_table_2 (col1 INTEGER)"
+ )
def test_table_info(self):
-
- t1 = Table('foo', self.metadata, info={'x':'y'})
- t2 = Table('bar', self.metadata, info={})
- t3 = Table('bat', self.metadata)
+ metadata = MetaData()
+ t1 = Table('foo', metadata, info={'x':'y'})
+ t2 = Table('bar', metadata, info={})
+ t3 = Table('bat', metadata)
assert t1.info == {'x':'y'}
assert t2.info == {}
assert t3.info == {}
diff --git a/test/engine/test_parseconnect.py b/test/engine/test_parseconnect.py
index 6b7ac37b2..90c0969be 100644
--- a/test/engine/test_parseconnect.py
+++ b/test/engine/test_parseconnect.py
@@ -1,4 +1,6 @@
-import ConfigParser, StringIO
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
+import ConfigParser
+import StringIO
import sqlalchemy.engine.url as url
from sqlalchemy import create_engine, engine_from_config
import sqlalchemy as tsa
@@ -28,8 +30,6 @@ class ParseConnectTest(TestBase):
'dbtype://username:apples%2Foranges@hostspec/mydatabase',
):
u = url.make_url(text)
- print u, text
- print "username=", u.username, "password=", u.password, "database=", u.database, "host=", u.host
assert u.drivername == 'dbtype'
assert u.username == 'username' or u.username is None
assert u.password == 'password' or u.password == 'apples/oranges' or u.password is None
@@ -41,21 +41,28 @@ class CreateEngineTest(TestBase):
def test_connect_query(self):
dbapi = MockDBAPI(foober='12', lala='18', fooz='somevalue')
- # start the postgres dialect, but put our mock DBAPI as the module instead of psycopg
- e = create_engine('postgres://scott:tiger@somehost/test?foober=12&lala=18&fooz=somevalue', module=dbapi)
+ e = create_engine(
+ 'postgresql://scott:tiger@somehost/test?foober=12&lala=18&fooz=somevalue',
+ module=dbapi,
+ _initialize=False
+ )
c = e.connect()
def test_kwargs(self):
dbapi = MockDBAPI(foober=12, lala=18, hoho={'this':'dict'}, fooz='somevalue')
- # start the postgres dialect, but put our mock DBAPI as the module instead of psycopg
- e = create_engine('postgres://scott:tiger@somehost/test?fooz=somevalue', connect_args={'foober':12, 'lala':18, 'hoho':{'this':'dict'}}, module=dbapi)
+ e = create_engine(
+ 'postgresql://scott:tiger@somehost/test?fooz=somevalue',
+ connect_args={'foober':12, 'lala':18, 'hoho':{'this':'dict'}},
+ module=dbapi,
+ _initialize=False
+ )
c = e.connect()
def test_coerce_config(self):
raw = r"""
[prefixed]
-sqlalchemy.url=postgres://scott:tiger@somehost/test?fooz=somevalue
+sqlalchemy.url=postgresql://scott:tiger@somehost/test?fooz=somevalue
sqlalchemy.convert_unicode=0
sqlalchemy.echo=false
sqlalchemy.echo_pool=1
@@ -65,7 +72,7 @@ sqlalchemy.pool_size=2
sqlalchemy.pool_threadlocal=1
sqlalchemy.pool_timeout=10
[plain]
-url=postgres://scott:tiger@somehost/test?fooz=somevalue
+url=postgresql://scott:tiger@somehost/test?fooz=somevalue
convert_unicode=0
echo=0
echo_pool=1
@@ -79,7 +86,7 @@ pool_timeout=10
ini.readfp(StringIO.StringIO(raw))
expected = {
- 'url': 'postgres://scott:tiger@somehost/test?fooz=somevalue',
+ 'url': 'postgresql://scott:tiger@somehost/test?fooz=somevalue',
'convert_unicode': 0,
'echo': False,
'echo_pool': True,
@@ -97,17 +104,17 @@ pool_timeout=10
self.assert_(tsa.engine._coerce_config(plain, '') == expected)
def test_engine_from_config(self):
- dbapi = MockDBAPI()
+ dbapi = mock_dbapi
config = {
- 'sqlalchemy.url':'postgres://scott:tiger@somehost/test?fooz=somevalue',
+ 'sqlalchemy.url':'postgresql://scott:tiger@somehost/test?fooz=somevalue',
'sqlalchemy.pool_recycle':'50',
'sqlalchemy.echo':'true'
}
e = engine_from_config(config, module=dbapi)
assert e.pool._recycle == 50
- assert e.url == url.make_url('postgres://scott:tiger@somehost/test?fooz=somevalue')
+ assert e.url == url.make_url('postgresql://scott:tiger@somehost/test?fooz=somevalue')
assert e.echo is True
def test_custom(self):
@@ -116,109 +123,77 @@ pool_timeout=10
def connect():
return dbapi.connect(foober=12, lala=18, fooz='somevalue', hoho={'this':'dict'})
- # start the postgres dialect, but put our mock DBAPI as the module instead of psycopg
- e = create_engine('postgres://', creator=connect, module=dbapi)
+ # start the postgresql dialect, but put our mock DBAPI as the module instead of psycopg
+ e = create_engine('postgresql://', creator=connect, module=dbapi, _initialize=False)
c = e.connect()
def test_recycle(self):
dbapi = MockDBAPI(foober=12, lala=18, hoho={'this':'dict'}, fooz='somevalue')
- e = create_engine('postgres://', pool_recycle=472, module=dbapi)
+ e = create_engine('postgresql://', pool_recycle=472, module=dbapi, _initialize=False)
assert e.pool._recycle == 472
def test_badargs(self):
- # good arg, use MockDBAPI to prevent oracle import errors
- e = create_engine('oracle://', use_ansi=True, module=MockDBAPI())
-
- try:
- e = create_engine("foobar://", module=MockDBAPI())
- assert False
- except ImportError:
- assert True
+ assert_raises(ImportError, create_engine, "foobar://", module=mock_dbapi)
# bad arg
- try:
- e = create_engine('postgres://', use_ansi=True, module=MockDBAPI())
- assert False
- except TypeError:
- assert True
+ assert_raises(TypeError, create_engine, 'postgresql://', use_ansi=True, module=mock_dbapi)
# bad arg
- try:
- e = create_engine('oracle://', lala=5, use_ansi=True, module=MockDBAPI())
- assert False
- except TypeError:
- assert True
+ assert_raises(TypeError, create_engine, 'oracle://', lala=5, use_ansi=True, module=mock_dbapi)
- try:
- e = create_engine('postgres://', lala=5, module=MockDBAPI())
- assert False
- except TypeError:
- assert True
+ assert_raises(TypeError, create_engine, 'postgresql://', lala=5, module=mock_dbapi)
- try:
- e = create_engine('sqlite://', lala=5)
- assert False
- except TypeError:
- assert True
+ assert_raises(TypeError, create_engine,'sqlite://', lala=5, module=mock_sqlite_dbapi)
- try:
- e = create_engine('mysql://', use_unicode=True, module=MockDBAPI())
- assert False
- except TypeError:
- assert True
-
- try:
- # sqlite uses SingletonThreadPool which doesnt have max_overflow
- e = create_engine('sqlite://', max_overflow=5)
- assert False
- except TypeError:
- assert True
+ assert_raises(TypeError, create_engine, 'mysql+mysqldb://', use_unicode=True, module=mock_dbapi)
- e = create_engine('mysql://', module=MockDBAPI(), connect_args={'use_unicode':True}, convert_unicode=True)
+ # sqlite uses SingletonThreadPool which doesnt have max_overflow
+ assert_raises(TypeError, create_engine, 'sqlite://', max_overflow=5,
+ module=mock_sqlite_dbapi)
- e = create_engine('sqlite://', connect_args={'use_unicode':True}, convert_unicode=True)
try:
- c = e.connect()
- assert False
- except tsa.exc.DBAPIError:
- assert True
+ e = create_engine('sqlite://', connect_args={'use_unicode':True}, convert_unicode=True)
+ except ImportError:
+ # no sqlite
+ pass
+ else:
+ # raises DBAPIerror due to use_unicode not a sqlite arg
+ assert_raises(tsa.exc.DBAPIError, e.connect)
def test_urlattr(self):
"""test the url attribute on ``Engine``."""
- e = create_engine('mysql://scott:tiger@localhost/test', module=MockDBAPI())
+ e = create_engine('mysql://scott:tiger@localhost/test', module=mock_dbapi, _initialize=False)
u = url.make_url('mysql://scott:tiger@localhost/test')
- e2 = create_engine(u, module=MockDBAPI())
+ e2 = create_engine(u, module=mock_dbapi, _initialize=False)
assert e.url.drivername == e2.url.drivername == 'mysql'
assert e.url.username == e2.url.username == 'scott'
assert e2.url is u
def test_poolargs(self):
"""test that connection pool args make it thru"""
- e = create_engine('postgres://', creator=None, pool_recycle=50, echo_pool=None, module=MockDBAPI())
+ e = create_engine('postgresql://', creator=None, pool_recycle=50, echo_pool=None, module=mock_dbapi, _initialize=False)
assert e.pool._recycle == 50
# these args work for QueuePool
- e = create_engine('postgres://', max_overflow=8, pool_timeout=60, poolclass=tsa.pool.QueuePool, module=MockDBAPI())
+ e = create_engine('postgresql://', max_overflow=8, pool_timeout=60, poolclass=tsa.pool.QueuePool, module=mock_dbapi)
- try:
- # but not SingletonThreadPool
- e = create_engine('sqlite://', max_overflow=8, pool_timeout=60, poolclass=tsa.pool.SingletonThreadPool)
- assert False
- except TypeError:
- assert True
+ # but not SingletonThreadPool
+ assert_raises(TypeError, create_engine, 'sqlite://', max_overflow=8, pool_timeout=60,
+ poolclass=tsa.pool.SingletonThreadPool, module=mock_sqlite_dbapi)
class MockDBAPI(object):
def __init__(self, **kwargs):
self.kwargs = kwargs
self.paramstyle = 'named'
- def connect(self, **kwargs):
- print kwargs, self.kwargs
+ def connect(self, *args, **kwargs):
for k in self.kwargs:
assert k in kwargs, "key %s not present in dictionary" % k
assert kwargs[k]==self.kwargs[k], "value %s does not match %s" % (kwargs[k], self.kwargs[k])
return MockConnection()
class MockConnection(object):
+ def get_server_info(self):
+ return "5.0"
def close(self):
pass
def cursor(self):
@@ -227,4 +202,6 @@ class MockCursor(object):
def close(self):
pass
mock_dbapi = MockDBAPI()
-
+mock_sqlite_dbapi = msd = MockDBAPI()
+msd.version_info = msd.sqlite_version_info = (99, 9, 9)
+msd.sqlite_version = '99.9.9'
diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py
index d135ad337..68637281e 100644
--- a/test/engine/test_pool.py
+++ b/test/engine/test_pool.py
@@ -1,7 +1,8 @@
-import threading, time, gc
-from sqlalchemy import pool, interfaces
+import threading, time
+from sqlalchemy import pool, interfaces, create_engine, select
import sqlalchemy as tsa
-from sqlalchemy.test import TestBase
+from sqlalchemy.test import TestBase, testing
+from sqlalchemy.test.util import gc_collect, lazy_gc
mcid = 1
@@ -51,7 +52,6 @@ class PoolTest(PoolTestBase):
connection2 = manager.connect('foo.db')
connection3 = manager.connect('bar.db')
- print "connection " + repr(connection)
self.assert_(connection.cursor() is not None)
self.assert_(connection is connection2)
self.assert_(connection2 is not connection3)
@@ -70,8 +70,6 @@ class PoolTest(PoolTestBase):
connection = manager.connect('foo.db')
connection2 = manager.connect('foo.db')
- print "connection " + repr(connection)
-
self.assert_(connection.cursor() is not None)
self.assert_(connection is not connection2)
@@ -103,7 +101,8 @@ class PoolTest(PoolTestBase):
c2.close()
else:
c2 = None
-
+ lazy_gc()
+
if useclose:
c1 = p.connect()
c2 = p.connect()
@@ -117,6 +116,8 @@ class PoolTest(PoolTestBase):
# extra tests with QueuePool to ensure connections get __del__()ed when dereferenced
if isinstance(p, pool.QueuePool):
+ lazy_gc()
+
self.assert_(p.checkedout() == 0)
c1 = p.connect()
c2 = p.connect()
@@ -126,6 +127,7 @@ class PoolTest(PoolTestBase):
else:
c2 = None
c1 = None
+ lazy_gc()
self.assert_(p.checkedout() == 0)
def test_properties(self):
@@ -164,6 +166,8 @@ class PoolTest(PoolTestBase):
def __init__(self):
if hasattr(self, 'connect'):
self.connect = self.inst_connect
+ if hasattr(self, 'first_connect'):
+ self.first_connect = self.inst_first_connect
if hasattr(self, 'checkout'):
self.checkout = self.inst_checkout
if hasattr(self, 'checkin'):
@@ -171,14 +175,17 @@ class PoolTest(PoolTestBase):
self.clear()
def clear(self):
self.connected = []
+ self.first_connected = []
self.checked_out = []
self.checked_in = []
- def assert_total(innerself, conn, cout, cin):
+ def assert_total(innerself, conn, fconn, cout, cin):
self.assert_(len(innerself.connected) == conn)
+ self.assert_(len(innerself.first_connected) == fconn)
self.assert_(len(innerself.checked_out) == cout)
self.assert_(len(innerself.checked_in) == cin)
- def assert_in(innerself, item, in_conn, in_cout, in_cin):
+ def assert_in(innerself, item, in_conn, in_fconn, in_cout, in_cin):
self.assert_((item in innerself.connected) == in_conn)
+ self.assert_((item in innerself.first_connected) == in_fconn)
self.assert_((item in innerself.checked_out) == in_cout)
self.assert_((item in innerself.checked_in) == in_cin)
def inst_connect(self, con, record):
@@ -186,6 +193,11 @@ class PoolTest(PoolTestBase):
assert con is not None
assert record is not None
self.connected.append(con)
+ def inst_first_connect(self, con, record):
+ print "first_connect(%s, %s)" % (con, record)
+ assert con is not None
+ assert record is not None
+ self.first_connected.append(con)
def inst_checkout(self, con, record, proxy):
print "checkout(%s, %s, %s)" % (con, record, proxy)
assert con is not None
@@ -203,6 +215,9 @@ class PoolTest(PoolTestBase):
class ListenConnect(InstrumentingListener):
def connect(self, con, record):
pass
+ class ListenFirstConnect(InstrumentingListener):
+ def first_connect(self, con, record):
+ pass
class ListenCheckOut(InstrumentingListener):
def checkout(self, con, record, proxy, num):
pass
@@ -214,40 +229,43 @@ class PoolTest(PoolTestBase):
return pool.QueuePool(creator=lambda: dbapi.connect('foo.db'),
use_threadlocal=False, **kw)
- def assert_listeners(p, total, conn, cout, cin):
+ def assert_listeners(p, total, conn, fconn, cout, cin):
for instance in (p, p.recreate()):
self.assert_(len(instance.listeners) == total)
self.assert_(len(instance._on_connect) == conn)
+ self.assert_(len(instance._on_first_connect) == fconn)
self.assert_(len(instance._on_checkout) == cout)
self.assert_(len(instance._on_checkin) == cin)
p = _pool()
- assert_listeners(p, 0, 0, 0, 0)
+ assert_listeners(p, 0, 0, 0, 0, 0)
p.add_listener(ListenAll())
- assert_listeners(p, 1, 1, 1, 1)
+ assert_listeners(p, 1, 1, 1, 1, 1)
p.add_listener(ListenConnect())
- assert_listeners(p, 2, 2, 1, 1)
+ assert_listeners(p, 2, 2, 1, 1, 1)
+
+ p.add_listener(ListenFirstConnect())
+ assert_listeners(p, 3, 2, 2, 1, 1)
p.add_listener(ListenCheckOut())
- assert_listeners(p, 3, 2, 2, 1)
+ assert_listeners(p, 4, 2, 2, 2, 1)
p.add_listener(ListenCheckIn())
- assert_listeners(p, 4, 2, 2, 2)
+ assert_listeners(p, 5, 2, 2, 2, 2)
del p
- print "----"
snoop = ListenAll()
p = _pool(listeners=[snoop])
- assert_listeners(p, 1, 1, 1, 1)
+ assert_listeners(p, 1, 1, 1, 1, 1)
c = p.connect()
- snoop.assert_total(1, 1, 0)
+ snoop.assert_total(1, 1, 1, 0)
cc = c.connection
- snoop.assert_in(cc, True, True, False)
+ snoop.assert_in(cc, True, True, True, False)
c.close()
- snoop.assert_in(cc, True, True, True)
+ snoop.assert_in(cc, True, True, True, True)
del c, cc
snoop.clear()
@@ -255,10 +273,11 @@ class PoolTest(PoolTestBase):
# this one depends on immediate gc
c = p.connect()
cc = c.connection
- snoop.assert_in(cc, False, True, False)
- snoop.assert_total(0, 1, 0)
+ snoop.assert_in(cc, False, False, True, False)
+ snoop.assert_total(0, 0, 1, 0)
del c, cc
- snoop.assert_total(0, 1, 1)
+ lazy_gc()
+ snoop.assert_total(0, 0, 1, 1)
p.dispose()
snoop.clear()
@@ -266,44 +285,46 @@ class PoolTest(PoolTestBase):
c = p.connect()
c.close()
c = p.connect()
- snoop.assert_total(1, 2, 1)
+ snoop.assert_total(1, 0, 2, 1)
c.close()
- snoop.assert_total(1, 2, 2)
+ snoop.assert_total(1, 0, 2, 2)
# invalidation
p.dispose()
snoop.clear()
c = p.connect()
- snoop.assert_total(1, 1, 0)
+ snoop.assert_total(1, 0, 1, 0)
c.invalidate()
- snoop.assert_total(1, 1, 1)
+ snoop.assert_total(1, 0, 1, 1)
c.close()
- snoop.assert_total(1, 1, 1)
+ snoop.assert_total(1, 0, 1, 1)
del c
- snoop.assert_total(1, 1, 1)
+ lazy_gc()
+ snoop.assert_total(1, 0, 1, 1)
c = p.connect()
- snoop.assert_total(2, 2, 1)
+ snoop.assert_total(2, 0, 2, 1)
c.close()
del c
- snoop.assert_total(2, 2, 2)
+ lazy_gc()
+ snoop.assert_total(2, 0, 2, 2)
# detached
p.dispose()
snoop.clear()
c = p.connect()
- snoop.assert_total(1, 1, 0)
+ snoop.assert_total(1, 0, 1, 0)
c.detach()
- snoop.assert_total(1, 1, 0)
+ snoop.assert_total(1, 0, 1, 0)
c.close()
del c
- snoop.assert_total(1, 1, 0)
+ snoop.assert_total(1, 0, 1, 0)
c = p.connect()
- snoop.assert_total(2, 2, 0)
+ snoop.assert_total(2, 0, 2, 0)
c.close()
del c
- snoop.assert_total(2, 2, 1)
+ snoop.assert_total(2, 0, 2, 1)
def test_listeners_callables(self):
dbapi = MockDBAPI()
@@ -362,262 +383,293 @@ class PoolTest(PoolTestBase):
c.close()
assert counts == [1, 2, 3]
+ def test_listener_after_oninit(self):
+ """Test that listeners are called after OnInit is removed"""
+ called = []
+ def listener(*args):
+ called.append(True)
+ listener.connect = listener
+ engine = create_engine(testing.db.url)
+ engine.pool.add_listener(listener)
+ engine.execute(select([1]))
+ assert called, "Listener not called on connect"
+
+
class QueuePoolTest(PoolTestBase):
- def testqueuepool_del(self):
- self._do_testqueuepool(useclose=False)
-
- def testqueuepool_close(self):
- self._do_testqueuepool(useclose=True)
-
- def _do_testqueuepool(self, useclose=False):
- p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = False)
-
- def status(pool):
- tup = (pool.size(), pool.checkedin(), pool.overflow(), pool.checkedout())
- print "Pool size: %d Connections in pool: %d Current Overflow: %d Current Checked out connections: %d" % tup
- return tup
-
- c1 = p.connect()
- self.assert_(status(p) == (3,0,-2,1))
- c2 = p.connect()
- self.assert_(status(p) == (3,0,-1,2))
- c3 = p.connect()
- self.assert_(status(p) == (3,0,0,3))
- c4 = p.connect()
- self.assert_(status(p) == (3,0,1,4))
- c5 = p.connect()
- self.assert_(status(p) == (3,0,2,5))
- c6 = p.connect()
- self.assert_(status(p) == (3,0,3,6))
- if useclose:
- c4.close()
- c3.close()
- c2.close()
- else:
- c4 = c3 = c2 = None
- self.assert_(status(p) == (3,3,3,3))
- if useclose:
- c1.close()
- c5.close()
- c6.close()
- else:
- c1 = c5 = c6 = None
- self.assert_(status(p) == (3,3,0,0))
- c1 = p.connect()
- c2 = p.connect()
- self.assert_(status(p) == (3, 1, 0, 2), status(p))
- if useclose:
- c2.close()
- else:
- c2 = None
- self.assert_(status(p) == (3, 2, 0, 1))
-
- def test_timeout(self):
- p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = 0, use_threadlocal = False, timeout=2)
- c1 = p.connect()
- c2 = p.connect()
- c3 = p.connect()
- now = time.time()
- try:
- c4 = p.connect()
- assert False
- except tsa.exc.TimeoutError, e:
- assert int(time.time() - now) == 2
-
- def test_timeout_race(self):
- # test a race condition where the initial connecting threads all race
- # to queue.Empty, then block on the mutex. each thread consumes a
- # connection as they go in. when the limit is reached, the remaining
- # threads go in, and get TimeoutError; even though they never got to
- # wait for the timeout on queue.get(). the fix involves checking the
- # timeout again within the mutex, and if so, unlocking and throwing
- # them back to the start of do_get()
- p = pool.QueuePool(creator = lambda: mock_dbapi.connect(delay=.05), pool_size = 2, max_overflow = 1, use_threadlocal = False, timeout=3)
- timeouts = []
- def checkout():
- for x in xrange(1):
- now = time.time()
- try:
- c1 = p.connect()
- except tsa.exc.TimeoutError, e:
- timeouts.append(int(time.time()) - now)
- continue
- time.sleep(4)
- c1.close()
-
- threads = []
- for i in xrange(10):
- th = threading.Thread(target=checkout)
- th.start()
- threads.append(th)
- for th in threads:
- th.join()
-
- print timeouts
- assert len(timeouts) > 0
- for t in timeouts:
- assert abs(t - 3) < 1, "Not all timeouts were 3 seconds: " + repr(timeouts)
-
- def _test_overflow(self, thread_count, max_overflow):
- def creator():
- time.sleep(.05)
- return mock_dbapi.connect()
-
- p = pool.QueuePool(creator=creator,
- pool_size=3, timeout=2,
- max_overflow=max_overflow)
- peaks = []
- def whammy():
- for i in range(10):
- try:
- con = p.connect()
- time.sleep(.005)
- peaks.append(p.overflow())
- con.close()
- del con
- except tsa.exc.TimeoutError:
- pass
- threads = []
- for i in xrange(thread_count):
- th = threading.Thread(target=whammy)
- th.start()
- threads.append(th)
- for th in threads:
- th.join()
-
- self.assert_(max(peaks) <= max_overflow)
-
- def test_no_overflow(self):
- self._test_overflow(40, 0)
-
- def test_max_overflow(self):
- self._test_overflow(40, 5)
-
- def test_mixed_close(self):
- p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
- c1 = p.connect()
- c2 = p.connect()
- assert c1 is c2
- c1.close()
- c2 = None
- assert p.checkedout() == 1
- c1 = None
- assert p.checkedout() == 0
-
- def test_weakref_kaboom(self):
- p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
- c1 = p.connect()
- c2 = p.connect()
- c1.close()
- c2 = None
- del c1
- del c2
- gc.collect()
- assert p.checkedout() == 0
- c3 = p.connect()
- assert c3 is not None
-
- def test_trick_the_counter(self):
- """this is a "flaw" in the connection pool; since threadlocal uses a single ConnectionFairy per thread
- with an open/close counter, you can fool the counter into giving you a ConnectionFairy with an
- ambiguous counter. i.e. its not true reference counting."""
- p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
- c1 = p.connect()
- c2 = p.connect()
- assert c1 is c2
- c1.close()
- c2 = p.connect()
- c2.close()
- self.assert_(p.checkedout() != 0)
-
- c2.close()
- self.assert_(p.checkedout() == 0)
-
- def test_recycle(self):
- p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 1, max_overflow = 0, use_threadlocal = False, recycle=3)
-
- c1 = p.connect()
- c_id = id(c1.connection)
- c1.close()
- c2 = p.connect()
- assert id(c2.connection) == c_id
- c2.close()
- time.sleep(4)
- c3= p.connect()
- assert id(c3.connection) != c_id
-
- def test_invalidate(self):
- dbapi = MockDBAPI()
- p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
- c1 = p.connect()
- c_id = c1.connection.id
- c1.close(); c1=None
- c1 = p.connect()
- assert c1.connection.id == c_id
- c1.invalidate()
- c1 = None
-
- c1 = p.connect()
- assert c1.connection.id != c_id
-
- def test_recreate(self):
- dbapi = MockDBAPI()
- p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
- p2 = p.recreate()
- assert p2.size() == 1
- assert p2._use_threadlocal is False
- assert p2._max_overflow == 0
-
- def test_reconnect(self):
- """tests reconnect operations at the pool level. SA's engine/dialect includes another
- layer of reconnect support for 'database was lost' errors."""
+ def testqueuepool_del(self):
+ self._do_testqueuepool(useclose=False)
+
+ def testqueuepool_close(self):
+ self._do_testqueuepool(useclose=True)
+
+ def _do_testqueuepool(self, useclose=False):
+ p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = False)
+
+ def status(pool):
+ tup = (pool.size(), pool.checkedin(), pool.overflow(), pool.checkedout())
+ print "Pool size: %d Connections in pool: %d Current Overflow: %d Current Checked out connections: %d" % tup
+ return tup
+
+ c1 = p.connect()
+ self.assert_(status(p) == (3,0,-2,1))
+ c2 = p.connect()
+ self.assert_(status(p) == (3,0,-1,2))
+ c3 = p.connect()
+ self.assert_(status(p) == (3,0,0,3))
+ c4 = p.connect()
+ self.assert_(status(p) == (3,0,1,4))
+ c5 = p.connect()
+ self.assert_(status(p) == (3,0,2,5))
+ c6 = p.connect()
+ self.assert_(status(p) == (3,0,3,6))
+ if useclose:
+ c4.close()
+ c3.close()
+ c2.close()
+ else:
+ c4 = c3 = c2 = None
+ lazy_gc()
+
+ self.assert_(status(p) == (3,3,3,3))
+ if useclose:
+ c1.close()
+ c5.close()
+ c6.close()
+ else:
+ c1 = c5 = c6 = None
+ lazy_gc()
+
+ self.assert_(status(p) == (3,3,0,0))
+
+ c1 = p.connect()
+ c2 = p.connect()
+ self.assert_(status(p) == (3, 1, 0, 2), status(p))
+ if useclose:
+ c2.close()
+ else:
+ c2 = None
+ lazy_gc()
+
+ self.assert_(status(p) == (3, 2, 0, 1))
+
+ c1.close()
+
+ lazy_gc()
+ assert not pool._refs
- dbapi = MockDBAPI()
- p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
- c1 = p.connect()
- c_id = c1.connection.id
- c1.close(); c1=None
-
- c1 = p.connect()
- assert c1.connection.id == c_id
- dbapi.raise_error = True
- c1.invalidate()
- c1 = None
-
- c1 = p.connect()
- assert c1.connection.id != c_id
-
- def test_detach(self):
- dbapi = MockDBAPI()
- p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
-
- c1 = p.connect()
- c1.detach()
- c_id = c1.connection.id
-
- c2 = p.connect()
- assert c2.connection.id != c1.connection.id
- dbapi.raise_error = True
-
- c2.invalidate()
- c2 = None
-
- c2 = p.connect()
- assert c2.connection.id != c1.connection.id
-
- con = c1.connection
-
- assert not con.closed
- c1.close()
- assert con.closed
-
- def test_threadfairy(self):
- p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
- c1 = p.connect()
- c1.close()
- c2 = p.connect()
- assert c2.connection is not None
+ def test_timeout(self):
+ p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = 0, use_threadlocal = False, timeout=2)
+ c1 = p.connect()
+ c2 = p.connect()
+ c3 = p.connect()
+ now = time.time()
+ try:
+ c4 = p.connect()
+ assert False
+ except tsa.exc.TimeoutError, e:
+ assert int(time.time() - now) == 2
+
+ def test_timeout_race(self):
+ # test a race condition where the initial connecting threads all race
+ # to queue.Empty, then block on the mutex. each thread consumes a
+ # connection as they go in. when the limit is reached, the remaining
+ # threads go in, and get TimeoutError; even though they never got to
+ # wait for the timeout on queue.get(). the fix involves checking the
+ # timeout again within the mutex, and if so, unlocking and throwing
+ # them back to the start of do_get()
+ p = pool.QueuePool(creator = lambda: mock_dbapi.connect(delay=.05), pool_size = 2, max_overflow = 1, use_threadlocal = False, timeout=3)
+ timeouts = []
+ def checkout():
+ for x in xrange(1):
+ now = time.time()
+ try:
+ c1 = p.connect()
+ except tsa.exc.TimeoutError, e:
+ timeouts.append(int(time.time()) - now)
+ continue
+ time.sleep(4)
+ c1.close()
+
+ threads = []
+ for i in xrange(10):
+ th = threading.Thread(target=checkout)
+ th.start()
+ threads.append(th)
+ for th in threads:
+ th.join()
+
+ print timeouts
+ assert len(timeouts) > 0
+ for t in timeouts:
+ assert abs(t - 3) < 1, "Not all timeouts were 3 seconds: " + repr(timeouts)
+
+ def _test_overflow(self, thread_count, max_overflow):
+ def creator():
+ time.sleep(.05)
+ return mock_dbapi.connect()
+
+ p = pool.QueuePool(creator=creator,
+ pool_size=3, timeout=2,
+ max_overflow=max_overflow)
+ peaks = []
+ def whammy():
+ for i in range(10):
+ try:
+ con = p.connect()
+ time.sleep(.005)
+ peaks.append(p.overflow())
+ con.close()
+ del con
+ except tsa.exc.TimeoutError:
+ pass
+ threads = []
+ for i in xrange(thread_count):
+ th = threading.Thread(target=whammy)
+ th.start()
+ threads.append(th)
+ for th in threads:
+ th.join()
+
+ self.assert_(max(peaks) <= max_overflow)
+
+ lazy_gc()
+ assert not pool._refs
+
+ def test_no_overflow(self):
+ self._test_overflow(40, 0)
+
+ def test_max_overflow(self):
+ self._test_overflow(40, 5)
+
+ def test_mixed_close(self):
+ p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
+ c1 = p.connect()
+ c2 = p.connect()
+ assert c1 is c2
+ c1.close()
+ c2 = None
+ assert p.checkedout() == 1
+ c1 = None
+ lazy_gc()
+ assert p.checkedout() == 0
+
+ lazy_gc()
+ assert not pool._refs
+
+ def test_weakref_kaboom(self):
+ p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
+ c1 = p.connect()
+ c2 = p.connect()
+ c1.close()
+ c2 = None
+ del c1
+ del c2
+ gc_collect()
+ assert p.checkedout() == 0
+ c3 = p.connect()
+ assert c3 is not None
+
+ def test_trick_the_counter(self):
+ """this is a "flaw" in the connection pool; since threadlocal uses a single ConnectionFairy per thread
+ with an open/close counter, you can fool the counter into giving you a ConnectionFairy with an
+ ambiguous counter. i.e. its not true reference counting."""
+ p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
+ c1 = p.connect()
+ c2 = p.connect()
+ assert c1 is c2
+ c1.close()
+ c2 = p.connect()
+ c2.close()
+ self.assert_(p.checkedout() != 0)
+
+ c2.close()
+ self.assert_(p.checkedout() == 0)
+
+ def test_recycle(self):
+ p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 1, max_overflow = 0, use_threadlocal = False, recycle=3)
+
+ c1 = p.connect()
+ c_id = id(c1.connection)
+ c1.close()
+ c2 = p.connect()
+ assert id(c2.connection) == c_id
+ c2.close()
+ time.sleep(4)
+ c3= p.connect()
+ assert id(c3.connection) != c_id
+
+ def test_invalidate(self):
+ dbapi = MockDBAPI()
+ p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
+ c1 = p.connect()
+ c_id = c1.connection.id
+ c1.close(); c1=None
+ c1 = p.connect()
+ assert c1.connection.id == c_id
+ c1.invalidate()
+ c1 = None
+
+ c1 = p.connect()
+ assert c1.connection.id != c_id
+
+ def test_recreate(self):
+ dbapi = MockDBAPI()
+ p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
+ p2 = p.recreate()
+ assert p2.size() == 1
+ assert p2._use_threadlocal is False
+ assert p2._max_overflow == 0
+
+ def test_reconnect(self):
+ """tests reconnect operations at the pool level. SA's engine/dialect includes another
+ layer of reconnect support for 'database was lost' errors."""
+
+ dbapi = MockDBAPI()
+ p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
+ c1 = p.connect()
+ c_id = c1.connection.id
+ c1.close(); c1=None
+
+ c1 = p.connect()
+ assert c1.connection.id == c_id
+ dbapi.raise_error = True
+ c1.invalidate()
+ c1 = None
+
+ c1 = p.connect()
+ assert c1.connection.id != c_id
+
+ def test_detach(self):
+ dbapi = MockDBAPI()
+ p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False)
+
+ c1 = p.connect()
+ c1.detach()
+ c_id = c1.connection.id
+
+ c2 = p.connect()
+ assert c2.connection.id != c1.connection.id
+ dbapi.raise_error = True
+
+ c2.invalidate()
+ c2 = None
+
+ c2 = p.connect()
+ assert c2.connection.id != c1.connection.id
+
+ con = c1.connection
+
+ assert not con.closed
+ c1.close()
+ assert con.closed
+
+ def test_threadfairy(self):
+ p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True)
+ c1 = p.connect()
+ c1.close()
+ c2 = p.connect()
+ assert c2.connection is not None
class SingletonThreadPoolTest(PoolTestBase):
def test_cleanup(self):
diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py
index 3a525c2a7..6afd71515 100644
--- a/test/engine/test_reconnect.py
+++ b/test/engine/test_reconnect.py
@@ -1,12 +1,13 @@
from sqlalchemy.test.testing import eq_
+import time
import weakref
from sqlalchemy import select, MetaData, Integer, String, pool
from sqlalchemy.test.schema import Table
from sqlalchemy.test.schema import Column
import sqlalchemy as tsa
from sqlalchemy.test import TestBase, testing, engines
-import time
-import gc
+from sqlalchemy.test.util import gc_collect
+
class MockDisconnect(Exception):
pass
@@ -54,7 +55,7 @@ class MockReconnectTest(TestBase):
dbapi = MockDBAPI()
# create engine using our current dburi
- db = tsa.create_engine('postgres://foo:bar@localhost/test', module=dbapi)
+ db = tsa.create_engine('postgresql://foo:bar@localhost/test', module=dbapi, _initialize=False)
# monkeypatch disconnect checker
db.dialect.is_disconnect = lambda e: isinstance(e, MockDisconnect)
@@ -98,7 +99,7 @@ class MockReconnectTest(TestBase):
assert id(db.pool) != pid
# ensure all connections closed (pool was recycled)
- gc.collect()
+ gc_collect()
assert len(dbapi.connections) == 0
conn =db.connect()
@@ -118,7 +119,7 @@ class MockReconnectTest(TestBase):
pass
# assert was invalidated
- gc.collect()
+ gc_collect()
assert len(dbapi.connections) == 0
assert not conn.closed
assert conn.invalidated
@@ -168,7 +169,7 @@ class MockReconnectTest(TestBase):
assert conn.invalidated
# ensure all connections closed (pool was recycled)
- gc.collect()
+ gc_collect()
assert len(dbapi.connections) == 0
# test reconnects
@@ -334,7 +335,8 @@ class InvalidateDuringResultTest(TestBase):
meta.drop_all()
engine.dispose()
- @testing.fails_on('mysql', 'FIXME: unknown')
+ @testing.fails_on('+mysqldb', "Buffers the result set and doesn't check for connection close")
+ @testing.fails_on('+pg8000', "Buffers the result set and doesn't check for connection close")
def test_invalidate_on_results(self):
conn = engine.connect()
@@ -344,7 +346,7 @@ class InvalidateDuringResultTest(TestBase):
engine.test_shutdown()
try:
- result.fetchone()
+ print "ghost result: %r" % result.fetchone()
assert False
except tsa.exc.DBAPIError, e:
if not e.connection_invalidated:
diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py
index ea80776a6..dff9fa1bb 100644
--- a/test/engine/test_reflection.py
+++ b/test/engine/test_reflection.py
@@ -1,17 +1,22 @@
from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
import StringIO, unicodedata
import sqlalchemy as sa
+from sqlalchemy import types as sql_types
+from sqlalchemy import schema
+from sqlalchemy.engine.reflection import Inspector
from sqlalchemy import MetaData
from sqlalchemy.test.schema import Table
from sqlalchemy.test.schema import Column
import sqlalchemy as tsa
from sqlalchemy.test import TestBase, ComparesTables, testing, engines
+create_inspector = Inspector.from_engine
metadata, users = None, None
class ReflectionTest(TestBase, ComparesTables):
+ @testing.exclude('mssql', '<', (10, 0, 0), 'Date is only supported on MSSQL 2008+')
@testing.exclude('mysql', '<', (4, 1, 1), 'early types are squirrely')
def test_basic_reflection(self):
meta = MetaData(testing.db)
@@ -22,16 +27,16 @@ class ReflectionTest(TestBase, ComparesTables):
Column('test1', sa.CHAR(5), nullable=False),
Column('test2', sa.Float(5), nullable=False),
Column('test3', sa.Text),
- Column('test4', sa.Numeric, nullable = False),
- Column('test5', sa.DateTime),
+ Column('test4', sa.Numeric(10, 2), nullable = False),
+ Column('test5', sa.Date),
Column('parent_user_id', sa.Integer,
sa.ForeignKey('engine_users.user_id')),
- Column('test6', sa.DateTime, nullable=False),
+ Column('test6', sa.Date, nullable=False),
Column('test7', sa.Text),
Column('test8', sa.Binary),
Column('test_passivedefault2', sa.Integer, server_default='5'),
Column('test9', sa.Binary(100)),
- Column('test_numeric', sa.Numeric()),
+ Column('test10', sa.Numeric(10, 2)),
test_needs_fk=True,
)
@@ -52,9 +57,35 @@ class ReflectionTest(TestBase, ComparesTables):
self.assert_tables_equal(users, reflected_users)
self.assert_tables_equal(addresses, reflected_addresses)
finally:
- addresses.drop()
- users.drop()
-
+ meta.drop_all()
+
+ def test_two_foreign_keys(self):
+ meta = MetaData(testing.db)
+ t1 = Table('t1', meta,
+ Column('id', sa.Integer, primary_key=True),
+ Column('t2id', sa.Integer, sa.ForeignKey('t2.id')),
+ Column('t3id', sa.Integer, sa.ForeignKey('t3.id')),
+ test_needs_fk=True
+ )
+ t2 = Table('t2', meta,
+ Column('id', sa.Integer, primary_key=True),
+ test_needs_fk=True
+ )
+ t3 = Table('t3', meta,
+ Column('id', sa.Integer, primary_key=True),
+ test_needs_fk=True
+ )
+ meta.create_all()
+ try:
+ meta2 = MetaData()
+ t1r, t2r, t3r = [Table(x, meta2, autoload=True, autoload_with=testing.db) for x in ('t1', 't2', 't3')]
+
+ assert t1r.c.t2id.references(t2r.c.id)
+ assert t1r.c.t3id.references(t3r.c.id)
+
+ finally:
+ meta.drop_all()
+
def test_include_columns(self):
meta = MetaData(testing.db)
foo = Table('foo', meta, *[Column(n, sa.String(30))
@@ -84,26 +115,68 @@ class ReflectionTest(TestBase, ComparesTables):
finally:
meta.drop_all()
+ @testing.emits_warning(r".*omitted columns")
+ def test_include_columns_indexes(self):
+ m = MetaData(testing.db)
+
+ t1 = Table('t1', m, Column('a', sa.Integer), Column('b', sa.Integer))
+ sa.Index('foobar', t1.c.a, t1.c.b)
+ sa.Index('bat', t1.c.a)
+ m.create_all()
+ try:
+ m2 = MetaData(testing.db)
+ t2 = Table('t1', m2, autoload=True)
+ assert len(t2.indexes) == 2
+ m2 = MetaData(testing.db)
+ t2 = Table('t1', m2, autoload=True, include_columns=['a'])
+ assert len(t2.indexes) == 1
+
+ m2 = MetaData(testing.db)
+ t2 = Table('t1', m2, autoload=True, include_columns=['a', 'b'])
+ assert len(t2.indexes) == 2
+ finally:
+ m.drop_all()
+
+ def test_autoincrement_col(self):
+ """test that 'autoincrement' is reflected according to sqla's policy.
+
+ Don't mark this test as unsupported for any backend !
+
+ (technically it fails with MySQL InnoDB since "id" comes before "id2")
+
+ """
+
+ meta = MetaData(testing.db)
+ t1 = Table('test', meta,
+ Column('id', sa.Integer, primary_key=True),
+ Column('data', sa.String(50)),
+ )
+ t2 = Table('test2', meta,
+ Column('id', sa.Integer, sa.ForeignKey('test.id'), primary_key=True),
+ Column('id2', sa.Integer, primary_key=True),
+ Column('data', sa.String(50)),
+ )
+ meta.create_all()
+ try:
+ m2 = MetaData(testing.db)
+ t1a = Table('test', m2, autoload=True)
+ assert t1a._autoincrement_column is t1a.c.id
+
+ t2a = Table('test2', m2, autoload=True)
+ assert t2a._autoincrement_column is t2a.c.id2
+
+ finally:
+ meta.drop_all()
+
def test_unknown_types(self):
meta = MetaData(testing.db)
t = Table("test", meta,
Column('foo', sa.DateTime))
- import sys
- dialect_module = sys.modules[testing.db.dialect.__module__]
-
- # we're relying on the presence of "ischema_names" in the
- # dialect module, else we can't test this. we need to be able
- # to get the dialect to not be aware of some type so we temporarily
- # monkeypatch. not sure what a better way for this could be,
- # except for an established dialect hook or dialect-specific tests
- if not hasattr(dialect_module, 'ischema_names'):
- return
-
- ischema_names = dialect_module.ischema_names
+ ischema_names = testing.db.dialect.ischema_names
t.create()
- dialect_module.ischema_names = {}
+ testing.db.dialect.ischema_names = {}
try:
m2 = MetaData(testing.db)
assert_raises(tsa.exc.SAWarning, Table, "test", m2, autoload=True)
@@ -115,7 +188,7 @@ class ReflectionTest(TestBase, ComparesTables):
assert t3.c.foo.type.__class__ == sa.types.NullType
finally:
- dialect_module.ischema_names = ischema_names
+ testing.db.dialect.ischema_names = ischema_names
t.drop()
def test_basic_override(self):
@@ -578,7 +651,6 @@ class ReflectionTest(TestBase, ComparesTables):
m9.reflect()
self.assert_(not m9.tables)
- @testing.fails_on_everything_except('postgres', 'mysql')
def test_index_reflection(self):
m1 = MetaData(testing.db)
t1 = Table('party', m1,
@@ -698,7 +770,7 @@ class UnicodeReflectionTest(TestBase):
def test_basic(self):
try:
# the 'convert_unicode' should not get in the way of the reflection
- # process. reflecttable for oracle, postgres (others?) expect non-unicode
+ # process. reflecttable for oracle, postgresql (others?) expect non-unicode
# strings in result sets/bind params
bind = engines.utf8_engine(options={'convert_unicode':True})
metadata = MetaData(bind)
@@ -713,7 +785,8 @@ class UnicodeReflectionTest(TestBase):
metadata.create_all()
reflected = set(bind.table_names())
- if not names.issubset(reflected):
+ # Jython 2.5 on Java 5 lacks unicodedata.normalize
+ if not names.issubset(reflected) and hasattr(unicodedata, 'normalize'):
# Python source files in the utf-8 coding seem to normalize
# literals as NFC (and the above are explicitly NFC). Maybe
# this database normalizes NFD on reflection.
@@ -741,23 +814,15 @@ class SchemaTest(TestBase):
Column('col1', sa.Integer, primary_key=True),
Column('col2', sa.Integer, sa.ForeignKey('someschema.table1.col1')),
schema='someschema')
- # ensure this doesnt crash
- print [t for t in metadata.sorted_tables]
- buf = StringIO.StringIO()
- def foo(s, p=None):
- buf.write(s)
- gen = sa.create_engine(testing.db.name + "://", strategy="mock", executor=foo)
- gen = gen.dialect.schemagenerator(gen.dialect, gen)
- gen.traverse(table1)
- gen.traverse(table2)
- buf = buf.getvalue()
- print buf
+
+ t1 = str(schema.CreateTable(table1).compile(bind=testing.db))
+ t2 = str(schema.CreateTable(table2).compile(bind=testing.db))
if testing.db.dialect.preparer(testing.db.dialect).omit_schema:
- assert buf.index("CREATE TABLE table1") > -1
- assert buf.index("CREATE TABLE table2") > -1
+ assert t1.index("CREATE TABLE table1") > -1
+ assert t2.index("CREATE TABLE table2") > -1
else:
- assert buf.index("CREATE TABLE someschema.table1") > -1
- assert buf.index("CREATE TABLE someschema.table2") > -1
+ assert t1.index("CREATE TABLE someschema.table1") > -1
+ assert t2.index("CREATE TABLE someschema.table2") > -1
@testing.crashes('firebird', 'No schema support')
@testing.fails_on('sqlite', 'FIXME: unknown')
@@ -767,9 +832,9 @@ class SchemaTest(TestBase):
def test_explicit_default_schema(self):
engine = testing.db
- if testing.against('mysql'):
+ if testing.against('mysql+mysqldb'):
schema = testing.db.url.database
- elif testing.against('postgres'):
+ elif testing.against('postgresql'):
schema = 'public'
elif testing.against('sqlite'):
# Works for CREATE TABLE main.foo, SELECT FROM main.foo, etc.,
@@ -820,4 +885,324 @@ class HasSequenceTest(TestBase):
metadata.drop_all(bind=testing.db)
eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'), False)
+# Tests related to engine.reflection
+
+def get_schema():
+ if testing.against('oracle'):
+ return 'scott'
+ return 'test_schema'
+
+def createTables(meta, schema=None):
+ if schema:
+ parent_user_id = Column('parent_user_id', sa.Integer,
+ sa.ForeignKey('%s.users.user_id' % schema)
+ )
+ else:
+ parent_user_id = Column('parent_user_id', sa.Integer,
+ sa.ForeignKey('users.user_id')
+ )
+
+ users = Table('users', meta,
+ Column('user_id', sa.INT, primary_key=True),
+ Column('user_name', sa.VARCHAR(20), nullable=False),
+ Column('test1', sa.CHAR(5), nullable=False),
+ Column('test2', sa.Float(5), nullable=False),
+ Column('test3', sa.Text),
+ Column('test4', sa.Numeric(10, 2), nullable = False),
+ Column('test5', sa.DateTime),
+ Column('test5-1', sa.TIMESTAMP),
+ parent_user_id,
+ Column('test6', sa.DateTime, nullable=False),
+ Column('test7', sa.Text),
+ Column('test8', sa.Binary),
+ Column('test_passivedefault2', sa.Integer, server_default='5'),
+ Column('test9', sa.Binary(100)),
+ Column('test10', sa.Numeric(10, 2)),
+ schema=schema,
+ test_needs_fk=True,
+ )
+ addresses = Table('email_addresses', meta,
+ Column('address_id', sa.Integer, primary_key = True),
+ Column('remote_user_id', sa.Integer,
+ sa.ForeignKey(users.c.user_id)),
+ Column('email_address', sa.String(20)),
+ schema=schema,
+ test_needs_fk=True,
+ )
+ return (users, addresses)
+
+def createIndexes(con, schema=None):
+ fullname = 'users'
+ if schema:
+ fullname = "%s.%s" % (schema, 'users')
+ query = "CREATE INDEX users_t_idx ON %s (test1, test2)" % fullname
+ con.execute(sa.sql.text(query))
+
+def createViews(con, schema=None):
+ for table_name in ('users', 'email_addresses'):
+ fullname = table_name
+ if schema:
+ fullname = "%s.%s" % (schema, table_name)
+ view_name = fullname + '_v'
+ query = "CREATE VIEW %s AS SELECT * FROM %s" % (view_name,
+ fullname)
+ con.execute(sa.sql.text(query))
+
+def dropViews(con, schema=None):
+ for table_name in ('email_addresses', 'users'):
+ fullname = table_name
+ if schema:
+ fullname = "%s.%s" % (schema, table_name)
+ view_name = fullname + '_v'
+ query = "DROP VIEW %s" % view_name
+ con.execute(sa.sql.text(query))
+
+
+class ComponentReflectionTest(TestBase):
+
+ @testing.requires.schemas
+ def test_get_schema_names(self):
+ meta = MetaData(testing.db)
+ insp = Inspector(meta.bind)
+
+ self.assert_(get_schema() in insp.get_schema_names())
+
+ def _test_get_table_names(self, schema=None, table_type='table',
+ order_by=None):
+ meta = MetaData(testing.db)
+ (users, addresses) = createTables(meta, schema)
+ meta.create_all()
+ createViews(meta.bind, schema)
+ try:
+ insp = Inspector(meta.bind)
+ if table_type == 'view':
+ table_names = insp.get_view_names(schema)
+ table_names.sort()
+ answer = ['email_addresses_v', 'users_v']
+ else:
+ table_names = insp.get_table_names(schema,
+ order_by=order_by)
+ table_names.sort()
+ if order_by == 'foreign_key':
+ answer = ['users', 'email_addresses']
+ else:
+ answer = ['email_addresses', 'users']
+ eq_(table_names, answer)
+ finally:
+ dropViews(meta.bind, schema)
+ addresses.drop()
+ users.drop()
+
+ def test_get_table_names(self):
+ self._test_get_table_names()
+
+ @testing.requires.schemas
+ def test_get_table_names_with_schema(self):
+ self._test_get_table_names(get_schema())
+
+ def test_get_view_names(self):
+ self._test_get_table_names(table_type='view')
+
+ @testing.requires.schemas
+ def test_get_view_names_with_schema(self):
+ self._test_get_table_names(get_schema(), table_type='view')
+
+ def _test_get_columns(self, schema=None, table_type='table'):
+ meta = MetaData(testing.db)
+ (users, addresses) = createTables(meta, schema)
+ table_names = ['users', 'email_addresses']
+ meta.create_all()
+ if table_type == 'view':
+ createViews(meta.bind, schema)
+ table_names = ['users_v', 'email_addresses_v']
+ try:
+ insp = Inspector(meta.bind)
+ for (table_name, table) in zip(table_names, (users, addresses)):
+ schema_name = schema
+ cols = insp.get_columns(table_name, schema=schema_name)
+ self.assert_(len(cols) > 0, len(cols))
+ # should be in order
+ for (i, col) in enumerate(table.columns):
+ eq_(col.name, cols[i]['name'])
+ ctype = cols[i]['type'].__class__
+ ctype_def = col.type
+ if isinstance(ctype_def, sa.types.TypeEngine):
+ ctype_def = ctype_def.__class__
+
+ # Oracle returns Date for DateTime.
+ if testing.against('oracle') \
+ and ctype_def in (sql_types.Date, sql_types.DateTime):
+ ctype_def = sql_types.Date
+
+ # assert that the desired type and return type
+ # share a base within one of the generic types.
+ self.assert_(
+ len(
+ set(
+ ctype.__mro__
+ ).intersection(ctype_def.__mro__)
+ .intersection([sql_types.Integer, sql_types.Numeric,
+ sql_types.DateTime, sql_types.Date, sql_types.Time,
+ sql_types.String, sql_types.Binary])
+ ) > 0
+ ,("%s(%s), %s(%s)" % (col.name, col.type, cols[i]['name'],
+ ctype)))
+ finally:
+ if table_type == 'view':
+ dropViews(meta.bind, schema)
+ addresses.drop()
+ users.drop()
+
+ def test_get_columns(self):
+ self._test_get_columns()
+
+ @testing.requires.schemas
+ def test_get_columns_with_schema(self):
+ self._test_get_columns(schema=get_schema())
+
+ def test_get_view_columns(self):
+ self._test_get_columns(table_type='view')
+
+ @testing.requires.schemas
+ def test_get_view_columns_with_schema(self):
+ self._test_get_columns(schema=get_schema(), table_type='view')
+
+ def _test_get_primary_keys(self, schema=None):
+ meta = MetaData(testing.db)
+ (users, addresses) = createTables(meta, schema)
+ meta.create_all()
+ insp = Inspector(meta.bind)
+ try:
+ users_pkeys = insp.get_primary_keys(users.name,
+ schema=schema)
+ eq_(users_pkeys, ['user_id'])
+ addr_pkeys = insp.get_primary_keys(addresses.name,
+ schema=schema)
+ eq_(addr_pkeys, ['address_id'])
+
+ finally:
+ addresses.drop()
+ users.drop()
+
+ def test_get_primary_keys(self):
+ self._test_get_primary_keys()
+
+ @testing.fails_on('sqlite', 'no schemas')
+ def test_get_primary_keys_with_schema(self):
+ self._test_get_primary_keys(schema=get_schema())
+
+ def _test_get_foreign_keys(self, schema=None):
+ meta = MetaData(testing.db)
+ (users, addresses) = createTables(meta, schema)
+ meta.create_all()
+ insp = Inspector(meta.bind)
+ try:
+ expected_schema = schema
+ # users
+ users_fkeys = insp.get_foreign_keys(users.name,
+ schema=schema)
+ fkey1 = users_fkeys[0]
+ self.assert_(fkey1['name'] is not None)
+ eq_(fkey1['referred_schema'], expected_schema)
+ eq_(fkey1['referred_table'], users.name)
+ eq_(fkey1['referred_columns'], ['user_id', ])
+ eq_(fkey1['constrained_columns'], ['parent_user_id'])
+ #addresses
+ addr_fkeys = insp.get_foreign_keys(addresses.name,
+ schema=schema)
+ fkey1 = addr_fkeys[0]
+ self.assert_(fkey1['name'] is not None)
+ eq_(fkey1['referred_schema'], expected_schema)
+ eq_(fkey1['referred_table'], users.name)
+ eq_(fkey1['referred_columns'], ['user_id', ])
+ eq_(fkey1['constrained_columns'], ['remote_user_id'])
+ finally:
+ addresses.drop()
+ users.drop()
+
+ def test_get_foreign_keys(self):
+ self._test_get_foreign_keys()
+
+ @testing.requires.schemas
+ def test_get_foreign_keys_with_schema(self):
+ self._test_get_foreign_keys(schema=get_schema())
+
+ def _test_get_indexes(self, schema=None):
+ meta = MetaData(testing.db)
+ (users, addresses) = createTables(meta, schema)
+ meta.create_all()
+ createIndexes(meta.bind, schema)
+ try:
+ # The database may decide to create indexes for foreign keys, etc.
+ # so there may be more indexes than expected.
+ insp = Inspector(meta.bind)
+ indexes = insp.get_indexes('users', schema=schema)
+ indexes.sort()
+ expected_indexes = [
+ {'unique': False,
+ 'column_names': ['test1', 'test2'],
+ 'name': 'users_t_idx'}]
+ index_names = [d['name'] for d in indexes]
+ for e_index in expected_indexes:
+ assert e_index['name'] in index_names
+ index = indexes[index_names.index(e_index['name'])]
+ for key in e_index:
+ eq_(e_index[key], index[key])
+
+ finally:
+ addresses.drop()
+ users.drop()
+
+ def test_get_indexes(self):
+ self._test_get_indexes()
+
+ @testing.requires.schemas
+ def test_get_indexes_with_schema(self):
+ self._test_get_indexes(schema=get_schema())
+
+ def _test_get_view_definition(self, schema=None):
+ meta = MetaData(testing.db)
+ (users, addresses) = createTables(meta, schema)
+ meta.create_all()
+ createViews(meta.bind, schema)
+ view_name1 = 'users_v'
+ view_name2 = 'email_addresses_v'
+ try:
+ insp = Inspector(meta.bind)
+ v1 = insp.get_view_definition(view_name1, schema=schema)
+ self.assert_(v1)
+ v2 = insp.get_view_definition(view_name2, schema=schema)
+ self.assert_(v2)
+ finally:
+ dropViews(meta.bind, schema)
+ addresses.drop()
+ users.drop()
+
+ def test_get_view_definition(self):
+ self._test_get_view_definition()
+
+ @testing.requires.schemas
+ def test_get_view_definition_with_schema(self):
+ self._test_get_view_definition(schema=get_schema())
+
+ def _test_get_table_oid(self, table_name, schema=None):
+ if testing.against('postgresql'):
+ meta = MetaData(testing.db)
+ (users, addresses) = createTables(meta, schema)
+ meta.create_all()
+ try:
+ insp = create_inspector(meta.bind)
+ oid = insp.get_table_oid(table_name, schema)
+ self.assert_(isinstance(oid, (int, long)))
+ finally:
+ addresses.drop()
+ users.drop()
+
+ def test_get_table_oid(self):
+ self._test_get_table_oid('users')
+
+ @testing.requires.schemas
+ def test_get_table_oid_with_schema(self):
+ self._test_get_table_oid('users', schema=get_schema())
+
diff --git a/test/engine/test_transaction.py b/test/engine/test_transaction.py
index 6698259a4..8e3f3412d 100644
--- a/test/engine/test_transaction.py
+++ b/test/engine/test_transaction.py
@@ -20,7 +20,8 @@ class TransactionTest(TestBase):
users.create(testing.db)
def teardown(self):
- testing.db.connect().execute(users.delete())
+ testing.db.execute(users.delete()).close()
+
@classmethod
def teardown_class(cls):
users.drop(testing.db)
@@ -40,6 +41,7 @@ class TransactionTest(TestBase):
result = connection.execute("select * from query_users")
assert len(result.fetchall()) == 3
transaction.commit()
+ connection.close()
def test_rollback(self):
"""test a basic rollback"""
@@ -176,6 +178,7 @@ class TransactionTest(TestBase):
connection.close()
@testing.requires.savepoints
+ @testing.crashes('oracle+zxjdbc', 'Errors out and causes subsequent tests to deadlock')
def test_nested_subtransaction_commit(self):
connection = testing.db.connect()
transaction = connection.begin()
@@ -274,6 +277,7 @@ class TransactionTest(TestBase):
connection.close()
@testing.requires.two_phase_transactions
+ @testing.crashes('mysql+zxjdbc', 'Deadlocks, causing subsequent tests to fail')
@testing.fails_on('mysql', 'FIXME: unknown')
def test_two_phase_recover(self):
# MySQL recovery doesn't currently seem to work correctly
@@ -369,7 +373,7 @@ class ExplicitAutoCommitTest(TestBase):
Requires PostgreSQL so that we may define a custom function which modifies the database.
"""
- __only_on__ = 'postgres'
+ __only_on__ = 'postgresql'
@classmethod
def setup_class(cls):
@@ -380,7 +384,7 @@ class ExplicitAutoCommitTest(TestBase):
testing.db.execute("create function insert_foo(varchar) returns integer as 'insert into foo(data) values ($1);select 1;' language sql")
def teardown(self):
- foo.delete().execute()
+ foo.delete().execute().close()
@classmethod
def teardown_class(cls):
@@ -453,8 +457,10 @@ class TLTransactionTest(TestBase):
test_needs_acid=True,
)
users.create(tlengine)
+
def teardown(self):
- tlengine.execute(users.delete())
+ tlengine.execute(users.delete()).close()
+
@classmethod
def teardown_class(cls):
users.drop(tlengine)
@@ -497,6 +503,7 @@ class TLTransactionTest(TestBase):
try:
assert len(result.fetchall()) == 0
finally:
+ c.close()
external_connection.close()
def test_rollback(self):
@@ -530,7 +537,9 @@ class TLTransactionTest(TestBase):
external_connection.close()
def test_commits(self):
- assert tlengine.connect().execute("select count(1) from query_users").scalar() == 0
+ connection = tlengine.connect()
+ assert connection.execute("select count(1) from query_users").scalar() == 0
+ connection.close()
connection = tlengine.contextual_connect()
transaction = connection.begin()
@@ -547,6 +556,7 @@ class TLTransactionTest(TestBase):
l = result.fetchall()
assert len(l) == 3, "expected 3 got %d" % len(l)
transaction.commit()
+ connection.close()
def test_rollback_off_conn(self):
# test that a TLTransaction opened off a TLConnection allows that
@@ -563,6 +573,7 @@ class TLTransactionTest(TestBase):
try:
assert len(result.fetchall()) == 0
finally:
+ conn.close()
external_connection.close()
def test_morerollback_off_conn(self):
@@ -581,6 +592,8 @@ class TLTransactionTest(TestBase):
try:
assert len(result.fetchall()) == 0
finally:
+ conn.close()
+ conn2.close()
external_connection.close()
def test_commit_off_connection(self):
@@ -596,6 +609,7 @@ class TLTransactionTest(TestBase):
try:
assert len(result.fetchall()) == 3
finally:
+ conn.close()
external_connection.close()
def test_nesting(self):
@@ -712,8 +726,10 @@ class ForUpdateTest(TestBase):
test_needs_acid=True,
)
counters.create(testing.db)
+
def teardown(self):
- testing.db.connect().execute(counters.delete())
+ testing.db.execute(counters.delete()).close()
+
@classmethod
def teardown_class(cls):
counters.drop(testing.db)
@@ -726,7 +742,7 @@ class ForUpdateTest(TestBase):
for i in xrange(count):
trans = con.begin()
try:
- existing = con.execute(sel).fetchone()
+ existing = con.execute(sel).first()
incr = existing['counter_value'] + 1
time.sleep(delay)
@@ -734,7 +750,7 @@ class ForUpdateTest(TestBase):
values={'counter_value':incr}))
time.sleep(delay)
- readback = con.execute(sel).fetchone()
+ readback = con.execute(sel).first()
if (readback['counter_value'] != incr):
raise AssertionError("Got %s post-update, expected %s" %
(readback['counter_value'], incr))
@@ -778,7 +794,7 @@ class ForUpdateTest(TestBase):
self.assert_(len(errors) == 0)
sel = counters.select(whereclause=counters.c.counter_id==1)
- final = db.execute(sel).fetchone()
+ final = db.execute(sel).first()
self.assert_(final['counter_value'] == iterations * thread_count)
def overlap(self, ids, errors, update_style):