diff options
Diffstat (limited to 'lib/sqlalchemy/test')
| -rw-r--r-- | lib/sqlalchemy/test/assertsql.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/test/config.py | 50 | ||||
| -rw-r--r-- | lib/sqlalchemy/test/engines.py | 60 | ||||
| -rw-r--r-- | lib/sqlalchemy/test/noseplugin.py | 43 | ||||
| -rw-r--r-- | lib/sqlalchemy/test/profiling.py | 20 | ||||
| -rw-r--r-- | lib/sqlalchemy/test/requires.py | 55 | ||||
| -rw-r--r-- | lib/sqlalchemy/test/schema.py | 31 | ||||
| -rw-r--r-- | lib/sqlalchemy/test/testing.py | 133 | ||||
| -rw-r--r-- | lib/sqlalchemy/test/util.py | 24 |
9 files changed, 298 insertions, 119 deletions
diff --git a/lib/sqlalchemy/test/assertsql.py b/lib/sqlalchemy/test/assertsql.py index dc2c6d40f..1af28794e 100644 --- a/lib/sqlalchemy/test/assertsql.py +++ b/lib/sqlalchemy/test/assertsql.py @@ -3,7 +3,6 @@ from sqlalchemy.interfaces import ConnectionProxy from sqlalchemy.engine.default import DefaultDialect from sqlalchemy.engine.base import Connection from sqlalchemy import util -import testing import re class AssertRule(object): diff --git a/lib/sqlalchemy/test/config.py b/lib/sqlalchemy/test/config.py index 6ea5667cc..eec962d80 100644 --- a/lib/sqlalchemy/test/config.py +++ b/lib/sqlalchemy/test/config.py @@ -1,4 +1,8 @@ -import optparse, os, sys, re, ConfigParser, StringIO, time, warnings +import optparse, os, sys, re, ConfigParser, time, warnings + +# 2to3 +import StringIO + logging = None __all__ = 'parser', 'configure', 'options', @@ -13,7 +17,11 @@ base_config = """ [db] sqlite=sqlite:///:memory: sqlite_file=sqlite:///querytest.db -postgres=postgres://scott:tiger@127.0.0.1:5432/test +postgresql=postgresql://scott:tiger@127.0.0.1:5432/test +postgres=postgresql://scott:tiger@127.0.0.1:5432/test +pg8000=postgresql+pg8000://scott:tiger@127.0.0.1:5432/test +postgresql_jython=postgresql+zxjdbc://scott:tiger@127.0.0.1:5432/test +mysql_jython=mysql+zxjdbc://scott:tiger@127.0.0.1:5432/test mysql=mysql://scott:tiger@127.0.0.1:3306/test oracle=oracle://scott:tiger@127.0.0.1:1521 oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0 @@ -125,28 +133,22 @@ def _prep_testing_database(options, file_config): from sqlalchemy.test import engines from sqlalchemy import schema - try: - # also create alt schemas etc. here? - if options.dropfirst: - e = engines.utf8_engine() - existing = e.table_names() - if existing: - print "Dropping existing tables in database: " + db_url - try: - print "Tables: %s" % ', '.join(existing) - except: - pass - print "Abort within 5 seconds..." - time.sleep(5) - md = schema.MetaData(e, reflect=True) - md.drop_all() - e.dispose() - except (KeyboardInterrupt, SystemExit): - raise - except Exception, e: - warnings.warn(RuntimeWarning( - "Error checking for existing tables in testing " - "database: %s" % e)) + # also create alt schemas etc. here? + if options.dropfirst: + e = engines.utf8_engine() + existing = e.table_names() + if existing: + print "Dropping existing tables in database: " + db_url + try: + print "Tables: %s" % ', '.join(existing) + except: + pass + print "Abort within 5 seconds..." + time.sleep(5) + md = schema.MetaData(e, reflect=True) + md.drop_all() + e.dispose() + post_configure['prep_db'] = _prep_testing_database def _set_table_options(options, file_config): diff --git a/lib/sqlalchemy/test/engines.py b/lib/sqlalchemy/test/engines.py index f0001978b..187ad2ff0 100644 --- a/lib/sqlalchemy/test/engines.py +++ b/lib/sqlalchemy/test/engines.py @@ -2,6 +2,7 @@ import sys, types, weakref from collections import deque import config from sqlalchemy.util import function_named, callable +import re class ConnectionKiller(object): def __init__(self): @@ -11,7 +12,8 @@ class ConnectionKiller(object): self.proxy_refs[con_proxy] = True def _apply_all(self, methods): - for rec in self.proxy_refs: + # must copy keys atomically + for rec in self.proxy_refs.keys(): if rec is not None and rec.is_valid: try: for name in methods: @@ -38,6 +40,10 @@ class ConnectionKiller(object): testing_reaper = ConnectionKiller() +def drop_all_tables(metadata): + testing_reaper.close_all() + metadata.drop_all() + def assert_conns_closed(fn): def decorated(*args, **kw): try: @@ -56,6 +62,14 @@ def rollback_open_connections(fn): testing_reaper.rollback_all() return function_named(decorated, fn.__name__) +def close_first(fn): + """Decorator that closes all connections before fn execution.""" + def decorated(*args, **kw): + testing_reaper.close_all() + fn(*args, **kw) + return function_named(decorated, fn.__name__) + + def close_open_connections(fn): """Decorator that closes all connections after fn execution.""" @@ -69,7 +83,10 @@ def close_open_connections(fn): def all_dialects(): import sqlalchemy.databases as d for name in d.__all__: - mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name) + # TEMPORARY + mod = getattr(d, name, None) + if not mod: + mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name) yield mod.dialect() class ReconnectFixture(object): @@ -115,7 +132,11 @@ def testing_engine(url=None, options=None): listeners.append(testing_reaper) engine = create_engine(url, **options) - + + # may want to call this, results + # in first-connect initializers + #engine.connect() + return engine def utf8_engine(url=None, options=None): @@ -123,7 +144,7 @@ def utf8_engine(url=None, options=None): from sqlalchemy.engine import url as engine_url - if config.db.name == 'mysql': + if config.db.driver == 'mysqldb': dbapi_ver = config.db.dialect.dbapi.version_info if (dbapi_ver < (1, 2, 1) or dbapi_ver in ((1, 2, 1, 'gamma', 1), (1, 2, 1, 'gamma', 2), @@ -139,19 +160,35 @@ def utf8_engine(url=None, options=None): return testing_engine(url, options) -def mock_engine(db=None): - """Provides a mocking engine based on the current testing.db.""" +def mock_engine(dialect_name=None): + """Provides a mocking engine based on the current testing.db. + + This is normally used to test DDL generation flow as emitted + by an Engine. + + It should not be used in other cases, as assert_compile() and + assert_sql_execution() are much better choices with fewer + moving parts. + + """ from sqlalchemy import create_engine - dbi = db or config.db + if not dialect_name: + dialect_name = config.db.name + buffer = [] def executor(sql, *a, **kw): buffer.append(sql) - engine = create_engine(dbi.name + '://', + def assert_sql(stmts): + recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer] + assert recv == stmts, recv + + engine = create_engine(dialect_name + '://', strategy='mock', executor=executor) assert not hasattr(engine, 'mock') engine.mock = buffer + engine.assert_sql = assert_sql return engine class ReplayableSession(object): @@ -168,9 +205,16 @@ class ReplayableSession(object): Natives = set([getattr(types, t) for t in dir(types) if not t.startswith('_')]). \ difference([getattr(types, t) + # Py3K + #for t in ('FunctionType', 'BuiltinFunctionType', + # 'MethodType', 'BuiltinMethodType', + # 'LambdaType', )]) + + # Py2K for t in ('FunctionType', 'BuiltinFunctionType', 'MethodType', 'BuiltinMethodType', 'LambdaType', 'UnboundMethodType',)]) + # end Py2K def __init__(self): self.buffer = deque() diff --git a/lib/sqlalchemy/test/noseplugin.py b/lib/sqlalchemy/test/noseplugin.py index 263d2d783..c4f32a163 100644 --- a/lib/sqlalchemy/test/noseplugin.py +++ b/lib/sqlalchemy/test/noseplugin.py @@ -14,7 +14,7 @@ from config import db, db_label, db_url, file_config, base_config, \ _set_table_options, _reverse_topological, _log from sqlalchemy.test import testing, config, requires from nose.plugins import Plugin -from nose.util import tolist +from sqlalchemy import util import nose.case log = logging.getLogger('nose.plugins.sqlalchemy') @@ -30,9 +30,6 @@ class NoseSQLAlchemy(Plugin): def options(self, parser, env=os.environ): Plugin.options(self, parser, env) opt = parser.add_option - #opt("--verbose", action="store_true", dest="verbose", - #help="enable stdout echoing/printing") - #opt("--quiet", action="store_true", dest="quiet", help="suppress output") opt("--log-info", action="callback", type="string", callback=_log, help="turn on info logging for <LOG> (multiple OK)") opt("--log-debug", action="callback", type="string", callback=_log, @@ -77,15 +74,16 @@ class NoseSQLAlchemy(Plugin): def configure(self, options, conf): Plugin.configure(self, options, conf) - - import testing, requires + self.options = options + + def begin(self): testing.db = db testing.requires = requires # Lazy setup of other options (post coverage) for fn in post_configure: - fn(options, file_config) - + fn(self.options, file_config) + def describeTest(self, test): return "" @@ -117,15 +115,20 @@ class NoseSQLAlchemy(Plugin): if check(test_suite)() != 'ok': # The requirement will perform messaging. return True - if (hasattr(cls, '__unsupported_on__') and - testing.db.name in cls.__unsupported_on__): - print "'%s' unsupported on DB implementation '%s'" % ( - cls.__class__.__name__, testing.db.name) - return True - if (getattr(cls, '__only_on__', None) not in (None, testing.db.name)): - print "'%s' unsupported on DB implementation '%s'" % ( - cls.__class__.__name__, testing.db.name) - return True + + if cls.__unsupported_on__: + spec = testing.db_spec(*cls.__unsupported_on__) + if spec(testing.db): + print "'%s' unsupported on DB implementation '%s'" % ( + cls.__class__.__name__, testing.db.name) + return True + if getattr(cls, '__only_on__', None): + spec = testing.db_spec(*util.to_list(cls.__only_on__)) + if not spec(testing.db): + print "'%s' unsupported on DB implementation '%s'" % ( + cls.__class__.__name__, testing.db.name) + return True + if (getattr(cls, '__skip_if__', False)): for c in getattr(cls, '__skip_if__'): if c(): @@ -140,15 +143,15 @@ class NoseSQLAlchemy(Plugin): return True return False - #def begin(self): - #pass - def beforeTest(self, test): testing.resetwarnings() def afterTest(self, test): testing.resetwarnings() + def afterContext(self): + testing.global_cleanup_assertions() + #def handleError(self, test, err): #pass diff --git a/lib/sqlalchemy/test/profiling.py b/lib/sqlalchemy/test/profiling.py index ca4b31cbd..8cab6ceba 100644 --- a/lib/sqlalchemy/test/profiling.py +++ b/lib/sqlalchemy/test/profiling.py @@ -6,8 +6,9 @@ in a more fine-grained way than nose's profiling plugin. """ import os, sys -from sqlalchemy.util import function_named -import config +from sqlalchemy.test import config +from sqlalchemy.test.util import function_named, gc_collect +from nose import SkipTest __all__ = 'profiled', 'function_call_count', 'conditional_call_count' @@ -162,15 +163,22 @@ def conditional_call_count(discriminator, categories): def _profile(filename, fn, *args, **kw): global profiler if not profiler: - profiler = 'hotshot' if sys.version_info > (2, 5): try: import cProfile profiler = 'cProfile' except ImportError: pass + if not profiler: + try: + import hotshot + profiler = 'hotshot' + except ImportError: + profiler = 'skip' - if profiler == 'cProfile': + if profiler == 'skip': + raise SkipTest('Profiling not supported on this platform') + elif profiler == 'cProfile': return _profile_cProfile(filename, fn, *args, **kw) else: return _profile_hotshot(filename, fn, *args, **kw) @@ -179,7 +187,7 @@ def _profile_cProfile(filename, fn, *args, **kw): import cProfile, gc, pstats, time load_stats = lambda: pstats.Stats(filename) - gc.collect() + gc_collect() began = time.time() cProfile.runctx('result = fn(*args, **kw)', globals(), locals(), @@ -192,7 +200,7 @@ def _profile_hotshot(filename, fn, *args, **kw): import gc, hotshot, hotshot.stats, time load_stats = lambda: hotshot.stats.load(filename) - gc.collect() + gc_collect() prof = hotshot.Profile(filename) began = time.time() prof.start() diff --git a/lib/sqlalchemy/test/requires.py b/lib/sqlalchemy/test/requires.py index b23b8620d..f3f4ec191 100644 --- a/lib/sqlalchemy/test/requires.py +++ b/lib/sqlalchemy/test/requires.py @@ -28,6 +28,25 @@ def foreign_keys(fn): no_support('sqlite', 'not supported by database'), ) + +def unbounded_varchar(fn): + """Target database must support VARCHAR with no length""" + return _chain_decorators_on( + fn, + no_support('firebird', 'not supported by database'), + no_support('oracle', 'not supported by database'), + no_support('mysql', 'not supported by database'), + ) + +def boolean_col_expressions(fn): + """Target database must support boolean expressions as columns""" + return _chain_decorators_on( + fn, + no_support('firebird', 'not supported by database'), + no_support('oracle', 'not supported by database'), + no_support('mssql', 'not supported by database'), + ) + def identity(fn): """Target database must support GENERATED AS IDENTITY or a facsimile. @@ -40,7 +59,7 @@ def identity(fn): fn, no_support('firebird', 'not supported by database'), no_support('oracle', 'not supported by database'), - no_support('postgres', 'not supported by database'), + no_support('postgresql', 'not supported by database'), no_support('sybase', 'not supported by database'), ) @@ -61,9 +80,19 @@ def row_triggers(fn): # no access to same table no_support('mysql', 'requires SUPER priv'), exclude('mysql', '<', (5, 0, 10), 'not supported by database'), - no_support('postgres', 'not supported by database: no statements'), + + # huh? TODO: implement triggers for PG tests, remove this + no_support('postgresql', 'PG triggers need to be implemented for tests'), ) +def correlated_outer_joins(fn): + """Target must support an outer join to a subquery which correlates to the parent.""" + + return _chain_decorators_on( + fn, + no_support('oracle', 'Raises "ORA-01799: a column may not be outer-joined to a subquery"') + ) + def savepoints(fn): """Target database must support savepoints.""" return _chain_decorators_on( @@ -75,6 +104,15 @@ def savepoints(fn): exclude('mysql', '<', (5, 0, 3), 'not supported by database'), ) +def schemas(fn): + """Target database must support external schemas, and have one named 'test_schema'.""" + + return _chain_decorators_on( + fn, + no_support('sqlite', 'no schema support'), + no_support('firebird', 'no schema support') + ) + def sequences(fn): """Target database must support SEQUENCEs.""" return _chain_decorators_on( @@ -93,6 +131,17 @@ def subqueries(fn): exclude('mysql', '<', (4, 1, 1), 'no subquery support'), ) +def returning(fn): + return _chain_decorators_on( + fn, + no_support('access', 'not supported by database'), + no_support('sqlite', 'not supported by database'), + no_support('mysql', 'not supported by database'), + no_support('maxdb', 'not supported by database'), + no_support('sybase', 'not supported by database'), + no_support('informix', 'not supported by database'), + ) + def two_phase_transactions(fn): """Target database must support two-phase transactions.""" return _chain_decorators_on( @@ -104,6 +153,8 @@ def two_phase_transactions(fn): no_support('oracle', 'no SA implementation'), no_support('sqlite', 'not supported by database'), no_support('sybase', 'FIXME: guessing, needs confirmation'), + no_support('postgresql+zxjdbc', 'FIXME: JDBC driver confuses the transaction state, may ' + 'need separate XA implementation'), exclude('mysql', '<', (5, 0, 3), 'not supported by database'), ) diff --git a/lib/sqlalchemy/test/schema.py b/lib/sqlalchemy/test/schema.py index f96805fe4..35b4060d2 100644 --- a/lib/sqlalchemy/test/schema.py +++ b/lib/sqlalchemy/test/schema.py @@ -33,7 +33,7 @@ def Table(*args, **kw): # expand to ForeignKeyConstraint too. fks = [fk for col in args if isinstance(col, schema.Column) - for fk in col.args if isinstance(fk, schema.ForeignKey)] + for fk in col.foreign_keys] for fk in fks: # root around in raw spec @@ -51,13 +51,6 @@ def Table(*args, **kw): if fk.onupdate is None: fk.onupdate = 'CASCADE' - if testing.against('firebird', 'oracle'): - pk_seqs = [col for col in args - if (isinstance(col, schema.Column) - and col.primary_key - and getattr(col, '_needs_autoincrement', False))] - for c in pk_seqs: - c.args.append(schema.Sequence(args[0] + '_' + c.name + '_seq', optional=True)) return schema.Table(*args, **kw) @@ -67,8 +60,20 @@ def Column(*args, **kw): test_opts = dict([(k,kw.pop(k)) for k in kw.keys() if k.startswith('test_')]) - c = schema.Column(*args, **kw) - if testing.against('firebird', 'oracle'): - if 'test_needs_autoincrement' in test_opts: - c._needs_autoincrement = True - return c + col = schema.Column(*args, **kw) + if 'test_needs_autoincrement' in test_opts and \ + kw.get('primary_key', False) and \ + testing.against('firebird', 'oracle'): + def add_seq(tbl): + col._init_items( + schema.Sequence(_truncate_name(testing.db.dialect, tbl.name + '_' + col.name + '_seq'), optional=True) + ) + col._on_table_attach(add_seq) + return col + +def _truncate_name(dialect, name): + if len(name) > dialect.max_identifier_length: + return name[0:max(dialect.max_identifier_length - 6, 0)] + "_" + hex(hash(name) % 64)[2:] + else: + return name +
\ No newline at end of file diff --git a/lib/sqlalchemy/test/testing.py b/lib/sqlalchemy/test/testing.py index 36c7d340a..16a13d9d3 100644 --- a/lib/sqlalchemy/test/testing.py +++ b/lib/sqlalchemy/test/testing.py @@ -8,10 +8,12 @@ import types import warnings from cStringIO import StringIO -from sqlalchemy.test import config, assertsql +from sqlalchemy.test import config, assertsql, util as testutil from sqlalchemy.util import function_named +from engines import drop_all_tables -from sqlalchemy import exc as sa_exc, util, types as sqltypes, schema +from sqlalchemy import exc as sa_exc, util, types as sqltypes, schema, pool +from nose import SkipTest _ops = { '<': operator.lt, '>': operator.gt, @@ -80,6 +82,19 @@ def future(fn): "Unexpected success for future test '%s'" % fn_name) return function_named(decorated, fn_name) +def db_spec(*dbs): + dialects = set([x for x in dbs if '+' not in x]) + drivers = set([x[1:] for x in dbs if x.startswith('+')]) + specs = set([tuple(x.split('+')) for x in dbs if '+' in x and x not in drivers]) + + def check(engine): + return engine.name in dialects or \ + engine.driver in drivers or \ + (engine.name, engine.driver) in specs + + return check + + def fails_on(dbs, reason): """Mark a test as expected to fail on the specified database implementation. @@ -90,23 +105,25 @@ def fails_on(dbs, reason): succeeds, a failure is reported. """ + spec = db_spec(dbs) + def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): - if config.db.name != dbs: + if not spec(config.db): return fn(*args, **kw) else: try: fn(*args, **kw) except Exception, ex: print ("'%s' failed as expected on DB implementation " - "'%s': %s" % ( - fn_name, config.db.name, reason)) + "'%s+%s': %s" % ( + fn_name, config.db.name, config.db.driver, reason)) return True else: raise AssertionError( - "Unexpected success for '%s' on DB implementation '%s'" % - (fn_name, config.db.name)) + "Unexpected success for '%s' on DB implementation '%s+%s'" % + (fn_name, config.db.name, config.db.driver)) return function_named(maybe, fn_name) return decorate @@ -117,23 +134,25 @@ def fails_on_everything_except(*dbs): databases except those listed. """ + spec = db_spec(*dbs) + def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): - if config.db.name in dbs: + if spec(config.db): return fn(*args, **kw) else: try: fn(*args, **kw) except Exception, ex: print ("'%s' failed as expected on DB implementation " - "'%s': %s" % ( - fn_name, config.db.name, str(ex))) + "'%s+%s': %s" % ( + fn_name, config.db.name, config.db.driver, str(ex))) return True else: raise AssertionError( - "Unexpected success for '%s' on DB implementation '%s'" % - (fn_name, config.db.name)) + "Unexpected success for '%s' on DB implementation '%s+%s'" % + (fn_name, config.db.name, config.db.driver)) return function_named(maybe, fn_name) return decorate @@ -145,12 +164,13 @@ def crashes(db, reason): """ carp = _should_carp_about_exclusion(reason) + spec = db_spec(db) def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): - if config.db.name == db: - msg = "'%s' unsupported on DB implementation '%s': %s" % ( - fn_name, config.db.name, reason) + if spec(config.db): + msg = "'%s' unsupported on DB implementation '%s+%s': %s" % ( + fn_name, config.db.name, config.db.driver, reason) print msg if carp: print >> sys.stderr, msg @@ -169,12 +189,13 @@ def _block_unconditionally(db, reason): """ carp = _should_carp_about_exclusion(reason) + spec = db_spec(db) def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): - if config.db.name == db: - msg = "'%s' unsupported on DB implementation '%s': %s" % ( - fn_name, config.db.name, reason) + if spec(config.db): + msg = "'%s' unsupported on DB implementation '%s+%s': %s" % ( + fn_name, config.db.name, config.db.driver, reason) print msg if carp: print >> sys.stderr, msg @@ -198,6 +219,7 @@ def exclude(db, op, spec, reason): """ carp = _should_carp_about_exclusion(reason) + def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): @@ -242,7 +264,9 @@ def _is_excluded(db, op, spec): _is_excluded('yikesdb', 'in', ((0, 3, 'alpha2'), (0, 3, 'alpha3'))) """ - if config.db.name != db: + vendor_spec = db_spec(db) + + if not vendor_spec(config.db): return False version = _server_version() @@ -255,7 +279,12 @@ def _server_version(bind=None): if bind is None: bind = config.db - return bind.dialect.server_version_info(bind.contextual_connect()) + + # force metadata to be retrieved + conn = bind.connect() + version = getattr(bind.dialect, 'server_version_info', ()) + conn.close() + return version def skip_if(predicate, reason=None): """Skip a test if predicate is true.""" @@ -266,8 +295,7 @@ def skip_if(predicate, reason=None): if predicate(): msg = "'%s' skipped on DB %s version '%s': %s" % ( fn_name, config.db.name, _server_version(), reason) - print msg - return True + raise SkipTest(msg) else: return fn(*args, **kw) return function_named(maybe, fn_name) @@ -315,10 +343,12 @@ def emits_warning_on(db, *warnings): strings; these will be matched to the root of the warning description by warnings.filterwarnings(). """ + spec = db_spec(db) + def decorate(fn): def maybe(*args, **kw): if isinstance(db, basestring): - if config.db.name != db: + if not spec(config.db): return fn(*args, **kw) else: wrapped = emits_warning(*warnings)(fn) @@ -384,6 +414,19 @@ def resetwarnings(): if sys.version_info < (2, 4): warnings.filterwarnings('ignore', category=FutureWarning) +def global_cleanup_assertions(): + """Check things that have to be finalized at the end of a test suite. + + Hardcoded at the moment, a modular system can be built here + to support things like PG prepared transactions, tables all + dropped, etc. + + """ + + testutil.lazy_gc() + assert not pool._refs + + def against(*queries): """Boolean predicate, compares to testing database configuration. @@ -394,21 +437,20 @@ def against(*queries): Also supports comparison to database version when provided with one or more 3-tuples of dialect name, operator, and version specification:: - testing.against('mysql', 'postgres') + testing.against('mysql', 'postgresql') testing.against(('mysql', '>=', (5, 0, 0)) """ for query in queries: if isinstance(query, basestring): - if config.db.name == query: + if db_spec(query)(config.db): return True else: name, op, spec = query - if config.db.name != name: + if not db_spec(name)(config.db): continue - have = config.db.dialect.server_version_info( - config.db.contextual_connect()) + have = _server_version() oper = hasattr(op, '__call__') and op or _ops[op] if oper(have, spec): @@ -545,16 +587,15 @@ class AssertsCompiledSQL(object): if dialect is None: dialect = getattr(self, '__dialect__', None) - if params is None: - keys = None - else: - keys = params.keys() + kw = {} + if params is not None: + kw['column_keys'] = params.keys() - c = clause.compile(column_keys=keys, dialect=dialect) + c = clause.compile(dialect=dialect, **kw) - print "\nSQL String:\n" + str(c) + repr(c.params) + print "\nSQL String:\n" + str(c) + repr(getattr(c, 'params', {})) - cc = re.sub(r'\n', '', str(c)) + cc = re.sub(r'[\n\t]', '', str(c)) eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect)) @@ -563,18 +604,13 @@ class AssertsCompiledSQL(object): class ComparesTables(object): def assert_tables_equal(self, table, reflected_table): - base_mro = sqltypes.TypeEngine.__mro__ assert len(table.c) == len(reflected_table.c) for c, reflected_c in zip(table.c, reflected_table.c): eq_(c.name, reflected_c.name) assert reflected_c is reflected_table.c[c.name] eq_(c.primary_key, reflected_c.primary_key) eq_(c.nullable, reflected_c.nullable) - assert len( - set(type(reflected_c.type).__mro__).difference(base_mro).intersection( - set(type(c.type).__mro__).difference(base_mro) - ) - ) > 0, "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type) + self.assert_types_base(reflected_c, c) if isinstance(c.type, sqltypes.String): eq_(c.type.length, reflected_c.type.length) @@ -586,14 +622,21 @@ class ComparesTables(object): elif against(('mysql', '<', (5, 0))): # ignore reflection of bogus db-generated DefaultClause() pass - elif not c.primary_key or not against('postgres'): - print repr(c) + elif not c.primary_key or not against('postgresql', 'mssql'): + #print repr(c) assert reflected_c.default is None, reflected_c.default assert len(table.primary_key) == len(reflected_table.primary_key) for c in table.primary_key: assert reflected_table.primary_key.columns[c.name] - + + def assert_types_base(self, c1, c2): + base_mro = sqltypes.TypeEngine.__mro__ + assert len( + set(type(c1.type).__mro__).difference(base_mro).intersection( + set(type(c2.type).__mro__).difference(base_mro) + ) + ) > 0, "On column %r, type '%s' doesn't correspond to type '%s'" % (c1.name, c1.type, c2.type) class AssertsExecutionResults(object): def assert_result(self, result, class_, *objects): @@ -678,7 +721,7 @@ class AssertsExecutionResults(object): assertsql.asserter.clear_rules() def assert_sql(self, db, callable_, list_, with_sequences=None): - if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgres'): + if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgresql'): rules = with_sequences else: rules = list_ diff --git a/lib/sqlalchemy/test/util.py b/lib/sqlalchemy/test/util.py new file mode 100644 index 000000000..60b0a4ef8 --- /dev/null +++ b/lib/sqlalchemy/test/util.py @@ -0,0 +1,24 @@ +from sqlalchemy.util import jython, function_named + +import gc +import time + +if jython: + def gc_collect(*args): + """aggressive gc.collect for tests.""" + gc.collect() + time.sleep(0.1) + gc.collect() + gc.collect() + return 0 + + # "lazy" gc, for VM's that don't GC on refcount == 0 + lazy_gc = gc_collect + +else: + # assume CPython - straight gc.collect, lazy_gc() is a pass + gc_collect = gc.collect + def lazy_gc(): + pass + + |
