summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/test
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/test')
-rw-r--r--lib/sqlalchemy/test/assertsql.py1
-rw-r--r--lib/sqlalchemy/test/config.py50
-rw-r--r--lib/sqlalchemy/test/engines.py60
-rw-r--r--lib/sqlalchemy/test/noseplugin.py43
-rw-r--r--lib/sqlalchemy/test/profiling.py20
-rw-r--r--lib/sqlalchemy/test/requires.py55
-rw-r--r--lib/sqlalchemy/test/schema.py31
-rw-r--r--lib/sqlalchemy/test/testing.py133
-rw-r--r--lib/sqlalchemy/test/util.py24
9 files changed, 298 insertions, 119 deletions
diff --git a/lib/sqlalchemy/test/assertsql.py b/lib/sqlalchemy/test/assertsql.py
index dc2c6d40f..1af28794e 100644
--- a/lib/sqlalchemy/test/assertsql.py
+++ b/lib/sqlalchemy/test/assertsql.py
@@ -3,7 +3,6 @@ from sqlalchemy.interfaces import ConnectionProxy
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.engine.base import Connection
from sqlalchemy import util
-import testing
import re
class AssertRule(object):
diff --git a/lib/sqlalchemy/test/config.py b/lib/sqlalchemy/test/config.py
index 6ea5667cc..eec962d80 100644
--- a/lib/sqlalchemy/test/config.py
+++ b/lib/sqlalchemy/test/config.py
@@ -1,4 +1,8 @@
-import optparse, os, sys, re, ConfigParser, StringIO, time, warnings
+import optparse, os, sys, re, ConfigParser, time, warnings
+
+# 2to3
+import StringIO
+
logging = None
__all__ = 'parser', 'configure', 'options',
@@ -13,7 +17,11 @@ base_config = """
[db]
sqlite=sqlite:///:memory:
sqlite_file=sqlite:///querytest.db
-postgres=postgres://scott:tiger@127.0.0.1:5432/test
+postgresql=postgresql://scott:tiger@127.0.0.1:5432/test
+postgres=postgresql://scott:tiger@127.0.0.1:5432/test
+pg8000=postgresql+pg8000://scott:tiger@127.0.0.1:5432/test
+postgresql_jython=postgresql+zxjdbc://scott:tiger@127.0.0.1:5432/test
+mysql_jython=mysql+zxjdbc://scott:tiger@127.0.0.1:5432/test
mysql=mysql://scott:tiger@127.0.0.1:3306/test
oracle=oracle://scott:tiger@127.0.0.1:1521
oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
@@ -125,28 +133,22 @@ def _prep_testing_database(options, file_config):
from sqlalchemy.test import engines
from sqlalchemy import schema
- try:
- # also create alt schemas etc. here?
- if options.dropfirst:
- e = engines.utf8_engine()
- existing = e.table_names()
- if existing:
- print "Dropping existing tables in database: " + db_url
- try:
- print "Tables: %s" % ', '.join(existing)
- except:
- pass
- print "Abort within 5 seconds..."
- time.sleep(5)
- md = schema.MetaData(e, reflect=True)
- md.drop_all()
- e.dispose()
- except (KeyboardInterrupt, SystemExit):
- raise
- except Exception, e:
- warnings.warn(RuntimeWarning(
- "Error checking for existing tables in testing "
- "database: %s" % e))
+ # also create alt schemas etc. here?
+ if options.dropfirst:
+ e = engines.utf8_engine()
+ existing = e.table_names()
+ if existing:
+ print "Dropping existing tables in database: " + db_url
+ try:
+ print "Tables: %s" % ', '.join(existing)
+ except:
+ pass
+ print "Abort within 5 seconds..."
+ time.sleep(5)
+ md = schema.MetaData(e, reflect=True)
+ md.drop_all()
+ e.dispose()
+
post_configure['prep_db'] = _prep_testing_database
def _set_table_options(options, file_config):
diff --git a/lib/sqlalchemy/test/engines.py b/lib/sqlalchemy/test/engines.py
index f0001978b..187ad2ff0 100644
--- a/lib/sqlalchemy/test/engines.py
+++ b/lib/sqlalchemy/test/engines.py
@@ -2,6 +2,7 @@ import sys, types, weakref
from collections import deque
import config
from sqlalchemy.util import function_named, callable
+import re
class ConnectionKiller(object):
def __init__(self):
@@ -11,7 +12,8 @@ class ConnectionKiller(object):
self.proxy_refs[con_proxy] = True
def _apply_all(self, methods):
- for rec in self.proxy_refs:
+ # must copy keys atomically
+ for rec in self.proxy_refs.keys():
if rec is not None and rec.is_valid:
try:
for name in methods:
@@ -38,6 +40,10 @@ class ConnectionKiller(object):
testing_reaper = ConnectionKiller()
+def drop_all_tables(metadata):
+ testing_reaper.close_all()
+ metadata.drop_all()
+
def assert_conns_closed(fn):
def decorated(*args, **kw):
try:
@@ -56,6 +62,14 @@ def rollback_open_connections(fn):
testing_reaper.rollback_all()
return function_named(decorated, fn.__name__)
+def close_first(fn):
+ """Decorator that closes all connections before fn execution."""
+ def decorated(*args, **kw):
+ testing_reaper.close_all()
+ fn(*args, **kw)
+ return function_named(decorated, fn.__name__)
+
+
def close_open_connections(fn):
"""Decorator that closes all connections after fn execution."""
@@ -69,7 +83,10 @@ def close_open_connections(fn):
def all_dialects():
import sqlalchemy.databases as d
for name in d.__all__:
- mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name)
+ # TEMPORARY
+ mod = getattr(d, name, None)
+ if not mod:
+ mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name)
yield mod.dialect()
class ReconnectFixture(object):
@@ -115,7 +132,11 @@ def testing_engine(url=None, options=None):
listeners.append(testing_reaper)
engine = create_engine(url, **options)
-
+
+ # may want to call this, results
+ # in first-connect initializers
+ #engine.connect()
+
return engine
def utf8_engine(url=None, options=None):
@@ -123,7 +144,7 @@ def utf8_engine(url=None, options=None):
from sqlalchemy.engine import url as engine_url
- if config.db.name == 'mysql':
+ if config.db.driver == 'mysqldb':
dbapi_ver = config.db.dialect.dbapi.version_info
if (dbapi_ver < (1, 2, 1) or
dbapi_ver in ((1, 2, 1, 'gamma', 1), (1, 2, 1, 'gamma', 2),
@@ -139,19 +160,35 @@ def utf8_engine(url=None, options=None):
return testing_engine(url, options)
-def mock_engine(db=None):
- """Provides a mocking engine based on the current testing.db."""
+def mock_engine(dialect_name=None):
+ """Provides a mocking engine based on the current testing.db.
+
+ This is normally used to test DDL generation flow as emitted
+ by an Engine.
+
+ It should not be used in other cases, as assert_compile() and
+ assert_sql_execution() are much better choices with fewer
+ moving parts.
+
+ """
from sqlalchemy import create_engine
- dbi = db or config.db
+ if not dialect_name:
+ dialect_name = config.db.name
+
buffer = []
def executor(sql, *a, **kw):
buffer.append(sql)
- engine = create_engine(dbi.name + '://',
+ def assert_sql(stmts):
+ recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer]
+ assert recv == stmts, recv
+
+ engine = create_engine(dialect_name + '://',
strategy='mock', executor=executor)
assert not hasattr(engine, 'mock')
engine.mock = buffer
+ engine.assert_sql = assert_sql
return engine
class ReplayableSession(object):
@@ -168,9 +205,16 @@ class ReplayableSession(object):
Natives = set([getattr(types, t)
for t in dir(types) if not t.startswith('_')]). \
difference([getattr(types, t)
+ # Py3K
+ #for t in ('FunctionType', 'BuiltinFunctionType',
+ # 'MethodType', 'BuiltinMethodType',
+ # 'LambdaType', )])
+
+ # Py2K
for t in ('FunctionType', 'BuiltinFunctionType',
'MethodType', 'BuiltinMethodType',
'LambdaType', 'UnboundMethodType',)])
+ # end Py2K
def __init__(self):
self.buffer = deque()
diff --git a/lib/sqlalchemy/test/noseplugin.py b/lib/sqlalchemy/test/noseplugin.py
index 263d2d783..c4f32a163 100644
--- a/lib/sqlalchemy/test/noseplugin.py
+++ b/lib/sqlalchemy/test/noseplugin.py
@@ -14,7 +14,7 @@ from config import db, db_label, db_url, file_config, base_config, \
_set_table_options, _reverse_topological, _log
from sqlalchemy.test import testing, config, requires
from nose.plugins import Plugin
-from nose.util import tolist
+from sqlalchemy import util
import nose.case
log = logging.getLogger('nose.plugins.sqlalchemy')
@@ -30,9 +30,6 @@ class NoseSQLAlchemy(Plugin):
def options(self, parser, env=os.environ):
Plugin.options(self, parser, env)
opt = parser.add_option
- #opt("--verbose", action="store_true", dest="verbose",
- #help="enable stdout echoing/printing")
- #opt("--quiet", action="store_true", dest="quiet", help="suppress output")
opt("--log-info", action="callback", type="string", callback=_log,
help="turn on info logging for <LOG> (multiple OK)")
opt("--log-debug", action="callback", type="string", callback=_log,
@@ -77,15 +74,16 @@ class NoseSQLAlchemy(Plugin):
def configure(self, options, conf):
Plugin.configure(self, options, conf)
-
- import testing, requires
+ self.options = options
+
+ def begin(self):
testing.db = db
testing.requires = requires
# Lazy setup of other options (post coverage)
for fn in post_configure:
- fn(options, file_config)
-
+ fn(self.options, file_config)
+
def describeTest(self, test):
return ""
@@ -117,15 +115,20 @@ class NoseSQLAlchemy(Plugin):
if check(test_suite)() != 'ok':
# The requirement will perform messaging.
return True
- if (hasattr(cls, '__unsupported_on__') and
- testing.db.name in cls.__unsupported_on__):
- print "'%s' unsupported on DB implementation '%s'" % (
- cls.__class__.__name__, testing.db.name)
- return True
- if (getattr(cls, '__only_on__', None) not in (None, testing.db.name)):
- print "'%s' unsupported on DB implementation '%s'" % (
- cls.__class__.__name__, testing.db.name)
- return True
+
+ if cls.__unsupported_on__:
+ spec = testing.db_spec(*cls.__unsupported_on__)
+ if spec(testing.db):
+ print "'%s' unsupported on DB implementation '%s'" % (
+ cls.__class__.__name__, testing.db.name)
+ return True
+ if getattr(cls, '__only_on__', None):
+ spec = testing.db_spec(*util.to_list(cls.__only_on__))
+ if not spec(testing.db):
+ print "'%s' unsupported on DB implementation '%s'" % (
+ cls.__class__.__name__, testing.db.name)
+ return True
+
if (getattr(cls, '__skip_if__', False)):
for c in getattr(cls, '__skip_if__'):
if c():
@@ -140,15 +143,15 @@ class NoseSQLAlchemy(Plugin):
return True
return False
- #def begin(self):
- #pass
-
def beforeTest(self, test):
testing.resetwarnings()
def afterTest(self, test):
testing.resetwarnings()
+ def afterContext(self):
+ testing.global_cleanup_assertions()
+
#def handleError(self, test, err):
#pass
diff --git a/lib/sqlalchemy/test/profiling.py b/lib/sqlalchemy/test/profiling.py
index ca4b31cbd..8cab6ceba 100644
--- a/lib/sqlalchemy/test/profiling.py
+++ b/lib/sqlalchemy/test/profiling.py
@@ -6,8 +6,9 @@ in a more fine-grained way than nose's profiling plugin.
"""
import os, sys
-from sqlalchemy.util import function_named
-import config
+from sqlalchemy.test import config
+from sqlalchemy.test.util import function_named, gc_collect
+from nose import SkipTest
__all__ = 'profiled', 'function_call_count', 'conditional_call_count'
@@ -162,15 +163,22 @@ def conditional_call_count(discriminator, categories):
def _profile(filename, fn, *args, **kw):
global profiler
if not profiler:
- profiler = 'hotshot'
if sys.version_info > (2, 5):
try:
import cProfile
profiler = 'cProfile'
except ImportError:
pass
+ if not profiler:
+ try:
+ import hotshot
+ profiler = 'hotshot'
+ except ImportError:
+ profiler = 'skip'
- if profiler == 'cProfile':
+ if profiler == 'skip':
+ raise SkipTest('Profiling not supported on this platform')
+ elif profiler == 'cProfile':
return _profile_cProfile(filename, fn, *args, **kw)
else:
return _profile_hotshot(filename, fn, *args, **kw)
@@ -179,7 +187,7 @@ def _profile_cProfile(filename, fn, *args, **kw):
import cProfile, gc, pstats, time
load_stats = lambda: pstats.Stats(filename)
- gc.collect()
+ gc_collect()
began = time.time()
cProfile.runctx('result = fn(*args, **kw)', globals(), locals(),
@@ -192,7 +200,7 @@ def _profile_hotshot(filename, fn, *args, **kw):
import gc, hotshot, hotshot.stats, time
load_stats = lambda: hotshot.stats.load(filename)
- gc.collect()
+ gc_collect()
prof = hotshot.Profile(filename)
began = time.time()
prof.start()
diff --git a/lib/sqlalchemy/test/requires.py b/lib/sqlalchemy/test/requires.py
index b23b8620d..f3f4ec191 100644
--- a/lib/sqlalchemy/test/requires.py
+++ b/lib/sqlalchemy/test/requires.py
@@ -28,6 +28,25 @@ def foreign_keys(fn):
no_support('sqlite', 'not supported by database'),
)
+
+def unbounded_varchar(fn):
+ """Target database must support VARCHAR with no length"""
+ return _chain_decorators_on(
+ fn,
+ no_support('firebird', 'not supported by database'),
+ no_support('oracle', 'not supported by database'),
+ no_support('mysql', 'not supported by database'),
+ )
+
+def boolean_col_expressions(fn):
+ """Target database must support boolean expressions as columns"""
+ return _chain_decorators_on(
+ fn,
+ no_support('firebird', 'not supported by database'),
+ no_support('oracle', 'not supported by database'),
+ no_support('mssql', 'not supported by database'),
+ )
+
def identity(fn):
"""Target database must support GENERATED AS IDENTITY or a facsimile.
@@ -40,7 +59,7 @@ def identity(fn):
fn,
no_support('firebird', 'not supported by database'),
no_support('oracle', 'not supported by database'),
- no_support('postgres', 'not supported by database'),
+ no_support('postgresql', 'not supported by database'),
no_support('sybase', 'not supported by database'),
)
@@ -61,9 +80,19 @@ def row_triggers(fn):
# no access to same table
no_support('mysql', 'requires SUPER priv'),
exclude('mysql', '<', (5, 0, 10), 'not supported by database'),
- no_support('postgres', 'not supported by database: no statements'),
+
+ # huh? TODO: implement triggers for PG tests, remove this
+ no_support('postgresql', 'PG triggers need to be implemented for tests'),
)
+def correlated_outer_joins(fn):
+ """Target must support an outer join to a subquery which correlates to the parent."""
+
+ return _chain_decorators_on(
+ fn,
+ no_support('oracle', 'Raises "ORA-01799: a column may not be outer-joined to a subquery"')
+ )
+
def savepoints(fn):
"""Target database must support savepoints."""
return _chain_decorators_on(
@@ -75,6 +104,15 @@ def savepoints(fn):
exclude('mysql', '<', (5, 0, 3), 'not supported by database'),
)
+def schemas(fn):
+ """Target database must support external schemas, and have one named 'test_schema'."""
+
+ return _chain_decorators_on(
+ fn,
+ no_support('sqlite', 'no schema support'),
+ no_support('firebird', 'no schema support')
+ )
+
def sequences(fn):
"""Target database must support SEQUENCEs."""
return _chain_decorators_on(
@@ -93,6 +131,17 @@ def subqueries(fn):
exclude('mysql', '<', (4, 1, 1), 'no subquery support'),
)
+def returning(fn):
+ return _chain_decorators_on(
+ fn,
+ no_support('access', 'not supported by database'),
+ no_support('sqlite', 'not supported by database'),
+ no_support('mysql', 'not supported by database'),
+ no_support('maxdb', 'not supported by database'),
+ no_support('sybase', 'not supported by database'),
+ no_support('informix', 'not supported by database'),
+ )
+
def two_phase_transactions(fn):
"""Target database must support two-phase transactions."""
return _chain_decorators_on(
@@ -104,6 +153,8 @@ def two_phase_transactions(fn):
no_support('oracle', 'no SA implementation'),
no_support('sqlite', 'not supported by database'),
no_support('sybase', 'FIXME: guessing, needs confirmation'),
+ no_support('postgresql+zxjdbc', 'FIXME: JDBC driver confuses the transaction state, may '
+ 'need separate XA implementation'),
exclude('mysql', '<', (5, 0, 3), 'not supported by database'),
)
diff --git a/lib/sqlalchemy/test/schema.py b/lib/sqlalchemy/test/schema.py
index f96805fe4..35b4060d2 100644
--- a/lib/sqlalchemy/test/schema.py
+++ b/lib/sqlalchemy/test/schema.py
@@ -33,7 +33,7 @@ def Table(*args, **kw):
# expand to ForeignKeyConstraint too.
fks = [fk
for col in args if isinstance(col, schema.Column)
- for fk in col.args if isinstance(fk, schema.ForeignKey)]
+ for fk in col.foreign_keys]
for fk in fks:
# root around in raw spec
@@ -51,13 +51,6 @@ def Table(*args, **kw):
if fk.onupdate is None:
fk.onupdate = 'CASCADE'
- if testing.against('firebird', 'oracle'):
- pk_seqs = [col for col in args
- if (isinstance(col, schema.Column)
- and col.primary_key
- and getattr(col, '_needs_autoincrement', False))]
- for c in pk_seqs:
- c.args.append(schema.Sequence(args[0] + '_' + c.name + '_seq', optional=True))
return schema.Table(*args, **kw)
@@ -67,8 +60,20 @@ def Column(*args, **kw):
test_opts = dict([(k,kw.pop(k)) for k in kw.keys()
if k.startswith('test_')])
- c = schema.Column(*args, **kw)
- if testing.against('firebird', 'oracle'):
- if 'test_needs_autoincrement' in test_opts:
- c._needs_autoincrement = True
- return c
+ col = schema.Column(*args, **kw)
+ if 'test_needs_autoincrement' in test_opts and \
+ kw.get('primary_key', False) and \
+ testing.against('firebird', 'oracle'):
+ def add_seq(tbl):
+ col._init_items(
+ schema.Sequence(_truncate_name(testing.db.dialect, tbl.name + '_' + col.name + '_seq'), optional=True)
+ )
+ col._on_table_attach(add_seq)
+ return col
+
+def _truncate_name(dialect, name):
+ if len(name) > dialect.max_identifier_length:
+ return name[0:max(dialect.max_identifier_length - 6, 0)] + "_" + hex(hash(name) % 64)[2:]
+ else:
+ return name
+ \ No newline at end of file
diff --git a/lib/sqlalchemy/test/testing.py b/lib/sqlalchemy/test/testing.py
index 36c7d340a..16a13d9d3 100644
--- a/lib/sqlalchemy/test/testing.py
+++ b/lib/sqlalchemy/test/testing.py
@@ -8,10 +8,12 @@ import types
import warnings
from cStringIO import StringIO
-from sqlalchemy.test import config, assertsql
+from sqlalchemy.test import config, assertsql, util as testutil
from sqlalchemy.util import function_named
+from engines import drop_all_tables
-from sqlalchemy import exc as sa_exc, util, types as sqltypes, schema
+from sqlalchemy import exc as sa_exc, util, types as sqltypes, schema, pool
+from nose import SkipTest
_ops = { '<': operator.lt,
'>': operator.gt,
@@ -80,6 +82,19 @@ def future(fn):
"Unexpected success for future test '%s'" % fn_name)
return function_named(decorated, fn_name)
+def db_spec(*dbs):
+ dialects = set([x for x in dbs if '+' not in x])
+ drivers = set([x[1:] for x in dbs if x.startswith('+')])
+ specs = set([tuple(x.split('+')) for x in dbs if '+' in x and x not in drivers])
+
+ def check(engine):
+ return engine.name in dialects or \
+ engine.driver in drivers or \
+ (engine.name, engine.driver) in specs
+
+ return check
+
+
def fails_on(dbs, reason):
"""Mark a test as expected to fail on the specified database
implementation.
@@ -90,23 +105,25 @@ def fails_on(dbs, reason):
succeeds, a failure is reported.
"""
+ spec = db_spec(dbs)
+
def decorate(fn):
fn_name = fn.__name__
def maybe(*args, **kw):
- if config.db.name != dbs:
+ if not spec(config.db):
return fn(*args, **kw)
else:
try:
fn(*args, **kw)
except Exception, ex:
print ("'%s' failed as expected on DB implementation "
- "'%s': %s" % (
- fn_name, config.db.name, reason))
+ "'%s+%s': %s" % (
+ fn_name, config.db.name, config.db.driver, reason))
return True
else:
raise AssertionError(
- "Unexpected success for '%s' on DB implementation '%s'" %
- (fn_name, config.db.name))
+ "Unexpected success for '%s' on DB implementation '%s+%s'" %
+ (fn_name, config.db.name, config.db.driver))
return function_named(maybe, fn_name)
return decorate
@@ -117,23 +134,25 @@ def fails_on_everything_except(*dbs):
databases except those listed.
"""
+ spec = db_spec(*dbs)
+
def decorate(fn):
fn_name = fn.__name__
def maybe(*args, **kw):
- if config.db.name in dbs:
+ if spec(config.db):
return fn(*args, **kw)
else:
try:
fn(*args, **kw)
except Exception, ex:
print ("'%s' failed as expected on DB implementation "
- "'%s': %s" % (
- fn_name, config.db.name, str(ex)))
+ "'%s+%s': %s" % (
+ fn_name, config.db.name, config.db.driver, str(ex)))
return True
else:
raise AssertionError(
- "Unexpected success for '%s' on DB implementation '%s'" %
- (fn_name, config.db.name))
+ "Unexpected success for '%s' on DB implementation '%s+%s'" %
+ (fn_name, config.db.name, config.db.driver))
return function_named(maybe, fn_name)
return decorate
@@ -145,12 +164,13 @@ def crashes(db, reason):
"""
carp = _should_carp_about_exclusion(reason)
+ spec = db_spec(db)
def decorate(fn):
fn_name = fn.__name__
def maybe(*args, **kw):
- if config.db.name == db:
- msg = "'%s' unsupported on DB implementation '%s': %s" % (
- fn_name, config.db.name, reason)
+ if spec(config.db):
+ msg = "'%s' unsupported on DB implementation '%s+%s': %s" % (
+ fn_name, config.db.name, config.db.driver, reason)
print msg
if carp:
print >> sys.stderr, msg
@@ -169,12 +189,13 @@ def _block_unconditionally(db, reason):
"""
carp = _should_carp_about_exclusion(reason)
+ spec = db_spec(db)
def decorate(fn):
fn_name = fn.__name__
def maybe(*args, **kw):
- if config.db.name == db:
- msg = "'%s' unsupported on DB implementation '%s': %s" % (
- fn_name, config.db.name, reason)
+ if spec(config.db):
+ msg = "'%s' unsupported on DB implementation '%s+%s': %s" % (
+ fn_name, config.db.name, config.db.driver, reason)
print msg
if carp:
print >> sys.stderr, msg
@@ -198,6 +219,7 @@ def exclude(db, op, spec, reason):
"""
carp = _should_carp_about_exclusion(reason)
+
def decorate(fn):
fn_name = fn.__name__
def maybe(*args, **kw):
@@ -242,7 +264,9 @@ def _is_excluded(db, op, spec):
_is_excluded('yikesdb', 'in', ((0, 3, 'alpha2'), (0, 3, 'alpha3')))
"""
- if config.db.name != db:
+ vendor_spec = db_spec(db)
+
+ if not vendor_spec(config.db):
return False
version = _server_version()
@@ -255,7 +279,12 @@ def _server_version(bind=None):
if bind is None:
bind = config.db
- return bind.dialect.server_version_info(bind.contextual_connect())
+
+ # force metadata to be retrieved
+ conn = bind.connect()
+ version = getattr(bind.dialect, 'server_version_info', ())
+ conn.close()
+ return version
def skip_if(predicate, reason=None):
"""Skip a test if predicate is true."""
@@ -266,8 +295,7 @@ def skip_if(predicate, reason=None):
if predicate():
msg = "'%s' skipped on DB %s version '%s': %s" % (
fn_name, config.db.name, _server_version(), reason)
- print msg
- return True
+ raise SkipTest(msg)
else:
return fn(*args, **kw)
return function_named(maybe, fn_name)
@@ -315,10 +343,12 @@ def emits_warning_on(db, *warnings):
strings; these will be matched to the root of the warning description by
warnings.filterwarnings().
"""
+ spec = db_spec(db)
+
def decorate(fn):
def maybe(*args, **kw):
if isinstance(db, basestring):
- if config.db.name != db:
+ if not spec(config.db):
return fn(*args, **kw)
else:
wrapped = emits_warning(*warnings)(fn)
@@ -384,6 +414,19 @@ def resetwarnings():
if sys.version_info < (2, 4):
warnings.filterwarnings('ignore', category=FutureWarning)
+def global_cleanup_assertions():
+ """Check things that have to be finalized at the end of a test suite.
+
+ Hardcoded at the moment, a modular system can be built here
+ to support things like PG prepared transactions, tables all
+ dropped, etc.
+
+ """
+
+ testutil.lazy_gc()
+ assert not pool._refs
+
+
def against(*queries):
"""Boolean predicate, compares to testing database configuration.
@@ -394,21 +437,20 @@ def against(*queries):
Also supports comparison to database version when provided with one or
more 3-tuples of dialect name, operator, and version specification::
- testing.against('mysql', 'postgres')
+ testing.against('mysql', 'postgresql')
testing.against(('mysql', '>=', (5, 0, 0))
"""
for query in queries:
if isinstance(query, basestring):
- if config.db.name == query:
+ if db_spec(query)(config.db):
return True
else:
name, op, spec = query
- if config.db.name != name:
+ if not db_spec(name)(config.db):
continue
- have = config.db.dialect.server_version_info(
- config.db.contextual_connect())
+ have = _server_version()
oper = hasattr(op, '__call__') and op or _ops[op]
if oper(have, spec):
@@ -545,16 +587,15 @@ class AssertsCompiledSQL(object):
if dialect is None:
dialect = getattr(self, '__dialect__', None)
- if params is None:
- keys = None
- else:
- keys = params.keys()
+ kw = {}
+ if params is not None:
+ kw['column_keys'] = params.keys()
- c = clause.compile(column_keys=keys, dialect=dialect)
+ c = clause.compile(dialect=dialect, **kw)
- print "\nSQL String:\n" + str(c) + repr(c.params)
+ print "\nSQL String:\n" + str(c) + repr(getattr(c, 'params', {}))
- cc = re.sub(r'\n', '', str(c))
+ cc = re.sub(r'[\n\t]', '', str(c))
eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
@@ -563,18 +604,13 @@ class AssertsCompiledSQL(object):
class ComparesTables(object):
def assert_tables_equal(self, table, reflected_table):
- base_mro = sqltypes.TypeEngine.__mro__
assert len(table.c) == len(reflected_table.c)
for c, reflected_c in zip(table.c, reflected_table.c):
eq_(c.name, reflected_c.name)
assert reflected_c is reflected_table.c[c.name]
eq_(c.primary_key, reflected_c.primary_key)
eq_(c.nullable, reflected_c.nullable)
- assert len(
- set(type(reflected_c.type).__mro__).difference(base_mro).intersection(
- set(type(c.type).__mro__).difference(base_mro)
- )
- ) > 0, "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type)
+ self.assert_types_base(reflected_c, c)
if isinstance(c.type, sqltypes.String):
eq_(c.type.length, reflected_c.type.length)
@@ -586,14 +622,21 @@ class ComparesTables(object):
elif against(('mysql', '<', (5, 0))):
# ignore reflection of bogus db-generated DefaultClause()
pass
- elif not c.primary_key or not against('postgres'):
- print repr(c)
+ elif not c.primary_key or not against('postgresql', 'mssql'):
+ #print repr(c)
assert reflected_c.default is None, reflected_c.default
assert len(table.primary_key) == len(reflected_table.primary_key)
for c in table.primary_key:
assert reflected_table.primary_key.columns[c.name]
-
+
+ def assert_types_base(self, c1, c2):
+ base_mro = sqltypes.TypeEngine.__mro__
+ assert len(
+ set(type(c1.type).__mro__).difference(base_mro).intersection(
+ set(type(c2.type).__mro__).difference(base_mro)
+ )
+ ) > 0, "On column %r, type '%s' doesn't correspond to type '%s'" % (c1.name, c1.type, c2.type)
class AssertsExecutionResults(object):
def assert_result(self, result, class_, *objects):
@@ -678,7 +721,7 @@ class AssertsExecutionResults(object):
assertsql.asserter.clear_rules()
def assert_sql(self, db, callable_, list_, with_sequences=None):
- if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgres'):
+ if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgresql'):
rules = with_sequences
else:
rules = list_
diff --git a/lib/sqlalchemy/test/util.py b/lib/sqlalchemy/test/util.py
new file mode 100644
index 000000000..60b0a4ef8
--- /dev/null
+++ b/lib/sqlalchemy/test/util.py
@@ -0,0 +1,24 @@
+from sqlalchemy.util import jython, function_named
+
+import gc
+import time
+
+if jython:
+ def gc_collect(*args):
+ """aggressive gc.collect for tests."""
+ gc.collect()
+ time.sleep(0.1)
+ gc.collect()
+ gc.collect()
+ return 0
+
+ # "lazy" gc, for VM's that don't GC on refcount == 0
+ lazy_gc = gc_collect
+
+else:
+ # assume CPython - straight gc.collect, lazy_gc() is a pass
+ gc_collect = gc.collect
+ def lazy_gc():
+ pass
+
+