summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2012-09-26 17:21:21 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2012-09-26 17:21:21 -0400
commit1a7778632d8039bd687e23522ce6c027e859d487 (patch)
tree1ea58d81c25db7d65f6244d1e12344c1ef8dd97f
parent0895e34d21c9818ec73e5c87e35ad6ba5c05acbd (diff)
downloadsqlalchemy-1a7778632d8039bd687e23522ce6c027e859d487.tar.gz
- further reorganization of test suite:
- bootstrap and lib move to all absolute imports - testing.py is no longer internally referenced. - requirements move to be a pluggable class which can be overridden. - cleanup in the interests of third party testing, test/lib and test/bootstrap may move to be an independent package.
-rw-r--r--setup.cfg19
-rw-r--r--test/bootstrap/config.py59
-rw-r--r--test/bootstrap/noseplugin.py94
-rw-r--r--test/lib/__init__.py18
-rw-r--r--test/lib/assertions.py349
-rw-r--r--test/lib/engines.py11
-rw-r--r--test/lib/fixtures.py21
-rw-r--r--test/lib/orm.py111
-rw-r--r--test/lib/pickleable.py20
-rw-r--r--test/lib/profiling.py14
-rw-r--r--test/lib/requires.py899
-rw-r--r--test/lib/schema.py22
-rw-r--r--test/lib/testing.py452
-rw-r--r--test/lib/util.py81
-rw-r--r--test/lib/warnings.py43
-rw-r--r--test/sql/test_defaults.py3
16 files changed, 1090 insertions, 1126 deletions
diff --git a/setup.cfg b/setup.cfg
index c2983fe41..1610fb00f 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -2,7 +2,24 @@
tag_build = dev
[nosetests]
-with-_sqlalchemy = true
+with-sqlalchemy = true
exclude = ^examples
first-package-wins = true
where = test
+
+[requirements]
+requirement_cls=test.lib.requires:DefaultRequirements
+
+[db]
+sqlite=sqlite:///:memory:
+sqlite_file=sqlite:///querytest.db
+postgresql=postgresql://scott:tiger@127.0.0.1:5432/test
+postgres=postgresql://scott:tiger@127.0.0.1:5432/test
+pg8000=postgresql+pg8000://scott:tiger@127.0.0.1:5432/test
+postgresql_jython=postgresql+zxjdbc://scott:tiger@127.0.0.1:5432/test
+mysql_jython=mysql+zxjdbc://scott:tiger@127.0.0.1:5432/test
+mysql=mysql://scott:tiger@127.0.0.1:3306/test
+pymysql=mysql+pymysql://scott:tiger@127.0.0.1:3306/test?use_unicode=0&charset=utf8
+oracle=oracle://scott:tiger@127.0.0.1:1521
+oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
+maxdb=maxdb://MONA:RED@/maxdb1
diff --git a/test/bootstrap/config.py b/test/bootstrap/config.py
index 86dc78bcc..dfaf972b9 100644
--- a/test/bootstrap/config.py
+++ b/test/bootstrap/config.py
@@ -1,36 +1,19 @@
-import optparse, os, sys, re, ConfigParser, time, warnings
+"""Option and configuration implementations, run by the nose plugin
+on test suite startup."""
-# 2to3
-import StringIO
+import time
+import warnings
+import sys
+import re
logging = None
-
-__all__ = 'parser', 'configure', 'options',
-
db = None
-db_label, db_url, db_opts = None, None, {}
-
+db_label = None
+db_url = None
+db_opts = {}
options = None
file_config = None
-base_config = """
-[db]
-sqlite=sqlite:///:memory:
-sqlite_file=sqlite:///querytest.db
-postgresql=postgresql://scott:tiger@127.0.0.1:5432/test
-postgres=postgresql://scott:tiger@127.0.0.1:5432/test
-pg8000=postgresql+pg8000://scott:tiger@127.0.0.1:5432/test
-postgresql_jython=postgresql+zxjdbc://scott:tiger@127.0.0.1:5432/test
-mysql_jython=mysql+zxjdbc://scott:tiger@127.0.0.1:5432/test
-mysql=mysql://scott:tiger@127.0.0.1:3306/test
-pymysql=mysql+pymysql://scott:tiger@127.0.0.1:3306/test?use_unicode=0&charset=utf8
-oracle=oracle://scott:tiger@127.0.0.1:1521
-oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
-mssql=mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test
-firebird=firebird://sysdba:masterkey@localhost//tmp/test.fdb
-maxdb=maxdb://MONA:RED@/maxdb1
-"""
-
def _log(option, opt_str, value, parser):
global logging
if not logging:
@@ -49,11 +32,9 @@ def _list_dbs(*args):
print "%20s\t%s" % (macro, file_config.get('db', macro))
sys.exit(0)
-
def _server_side_cursors(options, opt_str, value, parser):
db_opts['server_side_cursors'] = True
-
def _zero_timeout(options, opt_str, value, parser):
warnings.warn("--zero-timeout testing option is now on in all cases")
@@ -171,6 +152,26 @@ def _reverse_topological(options, file_config):
from sqlalchemy.orm import unitofwork, session, mapper, dependency
from sqlalchemy.util import topological
from test.lib.util import RandomSet
- topological.set = unitofwork.set = session.set = mapper.set = dependency.set = RandomSet
+ topological.set = unitofwork.set = session.set = mapper.set = \
+ dependency.set = RandomSet
post_configure.append(_reverse_topological)
+def _requirements(options, file_config):
+ from test.lib import testing
+ requirement_cls = file_config.get('requirements', "requirement_cls")
+
+ modname, clsname = requirement_cls.split(":")
+
+ # importlib.import_module() only introduced in 2.7, a little
+ # late
+ mod = __import__(modname)
+ for component in modname.split(".")[1:]:
+ mod = getattr(mod, component)
+
+ req_cls = getattr(mod, clsname)
+ global requirements
+ requirements = req_cls(db, sys.modules[__name__])
+ testing.requires = requirements
+
+post_configure.append(_requirements)
+
diff --git a/test/bootstrap/noseplugin.py b/test/bootstrap/noseplugin.py
index e664552dc..4aec808f9 100644
--- a/test/bootstrap/noseplugin.py
+++ b/test/bootstrap/noseplugin.py
@@ -1,29 +1,25 @@
-import logging
import os
-import re
-import sys
-import time
-import warnings
import ConfigParser
-import StringIO
-import nose.case
from nose.plugins import Plugin
from nose import SkipTest
from test.bootstrap import config
-from test.bootstrap.config import (
- _create_testing_engine, _engine_pool, _engine_strategy, _engine_uri, _list_dbs, _log,
- _prep_testing_database, _require, _reverse_topological, _server_side_cursors,
- _monkeypatch_cdecimal, _zero_timeout,
- _set_table_options, base_config, db, db_label, db_url, file_config, post_configure,
- pre_configure)
+from test.bootstrap.config import _log, _list_dbs, _zero_timeout, \
+ _engine_strategy, _server_side_cursors, pre_configure,\
+ post_configure
+# late imports
testing = None
+fixtures = None
engines = None
+exclusions = None
+warnings = None
+profiling = None
+assertions = None
+requires = None
util = None
-
-log = logging.getLogger('nose.plugins.sqlalchemy')
+file_config = None
class NoseSQLAlchemy(Plugin):
"""
@@ -33,7 +29,7 @@ class NoseSQLAlchemy(Plugin):
# nose 1.0 will allow us to replace the old "sqlalchemy" plugin,
# if installed, using the same name, but nose 1.0 isn't released yet...
- name = '_sqlalchemy'
+ name = 'sqlalchemy'
score = 100
def options(self, parser, env=os.environ):
@@ -52,8 +48,7 @@ class NoseSQLAlchemy(Plugin):
opt("--dburi", action="store", dest="dburi",
help="Database uri (overrides --db)")
opt("--dropfirst", action="store_true", dest="dropfirst",
- help="Drop all tables in the target database first (use with caution on Oracle, "
- "MS-SQL)")
+ help="Drop all tables in the target database first")
opt("--mockpool", action="store_true", dest="mockpool",
help="Use mock pool (asserts only one connection used)")
opt("--zero-timeout", action="callback", callback=_zero_timeout,
@@ -86,8 +81,7 @@ class NoseSQLAlchemy(Plugin):
help="Write/update profiling data.")
global file_config
file_config = ConfigParser.ConfigParser()
- file_config.readfp(StringIO.StringIO(base_config))
- file_config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')])
+ file_config.read(['setup.cfg', 'test.cfg', os.path.expanduser('~/.satest.cfg')])
config.file_config = file_config
def configure(self, options, conf):
@@ -97,17 +91,18 @@ class NoseSQLAlchemy(Plugin):
fn(self.options, file_config)
def begin(self):
- global testing, requires, util, fixtures, engines
- from test.lib import testing, requires, fixtures, engines
- from sqlalchemy import util
-
- testing.db = db
- testing.requires = requires
-
# Lazy setup of other options (post coverage)
for fn in post_configure:
fn(self.options, file_config)
+ # late imports, has to happen after config as well
+ # as nose plugins like coverage
+ global testing, requires, util, fixtures, engines, exclusions, \
+ assertions, warnings, profiling
+ from test.lib import testing, requires, fixtures, engines, exclusions, \
+ assertions, warnings, profiling
+ from sqlalchemy import util
+
def describeTest(self, test):
return ""
@@ -136,26 +131,27 @@ class NoseSQLAlchemy(Plugin):
def _do_skips(self, cls):
if hasattr(cls, '__requires__'):
- def test_suite(): return 'ok'
+ def test_suite():
+ return 'ok'
test_suite.__name__ = cls.__name__
for requirement in cls.__requires__:
- check = getattr(requires, requirement)
+ check = getattr(config.requirements, requirement)
check(test_suite)()
if cls.__unsupported_on__:
- spec = testing.db_spec(*cls.__unsupported_on__)
- if spec(testing.db):
+ spec = exclusions.db_spec(*cls.__unsupported_on__)
+ if spec(config.db):
raise SkipTest(
"'%s' unsupported on DB implementation '%s'" % (
- cls.__name__, testing.db.name)
+ cls.__name__, config.db.name)
)
if getattr(cls, '__only_on__', None):
- spec = testing.db_spec(*util.to_list(cls.__only_on__))
- if not spec(testing.db):
+ spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
+ if not spec(config.db):
raise SkipTest(
"'%s' unsupported on DB implementation '%s'" % (
- cls.__name__, testing.db.name)
+ cls.__name__, config.db.name)
)
if getattr(cls, '__skip_if__', False):
@@ -166,43 +162,29 @@ class NoseSQLAlchemy(Plugin):
)
for db, op, spec in getattr(cls, '__excluded_on__', ()):
- testing.exclude(db, op, spec,
+ exclusions.exclude(db, op, spec,
"'%s' unsupported on DB %s version %s" % (
- cls.__name__, testing.db.name,
- testing._server_version()))
+ cls.__name__, config.db.name,
+ exclusions._server_version()))
def beforeTest(self, test):
- testing.resetwarnings()
- testing.current_test = test.id()
+ warnings.resetwarnings()
+ profiling._current_test = test.id()
def afterTest(self, test):
engines.testing_reaper._after_test_ctx()
- testing.resetwarnings()
-
- def _setup_cls_engines(self, cls):
- engine_opts = getattr(cls, '__testing_engine__', None)
- if engine_opts:
- self._save_testing_db = testing.db
- testing.db = engines.testing_engine(options=engine_opts)
-
- def _teardown_cls_engines(self, cls):
- engine_opts = getattr(cls, '__testing_engine__', None)
- if engine_opts:
- testing.db = self._save_testing_db
- del self._save_testing_db
+ warnings.resetwarnings()
def startContext(self, ctx):
if not isinstance(ctx, type) \
or not issubclass(ctx, fixtures.TestBase):
return
self._do_skips(ctx)
- self._setup_cls_engines(ctx)
def stopContext(self, ctx):
if not isinstance(ctx, type) \
or not issubclass(ctx, fixtures.TestBase):
return
engines.testing_reaper._stop_test_ctx()
- self._teardown_cls_engines(ctx)
if not config.options.low_connections:
- testing.global_cleanup_assertions()
+ assertions.global_cleanup_assertions()
diff --git a/test/lib/__init__.py b/test/lib/__init__.py
index b36db71fc..a1b8eb6d1 100644
--- a/test/lib/__init__.py
+++ b/test/lib/__init__.py
@@ -1,22 +1,20 @@
"""Testing environment and utilities.
-This package contains base classes and routines used by
-the unit tests. Tests are based on Nose and bootstrapped
-by noseplugin.NoseSQLAlchemy.
-
"""
-from test.bootstrap import config
-from test.lib import testing, engines, requires, profiling, pickleable, \
+from ..bootstrap import config
+from . import testing, engines, requires, profiling, pickleable, \
fixtures
-from test.lib.schema import Column, Table
-from test.lib.testing import AssertsCompiledSQL, \
- AssertsExecutionResults, ComparesTables, rowset
+from .schema import Column, Table
+from .assertions import AssertsCompiledSQL, \
+ AssertsExecutionResults, ComparesTables
+from .util import rowset
__all__ = ('testing',
'Column', 'Table',
- 'rowset','fixtures',
+ 'rowset',
+ 'fixtures',
'AssertsExecutionResults',
'AssertsCompiledSQL', 'ComparesTables',
'engines', 'profiling', 'pickleable')
diff --git a/test/lib/assertions.py b/test/lib/assertions.py
new file mode 100644
index 000000000..70284799c
--- /dev/null
+++ b/test/lib/assertions.py
@@ -0,0 +1,349 @@
+from __future__ import absolute_import
+
+from . import util as testutil
+from sqlalchemy import pool, orm, util
+from sqlalchemy.engine import default
+from sqlalchemy import exc as sa_exc
+from sqlalchemy.util import decorator
+from sqlalchemy import types as sqltypes, schema
+import warnings
+import re
+from .warnings import resetwarnings
+from .exclusions import db_spec, _is_excluded
+from . import assertsql
+from ..bootstrap import config
+import itertools
+from .util import fail
+
+def emits_warning(*messages):
+ """Mark a test as emitting a warning.
+
+ With no arguments, squelches all SAWarning failures. Or pass one or more
+ strings; these will be matched to the root of the warning description by
+ warnings.filterwarnings().
+ """
+ # TODO: it would be nice to assert that a named warning was
+ # emitted. should work with some monkeypatching of warnings,
+ # and may work on non-CPython if they keep to the spirit of
+ # warnings.showwarning's docstring.
+ # - update: jython looks ok, it uses cpython's module
+
+ @decorator
+ def decorate(fn, *args, **kw):
+ # todo: should probably be strict about this, too
+ filters = [dict(action='ignore',
+ category=sa_exc.SAPendingDeprecationWarning)]
+ if not messages:
+ filters.append(dict(action='ignore',
+ category=sa_exc.SAWarning))
+ else:
+ filters.extend(dict(action='ignore',
+ message=message,
+ category=sa_exc.SAWarning)
+ for message in messages)
+ for f in filters:
+ warnings.filterwarnings(**f)
+ try:
+ return fn(*args, **kw)
+ finally:
+ resetwarnings()
+ return decorate
+
+def emits_warning_on(db, *warnings):
+ """Mark a test as emitting a warning on a specific dialect.
+
+ With no arguments, squelches all SAWarning failures. Or pass one or more
+ strings; these will be matched to the root of the warning description by
+ warnings.filterwarnings().
+ """
+ spec = db_spec(db)
+
+ @decorator
+ def decorate(fn, *args, **kw):
+ if isinstance(db, basestring):
+ if not spec(config.db):
+ return fn(*args, **kw)
+ else:
+ wrapped = emits_warning(*warnings)(fn)
+ return wrapped(*args, **kw)
+ else:
+ if not _is_excluded(*db):
+ return fn(*args, **kw)
+ else:
+ wrapped = emits_warning(*warnings)(fn)
+ return wrapped(*args, **kw)
+ return decorate
+
+
+def uses_deprecated(*messages):
+ """Mark a test as immune from fatal deprecation warnings.
+
+ With no arguments, squelches all SADeprecationWarning failures.
+ Or pass one or more strings; these will be matched to the root
+ of the warning description by warnings.filterwarnings().
+
+ As a special case, you may pass a function name prefixed with //
+ and it will be re-written as needed to match the standard warning
+ verbiage emitted by the sqlalchemy.util.deprecated decorator.
+ """
+
+ @decorator
+ def decorate(fn, *args, **kw):
+ # todo: should probably be strict about this, too
+ filters = [dict(action='ignore',
+ category=sa_exc.SAPendingDeprecationWarning)]
+ if not messages:
+ filters.append(dict(action='ignore',
+ category=sa_exc.SADeprecationWarning))
+ else:
+ filters.extend(
+ [dict(action='ignore',
+ message=message,
+ category=sa_exc.SADeprecationWarning)
+ for message in
+ [(m.startswith('//') and
+ ('Call to deprecated function ' + m[2:]) or m)
+ for m in messages]])
+
+ for f in filters:
+ warnings.filterwarnings(**f)
+ try:
+ return fn(*args, **kw)
+ finally:
+ resetwarnings()
+ return decorate
+
+
+
+def global_cleanup_assertions():
+ """Check things that have to be finalized at the end of a test suite.
+
+ Hardcoded at the moment, a modular system can be built here
+ to support things like PG prepared transactions, tables all
+ dropped, etc.
+
+ """
+
+ testutil.lazy_gc()
+ assert not pool._refs, str(pool._refs)
+
+
+
+def eq_(a, b, msg=None):
+ """Assert a == b, with repr messaging on failure."""
+ assert a == b, msg or "%r != %r" % (a, b)
+
+def ne_(a, b, msg=None):
+ """Assert a != b, with repr messaging on failure."""
+ assert a != b, msg or "%r == %r" % (a, b)
+
+def is_(a, b, msg=None):
+ """Assert a is b, with repr messaging on failure."""
+ assert a is b, msg or "%r is not %r" % (a, b)
+
+def is_not_(a, b, msg=None):
+ """Assert a is not b, with repr messaging on failure."""
+ assert a is not b, msg or "%r is %r" % (a, b)
+
+def startswith_(a, fragment, msg=None):
+ """Assert a.startswith(fragment), with repr messaging on failure."""
+ assert a.startswith(fragment), msg or "%r does not start with %r" % (
+ a, fragment)
+
+def assert_raises(except_cls, callable_, *args, **kw):
+ try:
+ callable_(*args, **kw)
+ success = False
+ except except_cls:
+ success = True
+
+ # assert outside the block so it works for AssertionError too !
+ assert success, "Callable did not raise an exception"
+
+def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
+ try:
+ callable_(*args, **kwargs)
+ assert False, "Callable did not raise an exception"
+ except except_cls, e:
+ assert re.search(msg, unicode(e), re.UNICODE), u"%r !~ %s" % (msg, e)
+ print unicode(e).encode('utf-8')
+
+
+class AssertsCompiledSQL(object):
+ def assert_compile(self, clause, result, params=None,
+ checkparams=None, dialect=None,
+ checkpositional=None,
+ use_default_dialect=False,
+ allow_dialect_select=False):
+ if use_default_dialect:
+ dialect = default.DefaultDialect()
+ elif dialect == None and not allow_dialect_select:
+ dialect = getattr(self, '__dialect__', None)
+ if dialect == 'default':
+ dialect = default.DefaultDialect()
+ elif dialect is None:
+ dialect = config.db.dialect
+
+ kw = {}
+ if params is not None:
+ kw['column_keys'] = params.keys()
+
+ if isinstance(clause, orm.Query):
+ context = clause._compile_context()
+ context.statement.use_labels = True
+ clause = context.statement
+
+ c = clause.compile(dialect=dialect, **kw)
+
+ param_str = repr(getattr(c, 'params', {}))
+ # Py3K
+ #param_str = param_str.encode('utf-8').decode('ascii', 'ignore')
+
+ print "\nSQL String:\n" + str(c) + param_str
+
+ cc = re.sub(r'[\n\t]', '', str(c))
+
+ eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
+
+ if checkparams is not None:
+ eq_(c.construct_params(params), checkparams)
+ if checkpositional is not None:
+ p = c.construct_params(params)
+ eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
+
+class ComparesTables(object):
+ def assert_tables_equal(self, table, reflected_table, strict_types=False):
+ assert len(table.c) == len(reflected_table.c)
+ for c, reflected_c in zip(table.c, reflected_table.c):
+ eq_(c.name, reflected_c.name)
+ assert reflected_c is reflected_table.c[c.name]
+ eq_(c.primary_key, reflected_c.primary_key)
+ eq_(c.nullable, reflected_c.nullable)
+
+ if strict_types:
+ assert type(reflected_c.type) is type(c.type), \
+ "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type)
+ else:
+ self.assert_types_base(reflected_c, c)
+
+ if isinstance(c.type, sqltypes.String):
+ eq_(c.type.length, reflected_c.type.length)
+
+ eq_(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys]))
+ if c.server_default:
+ assert isinstance(reflected_c.server_default,
+ schema.FetchedValue)
+
+ assert len(table.primary_key) == len(reflected_table.primary_key)
+ for c in table.primary_key:
+ assert reflected_table.primary_key.columns[c.name] is not None
+
+ def assert_types_base(self, c1, c2):
+ assert c1.type._compare_type_affinity(c2.type),\
+ "On column %r, type '%s' doesn't correspond to type '%s'" % \
+ (c1.name, c1.type, c2.type)
+
+class AssertsExecutionResults(object):
+ def assert_result(self, result, class_, *objects):
+ result = list(result)
+ print repr(result)
+ self.assert_list(result, class_, objects)
+
+ def assert_list(self, result, class_, list):
+ self.assert_(len(result) == len(list),
+ "result list is not the same size as test list, " +
+ "for class " + class_.__name__)
+ for i in range(0, len(list)):
+ self.assert_row(class_, result[i], list[i])
+
+ def assert_row(self, class_, rowobj, desc):
+ self.assert_(rowobj.__class__ is class_,
+ "item class is not " + repr(class_))
+ for key, value in desc.iteritems():
+ if isinstance(value, tuple):
+ if isinstance(value[1], list):
+ self.assert_list(getattr(rowobj, key), value[0], value[1])
+ else:
+ self.assert_row(value[0], getattr(rowobj, key), value[1])
+ else:
+ self.assert_(getattr(rowobj, key) == value,
+ "attribute %s value %s does not match %s" % (
+ key, getattr(rowobj, key), value))
+
+ def assert_unordered_result(self, result, cls, *expected):
+ """As assert_result, but the order of objects is not considered.
+
+ The algorithm is very expensive but not a big deal for the small
+ numbers of rows that the test suite manipulates.
+ """
+
+ class immutabledict(dict):
+ def __hash__(self):
+ return id(self)
+
+ found = util.IdentitySet(result)
+ expected = set([immutabledict(e) for e in expected])
+
+ for wrong in itertools.ifilterfalse(lambda o: type(o) == cls, found):
+ fail('Unexpected type "%s", expected "%s"' % (
+ type(wrong).__name__, cls.__name__))
+
+ if len(found) != len(expected):
+ fail('Unexpected object count "%s", expected "%s"' % (
+ len(found), len(expected)))
+
+ NOVALUE = object()
+ def _compare_item(obj, spec):
+ for key, value in spec.iteritems():
+ if isinstance(value, tuple):
+ try:
+ self.assert_unordered_result(
+ getattr(obj, key), value[0], *value[1])
+ except AssertionError:
+ return False
+ else:
+ if getattr(obj, key, NOVALUE) != value:
+ return False
+ return True
+
+ for expected_item in expected:
+ for found_item in found:
+ if _compare_item(found_item, expected_item):
+ found.remove(found_item)
+ break
+ else:
+ fail(
+ "Expected %s instance with attributes %s not found." % (
+ cls.__name__, repr(expected_item)))
+ return True
+
+ def assert_sql_execution(self, db, callable_, *rules):
+ assertsql.asserter.add_rules(rules)
+ try:
+ callable_()
+ assertsql.asserter.statement_complete()
+ finally:
+ assertsql.asserter.clear_rules()
+
+ def assert_sql(self, db, callable_, list_, with_sequences=None):
+ if with_sequences is not None and config.db.dialect.supports_sequences:
+ rules = with_sequences
+ else:
+ rules = list_
+
+ newrules = []
+ for rule in rules:
+ if isinstance(rule, dict):
+ newrule = assertsql.AllOf(*[
+ assertsql.ExactSQL(k, v) for k, v in rule.iteritems()
+ ])
+ else:
+ newrule = assertsql.ExactSQL(*rule)
+ newrules.append(newrule)
+
+ self.assert_sql_execution(db, callable_, *newrules)
+
+ def assert_sql_count(self, db, callable_, count):
+ self.assert_sql_execution(db, callable_, assertsql.CountStatements(count))
+
+
diff --git a/test/lib/engines.py b/test/lib/engines.py
index e226d11bc..1431d5a86 100644
--- a/test/lib/engines.py
+++ b/test/lib/engines.py
@@ -1,10 +1,11 @@
-import sys, types, weakref
+from __future__ import absolute_import
+
+import types
+import weakref
from collections import deque
-from test.bootstrap import config
-from test.lib.util import decorator, gc_collect
-from sqlalchemy.util import callable
+from ..bootstrap import config
+from .util import decorator
from sqlalchemy import event, pool
-from sqlalchemy.engine import base as engine_base
import re
import warnings
diff --git a/test/lib/fixtures.py b/test/lib/fixtures.py
index a2f7bd1fb..722b64cfc 100644
--- a/test/lib/fixtures.py
+++ b/test/lib/fixtures.py
@@ -1,10 +1,10 @@
-from test.lib import testing
-from test.lib import schema
-from test.lib.testing import adict
-from test.lib.engines import drop_all_tables
+from ..bootstrap import config
+from . import assertions, schema
+from .util import adict
+from .engines import drop_all_tables
+from .entities import BasicEntity, ComparableEntity
import sys
import sqlalchemy as sa
-from test.lib.entities import BasicEntity, ComparableEntity
from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta
class TestBase(object):
@@ -26,11 +26,6 @@ class TestBase(object):
# skipped.
__skip_if__ = None
- # replace testing.db with a testing.engine()
- # for the duration of this suite, using the given
- # arguments
- __testing_engine__ = None
-
def assert_(self, val, msg=None):
assert val, msg
@@ -158,7 +153,7 @@ class TablesTest(TestBase):
@classmethod
def setup_bind(cls):
- return testing.db
+ return config.db
@classmethod
def dispose_bind(cls, bind):
@@ -217,7 +212,7 @@ class _ORMTest(object):
class ORMTest(_ORMTest, TestBase):
pass
-class MappedTest(_ORMTest, TablesTest, testing.AssertsExecutionResults):
+class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
# 'once', 'each', None
run_setup_classes = 'once'
@@ -336,4 +331,4 @@ class DeclarativeMappedTest(MappedTest):
cls.DeclarativeBasic = _DeclBase
fn()
if cls.metadata.tables:
- cls.metadata.create_all(testing.db)
+ cls.metadata.create_all(config.db)
diff --git a/test/lib/orm.py b/test/lib/orm.py
deleted file mode 100644
index 7ec13c555..000000000
--- a/test/lib/orm.py
+++ /dev/null
@@ -1,111 +0,0 @@
-import inspect, re
-import config, testing
-from sqlalchemy import orm
-
-__all__ = 'mapper',
-
-
-_whitespace = re.compile(r'^(\s+)')
-
-def _find_pragma(lines, current):
- m = _whitespace.match(lines[current])
- basis = m and m.group() or ''
-
- for line in reversed(lines[0:current]):
- if 'testlib.pragma' in line:
- return line
- m = _whitespace.match(line)
- indent = m and m.group() or ''
-
- # simplistic detection:
-
- # >> # testlib.pragma foo
- # >> center_line()
- if indent == basis:
- break
- # >> # testlib.pragma foo
- # >> if fleem:
- # >> center_line()
- if line.endswith(':'):
- break
- return None
-
-def _make_blocker(method_name, fallback):
- """Creates tripwired variant of a method, raising when called.
-
- To excempt an invocation from blockage, there are two options.
-
- 1) add a pragma in a comment::
-
- # testlib.pragma exempt:methodname
- offending_line()
-
- 2) add a magic cookie to the function's namespace::
- __sa_baremethodname_exempt__ = True
- ...
- offending_line()
- another_offending_lines()
-
- The second is useful for testing and development.
- """
-
- if method_name.startswith('__') and method_name.endswith('__'):
- frame_marker = '__sa_%s_exempt__' % method_name[2:-2]
- else:
- frame_marker = '__sa_%s_exempt__' % method_name
- pragma_marker = 'exempt:' + method_name
-
- def method(self, *args, **kw):
- frame_r = None
- try:
- frame = inspect.stack()[1][0]
- frame_r = inspect.getframeinfo(frame, 9)
-
- module = frame.f_globals.get('__name__', '')
-
- type_ = type(self)
-
- pragma = _find_pragma(*frame_r[3:5])
-
- exempt = (
- (not module.startswith('sqlalchemy')) or
- (pragma and pragma_marker in pragma) or
- (frame_marker in frame.f_locals) or
- ('self' in frame.f_locals and
- getattr(frame.f_locals['self'], frame_marker, False)))
-
- if exempt:
- supermeth = getattr(super(type_, self), method_name, None)
- if (supermeth is None or
- getattr(supermeth, 'im_func', None) is method):
- return fallback(self, *args, **kw)
- else:
- return supermeth(*args, **kw)
- else:
- raise AssertionError(
- "%s.%s called in %s, line %s in %s" % (
- type_.__name__, method_name, module, frame_r[1], frame_r[2]))
- finally:
- del frame
- method.__name__ = method_name
- return method
-
-def mapper(type_, *args, **kw):
- forbidden = [
- ('__hash__', 'unhashable', lambda s: id(s)),
- ('__eq__', 'noncomparable', lambda s, o: s is o),
- ('__ne__', 'noncomparable', lambda s, o: s is not o),
- ('__cmp__', 'noncomparable', lambda s, o: object.__cmp__(s, o)),
- ('__le__', 'noncomparable', lambda s, o: object.__le__(s, o)),
- ('__lt__', 'noncomparable', lambda s, o: object.__lt__(s, o)),
- ('__ge__', 'noncomparable', lambda s, o: object.__ge__(s, o)),
- ('__gt__', 'noncomparable', lambda s, o: object.__gt__(s, o)),
- ('__nonzero__', 'truthless', lambda s: 1), ]
-
- if isinstance(type_, type) and type_.__bases__ == (object,):
- for method_name, option, fallback in forbidden:
- if (getattr(config.options, option, False) and
- method_name not in type_.__dict__):
- setattr(type_, method_name, _make_blocker(method_name, fallback))
-
- return orm.mapper(type_, *args, **kw)
diff --git a/test/lib/pickleable.py b/test/lib/pickleable.py
index 98d104f1d..f5b8b827c 100644
--- a/test/lib/pickleable.py
+++ b/test/lib/pickleable.py
@@ -1,6 +1,6 @@
"""Classes used in pickling tests, need to be at the module level for unpickling."""
-from test.lib import fixtures
+from . import fixtures
class User(fixtures.ComparableEntity):
pass
@@ -28,9 +28,9 @@ class Parent(fixtures.ComparableEntity):
pass
class Screen(object):
- def __init__(self, obj, parent=None):
- self.obj = obj
- self.parent = parent
+ def __init__(self, obj, parent=None):
+ self.obj = obj
+ self.parent = parent
class Foo(object):
def __init__(self, moredata):
@@ -39,7 +39,9 @@ class Foo(object):
self.moredata = moredata
__hash__ = object.__hash__
def __eq__(self, other):
- return other.data == self.data and other.stuff == self.stuff and other.moredata==self.moredata
+ return other.data == self.data and \
+ other.stuff == self.stuff and \
+ other.moredata == self.moredata
class Bar(object):
@@ -48,7 +50,9 @@ class Bar(object):
self.y = y
__hash__ = object.__hash__
def __eq__(self, other):
- return other.__class__ is self.__class__ and other.x==self.x and other.y==self.y
+ return other.__class__ is self.__class__ and \
+ other.x == self.x and \
+ other.y == self.y
def __str__(self):
return "Bar(%d, %d)" % (self.x, self.y)
@@ -57,7 +61,9 @@ class OldSchool:
self.x = x
self.y = y
def __eq__(self, other):
- return other.__class__ is self.__class__ and other.x==self.x and other.y==self.y
+ return other.__class__ is self.__class__ and \
+ other.x == self.x and \
+ other.y == self.y
class OldSchoolWithoutCompare:
def __init__(self, x, y):
diff --git a/test/lib/profiling.py b/test/lib/profiling.py
index e02c4ce46..dbf969586 100644
--- a/test/lib/profiling.py
+++ b/test/lib/profiling.py
@@ -7,8 +7,8 @@ in a more fine-grained way than nose's profiling plugin.
import os
import sys
-from test.lib.util import gc_collect, decorator
-from test.lib import testing
+from .util import gc_collect, decorator
+from ..bootstrap import config
from nose import SkipTest
import pstats
import time
@@ -20,8 +20,7 @@ except ImportError:
cProfile = None
from sqlalchemy.util.compat import jython, pypy, win32
-from test.lib.requires import _has_cextensions
-_has_cext = _has_cextensions()
+_current_test = None
def profiled(target=None, **target_opts):
"""Function profiling.
@@ -109,7 +108,7 @@ class ProfileStatsFile(object):
@util.memoized_property
def platform_key(self):
- dbapi_key = testing.db.name + "_" + testing.db.driver
+ dbapi_key = config.db.name + "_" + config.db.driver
# keep it at 2.7, 3.1, 3.2, etc. for now.
py_version = '.'.join([str(v) for v in sys.version_info[0:2]])
@@ -122,15 +121,16 @@ class ProfileStatsFile(object):
platform_tokens.append("pypy")
if win32:
platform_tokens.append("win")
+ _has_cext = config.requirements._has_cextensions()
platform_tokens.append(_has_cext and "cextensions" or "nocextensions")
return "_".join(platform_tokens)
def has_stats(self):
- test_key = testing.current_test
+ test_key = _current_test
return test_key in self.data and self.platform_key in self.data[test_key]
def result(self, callcount):
- test_key = testing.current_test
+ test_key = _current_test
per_fn = self.data[test_key]
per_platform = per_fn[self.platform_key]
diff --git a/test/lib/requires.py b/test/lib/requires.py
index 923dfb80d..1fff3d7c2 100644
--- a/test/lib/requires.py
+++ b/test/lib/requires.py
@@ -5,7 +5,7 @@ target database.
"""
-from exclusions import \
+from .exclusions import \
skip, \
skip_if,\
only_if,\
@@ -13,7 +13,8 @@ from exclusions import \
fails_on,\
fails_on_everything_except,\
fails_if,\
- SpecPredicate
+ SpecPredicate,\
+ against
def no_support(db, reason):
return SpecPredicate(db, description=reason)
@@ -22,442 +23,484 @@ def exclude(db, op, spec, description=None):
return SpecPredicate(db, op, spec, description=description)
from sqlalchemy import util
-from test.lib import config
-import testing
+from ..bootstrap import config
import sys
crashes = skip
+def _chain_decorators_on(*decorators):
+ def decorate(fn):
+ for decorator in reversed(decorators):
+ fn = decorator(fn)
+ return fn
+ return decorate
-def _chain_decorators_on(fn, *decorators):
- for decorator in reversed(decorators):
- fn = decorator(fn)
- return fn
-
-def deferrable_or_no_constraints(fn):
- """Target database must support derferable constraints."""
-
- return skip_if([
- no_support('firebird', 'not supported by database'),
- no_support('mysql', 'not supported by database'),
- no_support('mssql', 'not supported by database'),
- ])(fn)
-
-def foreign_keys(fn):
- """Target database must support foreign keys."""
-
- return skip_if(
- no_support('sqlite', 'not supported by database')
- )(fn)
-
-
-def unbounded_varchar(fn):
- """Target database must support VARCHAR with no length"""
-
- return skip_if([
- "firebird", "oracle", "mysql"
- ], "not supported by database"
- )(fn)
-
-def boolean_col_expressions(fn):
- """Target database must support boolean expressions as columns"""
- return skip_if([
- no_support('firebird', 'not supported by database'),
- no_support('oracle', 'not supported by database'),
- no_support('mssql', 'not supported by database'),
- no_support('sybase', 'not supported by database'),
- no_support('maxdb', 'FIXME: verify not supported by database'),
- no_support('informix', 'not supported by database'),
- ])(fn)
-
-def standalone_binds(fn):
- """target database/driver supports bound parameters as column expressions
- without being in the context of a typed column.
-
- """
- return skip_if(["firebird", "mssql+mxodbc"],
- "not supported by driver")(fn)
-
-def identity(fn):
- """Target database must support GENERATED AS IDENTITY or a facsimile.
-
- Includes GENERATED AS IDENTITY, AUTOINCREMENT, AUTO_INCREMENT, or other
- column DDL feature that fills in a DB-generated identifier at INSERT-time
- without requiring pre-execution of a SEQUENCE or other artifact.
-
- """
- return skip_if(["firebird", "oracle", "postgresql", "sybase"],
- "not supported by database"
- )(fn)
-
-def reflectable_autoincrement(fn):
- """Target database must support tables that can automatically generate
- PKs assuming they were reflected.
-
- this is essentially all the DBs in "identity" plus Postgresql, which
- has SERIAL support. FB and Oracle (and sybase?) require the Sequence to
- be explicitly added, including if the table was reflected.
- """
- return skip_if(["firebird", "oracle", "sybase"],
- "not supported by database"
- )(fn)
-
-def binary_comparisons(fn):
- """target database/driver can allow BLOB/BINARY fields to be compared
- against a bound parameter value.
- """
- return skip_if(["oracle", "mssql"],
- "not supported by database/driver"
- )(fn)
-
-def independent_cursors(fn):
- """Target must support simultaneous, independent database cursors
- on a single connection."""
-
- return skip_if(["mssql+pyodbc", "mssql+mxodbc"], "no driver support")
-
-def independent_connections(fn):
- """Target must support simultaneous, independent database connections."""
-
- # This is also true of some configurations of UnixODBC and probably win32
- # ODBC as well.
- return skip_if([
- no_support("sqlite",
- "independent connections disabled "
- "when :memory: connections are used"),
- exclude("mssql", "<", (9, 0, 0),
- "SQL Server 2005+ is required for "
- "independent connections"
- )
- ]
- )(fn)
-
-def updateable_autoincrement_pks(fn):
- """Target must support UPDATE on autoincrement/integer primary key."""
-
- return skip_if(["mssql", "sybase"],
- "IDENTITY columns can't be updated")(fn)
-
-def isolation_level(fn):
- return _chain_decorators_on(
- fn,
- only_on(('postgresql', 'sqlite', 'mysql'),
- "DBAPI has no isolation level support"),
- fails_on('postgresql+pypostgresql',
- 'pypostgresql bombs on multiple isolation level calls')
- )
-
-def row_triggers(fn):
- """Target must support standard statement-running EACH ROW triggers."""
-
- return skip_if([
- # no access to same table
- no_support('mysql', 'requires SUPER priv'),
- exclude('mysql', '<', (5, 0, 10), 'not supported by database'),
-
- # huh? TODO: implement triggers for PG tests, remove this
- no_support('postgresql',
- 'PG triggers need to be implemented for tests'),
- ])(fn)
-
-def correlated_outer_joins(fn):
- """Target must support an outer join to a subquery which
- correlates to the parent."""
-
- return skip_if("oracle", 'Raises "ORA-01799: a column may not be '
- 'outer-joined to a subquery"')(fn)
-
-def update_from(fn):
- """Target must support UPDATE..FROM syntax"""
-
- return only_on(['postgresql', 'mssql', 'mysql'],
- "Backend does not support UPDATE..FROM")(fn)
-
-
-def savepoints(fn):
- """Target database must support savepoints."""
-
- return skip_if([
- "access",
- "sqlite",
- "sybase",
- ("mysql", "<", (5, 0, 3)),
- ("informix", "<", (11, 55, "xC3"))
- ], "savepoints not supported")(fn)
-
-def denormalized_names(fn):
- """Target database must have 'denormalized', i.e.
- UPPERCASE as case insensitive names."""
-
- return skip_if(
- lambda: not testing.db.dialect.requires_name_normalize,
- "Backend does not require denormalized names."
- )(fn)
-
-def schemas(fn):
- """Target database must support external schemas, and have one
- named 'test_schema'."""
-
- return skip_if([
- "sqlte",
- "firebird"
- ], "no schema support")
-
-def sequences(fn):
- """Target database must support SEQUENCEs."""
-
- return only_if([
- "postgresql", "firebird", "oracle"
- ], "no SEQUENCE support")(fn)
-
-def update_nowait(fn):
- """Target database must support SELECT...FOR UPDATE NOWAIT"""
- return skip_if(["access", "firebird", "mssql", "mysql", "sqlite", "sybase"],
- "no FOR UPDATE NOWAIT support"
- )(fn)
-
-def subqueries(fn):
- """Target database must support subqueries."""
-
- return skip_if(exclude('mysql', '<', (4, 1, 1)), 'no subquery support')(fn)
-
-def intersect(fn):
- """Target database must support INTERSECT or equivalent."""
-
- return fails_if([
- "firebird", "mysql", "sybase", "informix"
- ], 'no support for INTERSECT')(fn)
-
-def except_(fn):
- """Target database must support EXCEPT or equivalent (i.e. MINUS)."""
- return fails_if([
- "firebird", "mysql", "sybase", "informix"
- ], 'no support for EXCEPT')(fn)
-
-def offset(fn):
- """Target database must support some method of adding OFFSET or
- equivalent to a result set."""
- return fails_if([
- "sybase"
- ], 'no support for OFFSET or equivalent')(fn)
-
-def window_functions(fn):
- return only_if([
- "postgresql", "mssql", "oracle"
- ], "Backend does not support window functions")(fn)
-
-def returning(fn):
- return only_if(["postgresql", "mssql", "oracle", "firebird"],
- "'returning' not supported by database"
- )(fn)
-
-def two_phase_transactions(fn):
- """Target database must support two-phase transactions."""
-
- return skip_if([
- no_support('access', 'two-phase xact not supported by database'),
- no_support('firebird', 'no SA implementation'),
- no_support('maxdb', 'two-phase xact not supported by database'),
- no_support('mssql', 'two-phase xact not supported by drivers'),
- no_support('oracle', 'two-phase xact not implemented in SQLA/oracle'),
- no_support('drizzle', 'two-phase xact not supported by database'),
- no_support('sqlite', 'two-phase xact not supported by database'),
- no_support('sybase', 'two-phase xact not supported by drivers/SQLA'),
- no_support('postgresql+zxjdbc',
- 'FIXME: JDBC driver confuses the transaction state, may '
- 'need separate XA implementation'),
- exclude('mysql', '<', (5, 0, 3),
- 'two-phase xact not supported by database'),
- ])(fn)
-
-def views(fn):
- """Target database must support VIEWs."""
-
- return skip_if("drizzle", "no VIEW support")(fn)
-
-def unicode_connections(fn):
- """Target driver must support some encoding of Unicode across the wire."""
- # TODO: expand to exclude MySQLdb versions w/ broken unicode
- return skip_if([
- exclude('mysql', '<', (4, 1, 1), 'no unicode connection support'),
- ])(fn)
-
-def unicode_ddl(fn):
- """Target driver must support some encoding of Unicode across the wire."""
- # TODO: expand to exclude MySQLdb versions w/ broken unicode
- return skip_if([
- no_support('maxdb', 'database support flakey'),
- no_support('oracle', 'FIXME: no support in database?'),
- no_support('sybase', 'FIXME: guessing, needs confirmation'),
- no_support('mssql+pymssql', 'no FreeTDS support'),
- exclude('mysql', '<', (4, 1, 1), 'no unicode connection support'),
- ])(fn)
-
-def sane_rowcount(fn):
- return skip_if(
- lambda: not testing.db.dialect.supports_sane_rowcount,
- "driver doesn't support 'sane' rowcount"
- )(fn)
-
-def cextensions(fn):
- return skip_if(
- lambda: not _has_cextensions(), "C extensions not installed"
- )(fn)
-
-
-def emulated_lastrowid(fn):
- """"target dialect retrieves cursor.lastrowid or an equivalent
- after an insert() construct executes.
- """
- return fails_on_everything_except('mysql+mysqldb', 'mysql+oursql',
- 'sqlite+pysqlite', 'mysql+pymysql',
- 'mssql+pyodbc', 'mssql+mxodbc')(fn)
-
-def dbapi_lastrowid(fn):
- """"target backend includes a 'lastrowid' accessor on the DBAPI
- cursor object.
-
- """
- return fails_on_everything_except('mysql+mysqldb', 'mysql+oursql',
- 'sqlite+pysqlite', 'mysql+pymysql')(fn)
-
-def sane_multi_rowcount(fn):
- return skip_if(
- lambda: not testing.db.dialect.supports_sane_multi_rowcount,
- "driver doesn't support 'sane' multi row count"
+class Requirements(object):
+ def __init__(self, db, config):
+ self.db = db
+ self.config = config
+
+
+class DefaultRequirements(Requirements):
+ @property
+ def deferrable_or_no_constraints(self):
+ """Target database must support derferable constraints."""
+
+ return skip_if([
+ no_support('firebird', 'not supported by database'),
+ no_support('mysql', 'not supported by database'),
+ no_support('mssql', 'not supported by database'),
+ ])
+
+ @property
+ def foreign_keys(self):
+ """Target database must support foreign keys."""
+
+ return skip_if(
+ no_support('sqlite', 'not supported by database')
+ )
+
+
+ @property
+ def unbounded_varchar(self):
+ """Target database must support VARCHAR with no length"""
+
+ return skip_if([
+ "firebird", "oracle", "mysql"
+ ], "not supported by database"
+ )
+
+ @property
+ def boolean_col_expressions(self):
+ """Target database must support boolean expressions as columns"""
+ return skip_if([
+ no_support('firebird', 'not supported by database'),
+ no_support('oracle', 'not supported by database'),
+ no_support('mssql', 'not supported by database'),
+ no_support('sybase', 'not supported by database'),
+ no_support('maxdb', 'FIXME: verify not supported by database'),
+ no_support('informix', 'not supported by database'),
+ ])
+
+ @property
+ def standalone_binds(self):
+ """target database/driver supports bound parameters as column expressions
+ without being in the context of a typed column.
+
+ """
+ return skip_if(["firebird", "mssql+mxodbc"],
+ "not supported by driver")
+
+ @property
+ def identity(self):
+ """Target database must support GENERATED AS IDENTITY or a facsimile.
+
+ Includes GENERATED AS IDENTITY, AUTOINCREMENT, AUTO_INCREMENT, or other
+ column DDL feature that fills in a DB-generated identifier at INSERT-time
+ without requiring pre-execution of a SEQUENCE or other artifact.
+
+ """
+ return skip_if(["firebird", "oracle", "postgresql", "sybase"],
+ "not supported by database"
+ )
+
+ @property
+ def reflectable_autoincrement(self):
+ """Target database must support tables that can automatically generate
+ PKs assuming they were reflected.
+
+ this is essentially all the DBs in "identity" plus Postgresql, which
+ has SERIAL support. FB and Oracle (and sybase?) require the Sequence to
+ be explicitly added, including if the table was reflected.
+ """
+ return skip_if(["firebird", "oracle", "sybase"],
+ "not supported by database"
+ )
+
+ @property
+ def binary_comparisons(self):
+ """target database/driver can allow BLOB/BINARY fields to be compared
+ against a bound parameter value.
+ """
+ return skip_if(["oracle", "mssql"],
+ "not supported by database/driver"
+ )
+
+ @property
+ def independent_cursors(self):
+ """Target must support simultaneous, independent database cursors
+ on a single connection."""
+
+ return skip_if(["mssql+pyodbc", "mssql+mxodbc"], "no driver support")
+
+ @property
+ def independent_connections(self):
+ """Target must support simultaneous, independent database connections."""
+
+ # This is also true of some configurations of UnixODBC and probably win32
+ # ODBC as well.
+ return skip_if([
+ no_support("sqlite",
+ "independent connections disabled "
+ "when :memory: connections are used"),
+ exclude("mssql", "<", (9, 0, 0),
+ "SQL Server 2005+ is required for "
+ "independent connections"
+ )
+ ]
+ )
+
+ @property
+ def updateable_autoincrement_pks(self):
+ """Target must support UPDATE on autoincrement/integer primary key."""
+
+ return skip_if(["mssql", "sybase"],
+ "IDENTITY columns can't be updated")
+
+ @property
+ def isolation_level(self):
+ return _chain_decorators_on(
+ only_on(('postgresql', 'sqlite', 'mysql'),
+ "DBAPI has no isolation level support"),
+ fails_on('postgresql+pypostgresql',
+ 'pypostgresql bombs on multiple isolation level calls')
+ )
+
+ @property
+ def row_triggers(self):
+ """Target must support standard statement-running EACH ROW triggers."""
+
+ return skip_if([
+ # no access to same table
+ no_support('mysql', 'requires SUPER priv'),
+ exclude('mysql', '<', (5, 0, 10), 'not supported by database'),
+
+ # huh? TODO: implement triggers for PG tests, remove this
+ no_support('postgresql',
+ 'PG triggers need to be implemented for tests'),
+ ])
+
+ @property
+ def correlated_outer_joins(self):
+ """Target must support an outer join to a subquery which
+ correlates to the parent."""
+
+ return skip_if("oracle", 'Raises "ORA-01799: a column may not be '
+ 'outer-joined to a subquery"')
+
+ @property
+ def update_from(self):
+ """Target must support UPDATE..FROM syntax"""
+
+ return only_on(['postgresql', 'mssql', 'mysql'],
+ "Backend does not support UPDATE..FROM")
+
+
+ @property
+ def savepoints(self):
+ """Target database must support savepoints."""
+
+ return skip_if([
+ "access",
+ "sqlite",
+ "sybase",
+ ("mysql", "<", (5, 0, 3)),
+ ("informix", "<", (11, 55, "xC3"))
+ ], "savepoints not supported")
+
+ @property
+ def denormalized_names(self):
+ """Target database must have 'denormalized', i.e.
+ UPPERCASE as case insensitive names."""
+
+ return skip_if(
+ lambda: not self.db.dialect.requires_name_normalize,
+ "Backend does not require denormalized names."
+ )
+
+ @property
+ def schemas(self):
+ """Target database must support external schemas, and have one
+ named 'test_schema'."""
+
+ return skip_if([
+ "sqlite",
+ "firebird"
+ ], "no schema support")
+
+ @property
+ def sequences(self):
+ """Target database must support SEQUENCEs."""
+
+ return only_if([
+ "postgresql", "firebird", "oracle"
+ ], "no SEQUENCE support")
+
+ @property
+ def update_nowait(self):
+ """Target database must support SELECT...FOR UPDATE NOWAIT"""
+ return skip_if(["access", "firebird", "mssql", "mysql", "sqlite", "sybase"],
+ "no FOR UPDATE NOWAIT support"
+ )
+
+ @property
+ def subqueries(self):
+ """Target database must support subqueries."""
+
+ return skip_if(exclude('mysql', '<', (4, 1, 1)), 'no subquery support')
+
+ @property
+ def intersect(self):
+ """Target database must support INTERSECT or equivalent."""
+
+ return fails_if([
+ "firebird", "mysql", "sybase", "informix"
+ ], 'no support for INTERSECT')
+
+ @property
+ def except_(self):
+ """Target database must support EXCEPT or equivalent (i.e. MINUS)."""
+ return fails_if([
+ "firebird", "mysql", "sybase", "informix"
+ ], 'no support for EXCEPT')
+
+ @property
+ def offset(self):
+ """Target database must support some method of adding OFFSET or
+ equivalent to a result set."""
+ return fails_if([
+ "sybase"
+ ], 'no support for OFFSET or equivalent')
+
+ @property
+ def window_functions(self):
+ return only_if([
+ "postgresql", "mssql", "oracle"
+ ], "Backend does not support window functions")
+
+ @property
+ def returning(self):
+ return only_if(["postgresql", "mssql", "oracle", "firebird"],
+ "'returning' not supported by database"
+ )
+
+ @property
+ def two_phase_transactions(self):
+ """Target database must support two-phase transactions."""
+
+ return skip_if([
+ no_support('access', 'two-phase xact not supported by database'),
+ no_support('firebird', 'no SA implementation'),
+ no_support('maxdb', 'two-phase xact not supported by database'),
+ no_support('mssql', 'two-phase xact not supported by drivers'),
+ no_support('oracle', 'two-phase xact not implemented in SQLA/oracle'),
+ no_support('drizzle', 'two-phase xact not supported by database'),
+ no_support('sqlite', 'two-phase xact not supported by database'),
+ no_support('sybase', 'two-phase xact not supported by drivers/SQLA'),
+ no_support('postgresql+zxjdbc',
+ 'FIXME: JDBC driver confuses the transaction state, may '
+ 'need separate XA implementation'),
+ exclude('mysql', '<', (5, 0, 3),
+ 'two-phase xact not supported by database'),
+ ])
+
+ @property
+ def views(self):
+ """Target database must support VIEWs."""
+
+ return skip_if("drizzle", "no VIEW support")
+
+ @property
+ def unicode_connections(self):
+ """Target driver must support some encoding of Unicode across the wire."""
+ # TODO: expand to exclude MySQLdb versions w/ broken unicode
+ return skip_if([
+ exclude('mysql', '<', (4, 1, 1), 'no unicode connection support'),
+ ])
+
+ @property
+ def unicode_ddl(self):
+ """Target driver must support some encoding of Unicode across the wire."""
+ # TODO: expand to exclude MySQLdb versions w/ broken unicode
+ return skip_if([
+ no_support('maxdb', 'database support flakey'),
+ no_support('oracle', 'FIXME: no support in database?'),
+ no_support('sybase', 'FIXME: guessing, needs confirmation'),
+ no_support('mssql+pymssql', 'no FreeTDS support'),
+ exclude('mysql', '<', (4, 1, 1), 'no unicode connection support'),
+ ])
+
+ @property
+ def sane_rowcount(self):
+ return skip_if(
+ lambda: not self.db.dialect.supports_sane_rowcount,
+ "driver doesn't support 'sane' rowcount"
+ )
+
+ @property
+ def cextensions(self):
+ return skip_if(
+ lambda: not self._has_cextensions(), "C extensions not installed"
+ )
+
+
+ @property
+ def emulated_lastrowid(self):
+ """"target dialect retrieves cursor.lastrowid or an equivalent
+ after an insert() construct executes.
+ """
+ return fails_on_everything_except('mysql+mysqldb', 'mysql+oursql',
+ 'sqlite+pysqlite', 'mysql+pymysql',
+ 'mssql+pyodbc', 'mssql+mxodbc')
+
+ @property
+ def dbapi_lastrowid(self):
+ """"target backend includes a 'lastrowid' accessor on the DBAPI
+ cursor object.
+
+ """
+ return fails_on_everything_except('mysql+mysqldb', 'mysql+oursql',
+ 'sqlite+pysqlite', 'mysql+pymysql')
+
+ @property
+ def sane_multi_rowcount(self):
+ return skip_if(
+ lambda: not self.db.dialect.supports_sane_multi_rowcount,
+ "driver doesn't support 'sane' multi row count"
+ )
+
+ @property
+ def nullsordering(self):
+ """Target backends that support nulls ordering."""
+ return _chain_decorators_on(
+ fails_on_everything_except('postgresql', 'oracle', 'firebird')
+ )
+
+ @property
+ def reflects_pk_names(self):
+ """Target driver reflects the name of primary key constraints."""
+ return _chain_decorators_on(
+ fails_on_everything_except('postgresql', 'oracle')
+ )
+
+ @property
+ def python2(self):
+ return _chain_decorators_on(
+ skip_if(
+ lambda: sys.version_info >= (3,),
+ "Python version 2.xx is required."
+ )
+ )
+
+ @property
+ def python3(self):
+ return _chain_decorators_on(
+ skip_if(
+ lambda: sys.version_info < (3,),
+ "Python version 3.xx is required."
+ )
+ )
+
+ @property
+ def python26(self):
+ return _chain_decorators_on(
+ skip_if(
+ lambda: sys.version_info < (2, 6),
+ "Python version 2.6 or greater is required"
)
+ )
-def nullsordering(fn):
- """Target backends that support nulls ordering."""
- return _chain_decorators_on(
- fn,
- fails_on_everything_except('postgresql', 'oracle', 'firebird')
- )
-
-def reflects_pk_names(fn):
- """Target driver reflects the name of primary key constraints."""
- return _chain_decorators_on(
- fn,
- fails_on_everything_except('postgresql', 'oracle')
- )
-
-def python2(fn):
- return _chain_decorators_on(
- fn,
- skip_if(
- lambda: sys.version_info >= (3,),
- "Python version 2.xx is required."
+ @property
+ def python25(self):
+ return _chain_decorators_on(
+ skip_if(
+ lambda: sys.version_info < (2, 5),
+ "Python version 2.5 or greater is required"
)
- )
-
-def python3(fn):
- return _chain_decorators_on(
- fn,
- skip_if(
- lambda: sys.version_info < (3,),
- "Python version 3.xx is required."
+ )
+
+ @property
+ def cpython(self):
+ return _chain_decorators_on(
+ only_if(lambda: util.cpython,
+ "cPython interpreter needed"
+ )
+ )
+
+ @property
+ def predictable_gc(self):
+ """target platform must remove all cycles unconditionally when
+ gc.collect() is called, as well as clean out unreferenced subclasses.
+
+ """
+ return self.cpython
+
+ @property
+ def sqlite(self):
+ return _chain_decorators_on(
+ skip_if(lambda: not self._has_sqlite())
+ )
+
+ @property
+ def ad_hoc_engines(self):
+ """Test environment must allow ad-hoc engine/connection creation.
+
+ DBs that scale poorly for many connections, even when closed, i.e.
+ Oracle, may use the "--low-connections" option which flags this requirement
+ as not present.
+
+ """
+ return _chain_decorators_on(
+ skip_if(lambda: config.options.low_connections)
+ )
+
+ @property
+ def skip_mysql_on_windows(self):
+ """Catchall for a large variety of MySQL on Windows failures"""
+
+ return _chain_decorators_on(
+ skip_if(self._has_mysql_on_windows,
+ "Not supported on MySQL + Windows"
)
- )
-
-def python26(fn):
- return _chain_decorators_on(
- fn,
- skip_if(
- lambda: sys.version_info < (2, 6),
- "Python version 2.6 or greater is required"
)
- )
-
-def python25(fn):
- return _chain_decorators_on(
- fn,
- skip_if(
- lambda: sys.version_info < (2, 5),
- "Python version 2.5 or greater is required"
+
+ @property
+ def english_locale_on_postgresql(self):
+ return _chain_decorators_on(
+ skip_if(lambda: against('postgresql') \
+ and not self.db.scalar('SHOW LC_COLLATE').startswith('en'))
)
- )
-
-def cpython(fn):
- return _chain_decorators_on(
- fn,
- only_if(lambda: util.cpython,
- "cPython interpreter needed"
- )
- )
-
-def predictable_gc(fn):
- """target platform must remove all cycles unconditionally when
- gc.collect() is called, as well as clean out unreferenced subclasses.
-
- """
- return cpython(fn)
-
-def sqlite(fn):
- return _chain_decorators_on(
- fn,
- skip_if(lambda: not _has_sqlite())
- )
-
-def ad_hoc_engines(fn):
- """Test environment must allow ad-hoc engine/connection creation.
-
- DBs that scale poorly for many connections, even when closed, i.e.
- Oracle, may use the "--low-connections" option which flags this requirement
- as not present.
-
- """
- return _chain_decorators_on(
- fn,
- skip_if(lambda: config.options.low_connections)
- )
-
-def skip_mysql_on_windows(fn):
- """Catchall for a large variety of MySQL on Windows failures"""
-
- return _chain_decorators_on(
- fn,
- skip_if(_has_mysql_on_windows,
- "Not supported on MySQL + Windows"
+
+ @property
+ def selectone(self):
+ """target driver must support the literal statement 'select 1'"""
+ return _chain_decorators_on(
+ skip_if(lambda: against('oracle'),
+ "non-standard SELECT scalar syntax"),
+ skip_if(lambda: against('firebird'),
+ "non-standard SELECT scalar syntax")
)
- )
-
-def english_locale_on_postgresql(fn):
- return _chain_decorators_on(
- fn,
- skip_if(lambda: testing.against('postgresql') \
- and not testing.db.scalar('SHOW LC_COLLATE').startswith('en'))
- )
-
-def selectone(fn):
- """target driver must support the literal statement 'select 1'"""
- return _chain_decorators_on(
- fn,
- skip_if(lambda: testing.against('oracle'),
- "non-standard SELECT scalar syntax"),
- skip_if(lambda: testing.against('firebird'),
- "non-standard SELECT scalar syntax")
- )
-
-def _has_cextensions():
- try:
- from sqlalchemy import cresultproxy, cprocessors
- return True
- except ImportError:
- return False
-
-def _has_sqlite():
- from sqlalchemy import create_engine
- try:
- e = create_engine('sqlite://')
- return True
- except ImportError:
- return False
-
-def _has_mysql_on_windows():
- return testing.against('mysql') and \
- testing.db.dialect._detect_casing(testing.db) == 1
-
-def _has_mysql_fully_case_sensitive():
- return testing.against('mysql') and \
- testing.db.dialect._detect_casing(testing.db) == 0
+
+ def _has_cextensions(self):
+ try:
+ from sqlalchemy import cresultproxy, cprocessors
+ return True
+ except ImportError:
+ return False
+
+ def _has_sqlite(self):
+ from sqlalchemy import create_engine
+ try:
+ create_engine('sqlite://')
+ return True
+ except ImportError:
+ return False
+
+ def _has_mysql_on_windows(self):
+ return against('mysql') and \
+ self.db.dialect._detect_casing(self.db) == 1
+
+ def _has_mysql_fully_case_sensitive(self):
+ return against('mysql') and \
+ self.db.dialect._detect_casing(self.db) == 0
diff --git a/test/lib/schema.py b/test/lib/schema.py
index 2328770d5..6c7a684e2 100644
--- a/test/lib/schema.py
+++ b/test/lib/schema.py
@@ -2,8 +2,9 @@
desired state for different backends.
"""
-from test.lib import testing
+from . import exclusions
from sqlalchemy import schema, event
+from ..bootstrap import config
__all__ = 'Table', 'Column',
@@ -12,12 +13,12 @@ table_options = {}
def Table(*args, **kw):
"""A schema.Table wrapper/hook for dialect-specific tweaks."""
- test_opts = dict([(k,kw.pop(k)) for k in kw.keys()
+ test_opts = dict([(k, kw.pop(k)) for k in kw.keys()
if k.startswith('test_')])
kw.update(table_options)
- if testing.against('mysql'):
+ if exclusions.against('mysql'):
if 'mysql_engine' not in kw and 'mysql_type' not in kw:
if 'test_needs_fk' in test_opts or 'test_needs_acid' in test_opts:
kw['mysql_engine'] = 'InnoDB'
@@ -26,9 +27,9 @@ def Table(*args, **kw):
# Apply some default cascading rules for self-referential foreign keys.
# MySQL InnoDB has some issues around seleting self-refs too.
- if testing.against('firebird'):
+ if exclusions.against('firebird'):
table_name = args[0]
- unpack = (testing.config.db.dialect.
+ unpack = (config.db.dialect.
identifier_preparer.unformat_identifiers)
# Only going after ForeignKeys in Columns. May need to
@@ -59,23 +60,26 @@ def Table(*args, **kw):
def Column(*args, **kw):
"""A schema.Column wrapper/hook for dialect-specific tweaks."""
- test_opts = dict([(k,kw.pop(k)) for k in kw.keys()
+ test_opts = dict([(k, kw.pop(k)) for k in kw.keys()
if k.startswith('test_')])
col = schema.Column(*args, **kw)
if 'test_needs_autoincrement' in test_opts and \
kw.get('primary_key', False) and \
- testing.against('firebird', 'oracle'):
+ exclusions.against('firebird', 'oracle'):
def add_seq(c, tbl):
c._init_items(
- schema.Sequence(_truncate_name(testing.db.dialect, tbl.name + '_' + c.name + '_seq'), optional=True)
+ schema.Sequence(_truncate_name(
+ config.db.dialect, tbl.name + '_' + c.name + '_seq'),
+ optional=True)
)
event.listen(col, 'after_parent_attach', add_seq, propagate=True)
return col
def _truncate_name(dialect, name):
if len(name) > dialect.max_identifier_length:
- return name[0:max(dialect.max_identifier_length - 6, 0)] + "_" + hex(hash(name) % 64)[2:]
+ return name[0:max(dialect.max_identifier_length - 6, 0)] + \
+ "_" + hex(hash(name) % 64)[2:]
else:
return name
diff --git a/test/lib/testing.py b/test/lib/testing.py
index 1de446f91..f244dfecc 100644
--- a/test/lib/testing.py
+++ b/test/lib/testing.py
@@ -1,453 +1,25 @@
-"""TestCase and TestSuite artifacts and testing decorators."""
+from __future__ import absolute_import
-import itertools
-import re
-import sys
-import types
-import warnings
-from cStringIO import StringIO
+from .warnings import testing_warn, assert_warnings, resetwarnings
-from test.bootstrap import config
-from test.lib import assertsql, util as testutil
+from ..bootstrap import config
+from . import assertsql, util as testutil
from sqlalchemy.util import decorator
-from engines import drop_all_tables
-from sqlalchemy import exc as sa_exc, util, types as sqltypes, schema, \
- pool, orm
-from sqlalchemy.engine import default
-from exclusions import db_spec, _is_excluded, fails_if, skip_if, future,\
+from .exclusions import db_spec, _is_excluded, fails_if, skip_if, future,\
fails_on, fails_on_everything_except, skip, only_on, exclude, against,\
_server_version
+from .assertions import emits_warning, emits_warning_on, uses_deprecated, \
+ eq_, ne_, is_, is_not_, startswith_, assert_raises, \
+ assert_raises_message, AssertsCompiledSQL, ComparesTables, AssertsExecutionResults
+
+from .util import run_as_contextmanager, rowset, fail, provide_metadata, adict
+
crashes = skip
-# sugar ('testing.db'); set here by config() at runtime
+# various sugar installed by config.py
db = None
-
-# more sugar, installed by __init__
requires = None
-def emits_warning(*messages):
- """Mark a test as emitting a warning.
-
- With no arguments, squelches all SAWarning failures. Or pass one or more
- strings; these will be matched to the root of the warning description by
- warnings.filterwarnings().
- """
- # TODO: it would be nice to assert that a named warning was
- # emitted. should work with some monkeypatching of warnings,
- # and may work on non-CPython if they keep to the spirit of
- # warnings.showwarning's docstring.
- # - update: jython looks ok, it uses cpython's module
-
- @decorator
- def decorate(fn, *args, **kw):
- # todo: should probably be strict about this, too
- filters = [dict(action='ignore',
- category=sa_exc.SAPendingDeprecationWarning)]
- if not messages:
- filters.append(dict(action='ignore',
- category=sa_exc.SAWarning))
- else:
- filters.extend(dict(action='ignore',
- message=message,
- category=sa_exc.SAWarning)
- for message in messages)
- for f in filters:
- warnings.filterwarnings(**f)
- try:
- return fn(*args, **kw)
- finally:
- resetwarnings()
- return decorate
-
-def emits_warning_on(db, *warnings):
- """Mark a test as emitting a warning on a specific dialect.
-
- With no arguments, squelches all SAWarning failures. Or pass one or more
- strings; these will be matched to the root of the warning description by
- warnings.filterwarnings().
- """
- spec = db_spec(db)
-
- @decorator
- def decorate(fn, *args, **kw):
- if isinstance(db, basestring):
- if not spec(config.db):
- return fn(*args, **kw)
- else:
- wrapped = emits_warning(*warnings)(fn)
- return wrapped(*args, **kw)
- else:
- if not _is_excluded(*db):
- return fn(*args, **kw)
- else:
- wrapped = emits_warning(*warnings)(fn)
- return wrapped(*args, **kw)
- return decorate
-
-def assert_warnings(fn, warnings):
- """Assert that each of the given warnings are emitted by fn."""
-
- canary = []
- orig_warn = util.warn
- def capture_warnings(*args, **kw):
- orig_warn(*args, **kw)
- popwarn = warnings.pop(0)
- canary.append(popwarn)
- eq_(args[0], popwarn)
- util.warn = util.langhelpers.warn = capture_warnings
-
- result = emits_warning()(fn)()
- assert canary, "No warning was emitted"
- return result
-
-def uses_deprecated(*messages):
- """Mark a test as immune from fatal deprecation warnings.
-
- With no arguments, squelches all SADeprecationWarning failures.
- Or pass one or more strings; these will be matched to the root
- of the warning description by warnings.filterwarnings().
-
- As a special case, you may pass a function name prefixed with //
- and it will be re-written as needed to match the standard warning
- verbiage emitted by the sqlalchemy.util.deprecated decorator.
- """
-
- @decorator
- def decorate(fn, *args, **kw):
- # todo: should probably be strict about this, too
- filters = [dict(action='ignore',
- category=sa_exc.SAPendingDeprecationWarning)]
- if not messages:
- filters.append(dict(action='ignore',
- category=sa_exc.SADeprecationWarning))
- else:
- filters.extend(
- [dict(action='ignore',
- message=message,
- category=sa_exc.SADeprecationWarning)
- for message in
- [ (m.startswith('//') and
- ('Call to deprecated function ' + m[2:]) or m)
- for m in messages] ])
-
- for f in filters:
- warnings.filterwarnings(**f)
- try:
- return fn(*args, **kw)
- finally:
- resetwarnings()
- return decorate
-
-def testing_warn(msg, stacklevel=3):
- """Replaces sqlalchemy.util.warn during tests."""
-
- filename = "test.lib.testing"
- lineno = 1
- if isinstance(msg, basestring):
- warnings.warn_explicit(msg, sa_exc.SAWarning, filename, lineno)
- else:
- warnings.warn_explicit(msg, filename, lineno)
-
-def resetwarnings():
- """Reset warning behavior to testing defaults."""
-
- util.warn = util.langhelpers.warn = testing_warn
-
- warnings.filterwarnings('ignore',
- category=sa_exc.SAPendingDeprecationWarning)
- warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning)
- warnings.filterwarnings('error', category=sa_exc.SAWarning)
-
-
-def global_cleanup_assertions():
- """Check things that have to be finalized at the end of a test suite.
-
- Hardcoded at the moment, a modular system can be built here
- to support things like PG prepared transactions, tables all
- dropped, etc.
-
- """
-
- testutil.lazy_gc()
- assert not pool._refs, str(pool._refs)
-
-
-def run_as_contextmanager(ctx, fn, *arg, **kw):
- """Run the given function under the given contextmanager,
- simulating the behavior of 'with' to support older
- Python versions.
-
- """
-
- obj = ctx.__enter__()
- try:
- result = fn(obj, *arg, **kw)
- ctx.__exit__(None, None, None)
- return result
- except:
- exc_info = sys.exc_info()
- raise_ = ctx.__exit__(*exc_info)
- if raise_ is None:
- raise
- else:
- return raise_
-
-def rowset(results):
- """Converts the results of sql execution into a plain set of column tuples.
-
- Useful for asserting the results of an unordered query.
- """
-
- return set([tuple(row) for row in results])
-
-
-def eq_(a, b, msg=None):
- """Assert a == b, with repr messaging on failure."""
- assert a == b, msg or "%r != %r" % (a, b)
-
-def ne_(a, b, msg=None):
- """Assert a != b, with repr messaging on failure."""
- assert a != b, msg or "%r == %r" % (a, b)
-
-def is_(a, b, msg=None):
- """Assert a is b, with repr messaging on failure."""
- assert a is b, msg or "%r is not %r" % (a, b)
-
-def is_not_(a, b, msg=None):
- """Assert a is not b, with repr messaging on failure."""
- assert a is not b, msg or "%r is %r" % (a, b)
-
-def startswith_(a, fragment, msg=None):
- """Assert a.startswith(fragment), with repr messaging on failure."""
- assert a.startswith(fragment), msg or "%r does not start with %r" % (
- a, fragment)
-
-def assert_raises(except_cls, callable_, *args, **kw):
- try:
- callable_(*args, **kw)
- success = False
- except except_cls, e:
- success = True
-
- # assert outside the block so it works for AssertionError too !
- assert success, "Callable did not raise an exception"
-
-def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
- try:
- callable_(*args, **kwargs)
- assert False, "Callable did not raise an exception"
- except except_cls, e:
- assert re.search(msg, unicode(e), re.UNICODE), u"%r !~ %s" % (msg, e)
- print unicode(e).encode('utf-8')
-
-def fail(msg):
- assert False, msg
-
-
-@decorator
-def provide_metadata(fn, *args, **kw):
- """Provide bound MetaData for a single test, dropping afterwards."""
-
- metadata = schema.MetaData(db)
- self = args[0]
- prev_meta = getattr(self, 'metadata', None)
- self.metadata = metadata
- try:
- return fn(*args, **kw)
- finally:
- metadata.drop_all()
- self.metadata = prev_meta
-
-class adict(dict):
- """Dict keys available as attributes. Shadows."""
- def __getattribute__(self, key):
- try:
- return self[key]
- except KeyError:
- return dict.__getattribute__(self, key)
-
- def get_all(self, *keys):
- return tuple([self[key] for key in keys])
-
-
-class AssertsCompiledSQL(object):
- def assert_compile(self, clause, result, params=None,
- checkparams=None, dialect=None,
- checkpositional=None,
- use_default_dialect=False,
- allow_dialect_select=False):
- if use_default_dialect:
- dialect = default.DefaultDialect()
- elif dialect == None and not allow_dialect_select:
- dialect = getattr(self, '__dialect__', None)
- if dialect == 'default':
- dialect = default.DefaultDialect()
- elif dialect is None:
- dialect = db.dialect
-
- kw = {}
- if params is not None:
- kw['column_keys'] = params.keys()
-
- if isinstance(clause, orm.Query):
- context = clause._compile_context()
- context.statement.use_labels = True
- clause = context.statement
-
- c = clause.compile(dialect=dialect, **kw)
-
- param_str = repr(getattr(c, 'params', {}))
- # Py3K
- #param_str = param_str.encode('utf-8').decode('ascii', 'ignore')
-
- print "\nSQL String:\n" + str(c) + param_str
-
- cc = re.sub(r'[\n\t]', '', str(c))
-
- eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
-
- if checkparams is not None:
- eq_(c.construct_params(params), checkparams)
- if checkpositional is not None:
- p = c.construct_params(params)
- eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
-
-class ComparesTables(object):
- def assert_tables_equal(self, table, reflected_table, strict_types=False):
- assert len(table.c) == len(reflected_table.c)
- for c, reflected_c in zip(table.c, reflected_table.c):
- eq_(c.name, reflected_c.name)
- assert reflected_c is reflected_table.c[c.name]
- eq_(c.primary_key, reflected_c.primary_key)
- eq_(c.nullable, reflected_c.nullable)
-
- if strict_types:
- assert type(reflected_c.type) is type(c.type), \
- "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type)
- else:
- self.assert_types_base(reflected_c, c)
-
- if isinstance(c.type, sqltypes.String):
- eq_(c.type.length, reflected_c.type.length)
-
- eq_(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys]))
- if c.server_default:
- assert isinstance(reflected_c.server_default,
- schema.FetchedValue)
-
- assert len(table.primary_key) == len(reflected_table.primary_key)
- for c in table.primary_key:
- assert reflected_table.primary_key.columns[c.name] is not None
-
- def assert_types_base(self, c1, c2):
- assert c1.type._compare_type_affinity(c2.type),\
- "On column %r, type '%s' doesn't correspond to type '%s'" % \
- (c1.name, c1.type, c2.type)
-
-class AssertsExecutionResults(object):
- def assert_result(self, result, class_, *objects):
- result = list(result)
- print repr(result)
- self.assert_list(result, class_, objects)
-
- def assert_list(self, result, class_, list):
- self.assert_(len(result) == len(list),
- "result list is not the same size as test list, " +
- "for class " + class_.__name__)
- for i in range(0, len(list)):
- self.assert_row(class_, result[i], list[i])
-
- def assert_row(self, class_, rowobj, desc):
- self.assert_(rowobj.__class__ is class_,
- "item class is not " + repr(class_))
- for key, value in desc.iteritems():
- if isinstance(value, tuple):
- if isinstance(value[1], list):
- self.assert_list(getattr(rowobj, key), value[0], value[1])
- else:
- self.assert_row(value[0], getattr(rowobj, key), value[1])
- else:
- self.assert_(getattr(rowobj, key) == value,
- "attribute %s value %s does not match %s" % (
- key, getattr(rowobj, key), value))
-
- def assert_unordered_result(self, result, cls, *expected):
- """As assert_result, but the order of objects is not considered.
-
- The algorithm is very expensive but not a big deal for the small
- numbers of rows that the test suite manipulates.
- """
-
- class immutabledict(dict):
- def __hash__(self):
- return id(self)
-
- found = util.IdentitySet(result)
- expected = set([immutabledict(e) for e in expected])
-
- for wrong in itertools.ifilterfalse(lambda o: type(o) == cls, found):
- fail('Unexpected type "%s", expected "%s"' % (
- type(wrong).__name__, cls.__name__))
-
- if len(found) != len(expected):
- fail('Unexpected object count "%s", expected "%s"' % (
- len(found), len(expected)))
-
- NOVALUE = object()
- def _compare_item(obj, spec):
- for key, value in spec.iteritems():
- if isinstance(value, tuple):
- try:
- self.assert_unordered_result(
- getattr(obj, key), value[0], *value[1])
- except AssertionError:
- return False
- else:
- if getattr(obj, key, NOVALUE) != value:
- return False
- return True
-
- for expected_item in expected:
- for found_item in found:
- if _compare_item(found_item, expected_item):
- found.remove(found_item)
- break
- else:
- fail(
- "Expected %s instance with attributes %s not found." % (
- cls.__name__, repr(expected_item)))
- return True
-
- def assert_sql_execution(self, db, callable_, *rules):
- assertsql.asserter.add_rules(rules)
- try:
- callable_()
- assertsql.asserter.statement_complete()
- finally:
- assertsql.asserter.clear_rules()
-
- def assert_sql(self, db, callable_, list_, with_sequences=None):
- if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgresql'):
- rules = with_sequences
- else:
- rules = list_
-
- newrules = []
- for rule in rules:
- if isinstance(rule, dict):
- newrule = assertsql.AllOf(*[
- assertsql.ExactSQL(k, v) for k, v in rule.iteritems()
- ])
- else:
- newrule = assertsql.ExactSQL(*rule)
- newrules.append(newrule)
-
- self.assert_sql_execution(db, callable_, *newrules)
-
- def assert_sql_count(self, db, callable_, count):
- self.assert_sql_execution(db, callable_, assertsql.CountStatements(count))
-
-
diff --git a/test/lib/util.py b/test/lib/util.py
index e8a10c65a..c2144f6bd 100644
--- a/test/lib/util.py
+++ b/test/lib/util.py
@@ -4,9 +4,11 @@ from sqlalchemy.util.compat import decimal
import gc
import time
import random
+import sys
+import types
if jython:
- def gc_collect(*args):
+ def jython_gc_collect(*args):
"""aggressive gc.collect for tests."""
gc.collect()
time.sleep(0.1)
@@ -15,12 +17,12 @@ if jython:
return 0
# "lazy" gc, for VM's that don't GC on refcount == 0
- lazy_gc = gc_collect
+ lazy_gc = jython_gc_collect
elif pypy:
- def gc_collect(*args):
+ def pypy_gc_collect(*args):
gc.collect()
gc.collect()
- lazy_gc = gc_collect
+ lazy_gc = pypy_gc_collect
else:
# assume CPython - straight gc.collect, lazy_gc() is a pass
gc_collect = gc.collect
@@ -40,9 +42,9 @@ def picklers():
picklers.add(pickle)
# yes, this thing needs this much testing
- for pickle in picklers:
+ for pickle_ in picklers:
for protocol in -1, 0, 1, 2:
- yield pickle.loads, lambda d:pickle.dumps(d, protocol)
+ yield pickle_.loads, lambda d: pickle_.dumps(d, protocol)
def round_decimal(value, prec):
@@ -50,7 +52,8 @@ def round_decimal(value, prec):
return round(value, prec)
# can also use shift() here but that is 2.6 only
- return (value * decimal.Decimal("1" + "0" * prec)).to_integral(decimal.ROUND_FLOOR) / \
+ return (value * decimal.Decimal("1" + "0" * prec)
+ ).to_integral(decimal.ROUND_FLOOR) / \
pow(10, prec)
class RandomSet(set):
@@ -127,3 +130,67 @@ def function_named(fn, name):
fn.func_defaults, fn.func_closure)
return fn
+
+
+def run_as_contextmanager(ctx, fn, *arg, **kw):
+ """Run the given function under the given contextmanager,
+ simulating the behavior of 'with' to support older
+ Python versions.
+
+ """
+
+ obj = ctx.__enter__()
+ try:
+ result = fn(obj, *arg, **kw)
+ ctx.__exit__(None, None, None)
+ return result
+ except:
+ exc_info = sys.exc_info()
+ raise_ = ctx.__exit__(*exc_info)
+ if raise_ is None:
+ raise
+ else:
+ return raise_
+
+def rowset(results):
+ """Converts the results of sql execution into a plain set of column tuples.
+
+ Useful for asserting the results of an unordered query.
+ """
+
+ return set([tuple(row) for row in results])
+
+
+def fail(msg):
+ assert False, msg
+
+
+@decorator
+def provide_metadata(fn, *args, **kw):
+ """Provide bound MetaData for a single test, dropping afterwards."""
+
+ from ..bootstrap.config import db
+ from sqlalchemy import schema
+
+ metadata = schema.MetaData(db)
+ self = args[0]
+ prev_meta = getattr(self, 'metadata', None)
+ self.metadata = metadata
+ try:
+ return fn(*args, **kw)
+ finally:
+ metadata.drop_all()
+ self.metadata = prev_meta
+
+class adict(dict):
+ """Dict keys available as attributes. Shadows."""
+ def __getattribute__(self, key):
+ try:
+ return self[key]
+ except KeyError:
+ return dict.__getattribute__(self, key)
+
+ def get_all(self, *keys):
+ return tuple([self[key] for key in keys])
+
+
diff --git a/test/lib/warnings.py b/test/lib/warnings.py
new file mode 100644
index 000000000..d17d9f465
--- /dev/null
+++ b/test/lib/warnings.py
@@ -0,0 +1,43 @@
+from __future__ import absolute_import
+
+import warnings
+from sqlalchemy import exc as sa_exc
+from sqlalchemy import util
+
+def testing_warn(msg, stacklevel=3):
+ """Replaces sqlalchemy.util.warn during tests."""
+
+ filename = "test.lib.testing"
+ lineno = 1
+ if isinstance(msg, basestring):
+ warnings.warn_explicit(msg, sa_exc.SAWarning, filename, lineno)
+ else:
+ warnings.warn_explicit(msg, filename, lineno)
+
+def resetwarnings():
+ """Reset warning behavior to testing defaults."""
+
+ util.warn = util.langhelpers.warn = testing_warn
+
+ warnings.filterwarnings('ignore',
+ category=sa_exc.SAPendingDeprecationWarning)
+ warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning)
+ warnings.filterwarnings('error', category=sa_exc.SAWarning)
+
+def assert_warnings(fn, warnings):
+ """Assert that each of the given warnings are emitted by fn."""
+
+ from .assertions import eq_, emits_warning
+
+ canary = []
+ orig_warn = util.warn
+ def capture_warnings(*args, **kw):
+ orig_warn(*args, **kw)
+ popwarn = warnings.pop(0)
+ canary.append(popwarn)
+ eq_(args[0], popwarn)
+ util.warn = util.langhelpers.warn = capture_warnings
+
+ result = emits_warning()(fn)()
+ assert canary, "No warning was emitted"
+ return result
diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py
index b5277c623..55aa86633 100644
--- a/test/sql/test_defaults.py
+++ b/test/sql/test_defaults.py
@@ -13,7 +13,6 @@ from sqlalchemy.dialects import sqlite
from test.lib import fixtures
class DefaultTest(fixtures.TestBase):
- __testing_engine__ = {'execution_options':{'native_odbc_execute':False}}
@classmethod
def setup_class(cls):
@@ -405,7 +404,6 @@ class DefaultTest(fixtures.TestBase):
class PKDefaultTest(fixtures.TablesTest):
__requires__ = ('subqueries',)
- __testing_engine__ = {'execution_options':{'native_odbc_execute':False}}
@classmethod
def define_tables(cls, metadata):
@@ -441,7 +439,6 @@ class PKDefaultTest(fixtures.TablesTest):
class PKIncrementTest(fixtures.TablesTest):
run_define_tables = 'each'
- __testing_engine__ = {'execution_options':{'native_odbc_execute':False}}
@classmethod
def define_tables(cls, metadata):