summaryrefslogtreecommitdiff
path: root/test/testlib
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2008-05-09 16:34:10 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2008-05-09 16:34:10 +0000
commit4a6afd469fad170868554bf28578849bf3dfd5dd (patch)
treeb396edc33d567ae19dd244e87137296450467725 /test/testlib
parent46b7c9dc57a38d5b9e44a4723dad2ad8ec57baca (diff)
downloadsqlalchemy-4a6afd469fad170868554bf28578849bf3dfd5dd.tar.gz
r4695 merged to trunk; trunk now becomes 0.5.
0.4 development continues at /sqlalchemy/branches/rel_0_4
Diffstat (limited to 'test/testlib')
-rw-r--r--test/testlib/__init__.py22
-rw-r--r--test/testlib/compat.py18
-rw-r--r--test/testlib/engines.py2
-rw-r--r--test/testlib/filters.py4
-rw-r--r--test/testlib/fixtures.py72
-rw-r--r--test/testlib/profiling.py5
-rw-r--r--test/testlib/requires.py32
-rw-r--r--test/testlib/schema.py2
-rw-r--r--test/testlib/tables.py5
-rw-r--r--test/testlib/testing.py134
10 files changed, 221 insertions, 75 deletions
diff --git a/test/testlib/__init__.py b/test/testlib/__init__.py
index 98552b0f3..67e56e3d8 100644
--- a/test/testlib/__init__.py
+++ b/test/testlib/__init__.py
@@ -3,14 +3,21 @@
Load after sqlalchemy imports to use instrumented stand-ins like Table.
"""
+import sys
import testlib.config
from testlib.schema import Table, Column
from testlib.orm import mapper
import testlib.testing as testing
-from testlib.testing import rowset
-from testlib.testing import TestBase, AssertsExecutionResults, ORMTest, AssertsCompiledSQL, ComparesTables
+from testlib.testing import \
+ AssertsCompiledSQL, \
+ AssertsExecutionResults, \
+ ComparesTables, \
+ ORMTest, \
+ TestBase, \
+ rowset
import testlib.profiling as profiling
import testlib.engines as engines
+import testlib.requires as requires
from testlib.compat import set, frozenset, sorted, _function_named
@@ -18,6 +25,15 @@ __all__ = ('testing',
'mapper',
'Table', 'Column',
'rowset',
- 'TestBase', 'AssertsExecutionResults', 'ORMTest', 'AssertsCompiledSQL', 'ComparesTables',
+ 'TestBase', 'AssertsExecutionResults', 'ORMTest',
+ 'AssertsCompiledSQL', 'ComparesTables',
'profiling', 'engines',
'set', 'frozenset', 'sorted', '_function_named')
+
+
+testing.requires = requires
+
+sys.modules['testlib.sa'] = sa = testing.CompositeModule(
+ 'testlib.sa', 'sqlalchemy', 'testlib.schema', orm=testing.CompositeModule(
+ 'testlib.sa.orm', 'sqlalchemy.orm', 'testlib.orm'))
+sys.modules['testlib.sa.orm'] = sa.orm
diff --git a/test/testlib/compat.py b/test/testlib/compat.py
index ba12b78ac..fcb7fa1e9 100644
--- a/test/testlib/compat.py
+++ b/test/testlib/compat.py
@@ -1,6 +1,6 @@
-import itertools, new, sys, warnings
+import new
-__all__ = 'set', 'frozenset', 'sorted', '_function_named', 'deque'
+__all__ = 'set', 'frozenset', 'sorted', '_function_named', 'deque', 'reversed'
try:
set = set
@@ -69,6 +69,16 @@ except NameError:
return l
try:
+ reversed = reversed
+except NameError:
+ def reversed(seq):
+ i = len(seq) - 1
+ while i >= 0:
+ yield seq[i]
+ i -= 1
+ raise StopIteration()
+
+try:
from collections import deque
except ImportError:
class deque(list):
@@ -77,9 +87,7 @@ except ImportError:
def popleft(self):
return self.pop(0)
def extendleft(self, iterable):
- items = list(iterable)
- items.reverse()
- for x in items:
+ for x in reversed(list(iterable)):
self.insert(0, x)
def _function_named(fn, newname):
diff --git a/test/testlib/engines.py b/test/testlib/engines.py
index f5694df57..5ad35a066 100644
--- a/test/testlib/engines.py
+++ b/test/testlib/engines.py
@@ -1,6 +1,6 @@
import sys, types, weakref
from testlib import config
-from testlib.compat import *
+from testlib.compat import set, _function_named, deque
class ConnectionKiller(object):
diff --git a/test/testlib/filters.py b/test/testlib/filters.py
index eb7eff279..2d559f53b 100644
--- a/test/testlib/filters.py
+++ b/test/testlib/filters.py
@@ -14,8 +14,8 @@ Includes::
"""
import sys
-from StringIO import StringIO
-from tokenize import *
+from tokenize import generate_tokens, INDENT, DEDENT, NAME, OP, NL, NEWLINE, \
+ NUMBER, STRING, COMMENT
__all__ = ['py23_decorators', 'py23']
diff --git a/test/testlib/fixtures.py b/test/testlib/fixtures.py
index e8d71179a..f56b865c6 100644
--- a/test/testlib/fixtures.py
+++ b/test/testlib/fixtures.py
@@ -1,14 +1,16 @@
-# can't be imported until the path is setup; be sure to configure
-# first if covering.
-from sqlalchemy import *
-from sqlalchemy import util
-from testlib import *
-
-__all__ = ['keywords', 'addresses', 'Base', 'Keyword', 'FixtureTest', 'Dingaling', 'item_keywords',
- 'dingalings', 'User', 'items', 'Fixtures', 'orders', 'install_fixture_data', 'Address', 'users',
+from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey
+from testlib.sa.orm import attributes
+from testlib import ORMTest
+from testlib.compat import set
+
+
+__all__ = ['keywords', 'addresses', 'Base', 'Keyword', 'FixtureTest',
+ 'Dingaling', 'item_keywords', 'dingalings', 'User', 'items',
+ 'Fixtures', 'orders', 'install_fixture_data', 'Address', 'users',
'order_items', 'Item', 'Order', 'fixtures']
-
-_recursion_stack = util.Set()
+
+
+_recursion_stack = set()
class Base(object):
def __init__(self, **kwargs):
for k in kwargs:
@@ -36,10 +38,15 @@ class Base(object):
_recursion_stack.add(self)
try:
# pick the entity thats not SA persisted as the source
+ try:
+ state = attributes.instance_state(self)
+ key = state.key
+ except (KeyError, AttributeError):
+ key = None
if other is None:
a = self
b = other
- elif hasattr(self, '_instance_key'):
+ elif key is not None:
a = other
b = self
else:
@@ -57,8 +64,9 @@ class Base(object):
battr = getattr(b, attr)
except AttributeError:
#print "b class does not have attribute named '%s'" % attr
+ #raise
return False
-
+
if list(value) == list(battr):
continue
else:
@@ -84,43 +92,60 @@ metadata = MetaData()
users = Table('users', metadata,
Column('id', Integer, primary_key=True),
- Column('name', String(30), nullable=False))
+ Column('name', String(30), nullable=False),
+ test_needs_acid=True,
+ test_needs_fk=True
+ )
orders = Table('orders', metadata,
Column('id', Integer, primary_key=True),
Column('user_id', None, ForeignKey('users.id')),
Column('address_id', None, ForeignKey('addresses.id')),
Column('description', String(30)),
- Column('isopen', Integer)
+ Column('isopen', Integer),
+ test_needs_acid=True,
+ test_needs_fk=True
)
addresses = Table('addresses', metadata,
Column('id', Integer, primary_key=True),
Column('user_id', None, ForeignKey('users.id')),
- Column('email_address', String(50), nullable=False))
+ Column('email_address', String(50), nullable=False),
+ test_needs_acid=True,
+ test_needs_fk=True)
dingalings = Table("dingalings", metadata,
Column('id', Integer, primary_key=True),
Column('address_id', None, ForeignKey('addresses.id')),
- Column('data', String(30))
+ Column('data', String(30)),
+ test_needs_acid=True,
+ test_needs_fk=True
)
items = Table('items', metadata,
Column('id', Integer, primary_key=True),
- Column('description', String(30), nullable=False)
+ Column('description', String(30), nullable=False),
+ test_needs_acid=True,
+ test_needs_fk=True
)
order_items = Table('order_items', metadata,
Column('item_id', None, ForeignKey('items.id')),
- Column('order_id', None, ForeignKey('orders.id')))
+ Column('order_id', None, ForeignKey('orders.id')),
+ test_needs_acid=True,
+ test_needs_fk=True)
item_keywords = Table('item_keywords', metadata,
Column('item_id', None, ForeignKey('items.id')),
- Column('keyword_id', None, ForeignKey('keywords.id')))
+ Column('keyword_id', None, ForeignKey('keywords.id')),
+ test_needs_acid=True,
+ test_needs_fk=True)
keywords = Table('keywords', metadata,
Column('id', Integer, primary_key=True),
- Column('name', String(30), nullable=False)
+ Column('name', String(30), nullable=False),
+ test_needs_acid=True,
+ test_needs_fk=True
)
def install_fixture_data():
@@ -203,14 +228,15 @@ def install_fixture_data():
class FixtureTest(ORMTest):
refresh_data = False
-
+ only_tables = False
+
def setUpAll(self):
super(FixtureTest, self).setUpAll()
- if self.keep_data:
+ if not self.only_tables and self.keep_data:
install_fixture_data()
def setUp(self):
- if self.refresh_data:
+ if not self.only_tables and self.refresh_data:
install_fixture_data()
def define_tables(self, meta):
diff --git a/test/testlib/profiling.py b/test/testlib/profiling.py
index b452d1fb8..e423b9904 100644
--- a/test/testlib/profiling.py
+++ b/test/testlib/profiling.py
@@ -1,8 +1,7 @@
"""Profiling support for unit and performance tests."""
import os, sys
-from testlib.config import parser, post_configure
-from testlib.compat import *
+from testlib.compat import set, _function_named
import testlib.config
__all__ = 'profiled', 'function_call_count', 'conditional_call_count'
@@ -26,8 +25,6 @@ def profiled(target=None, **target_opts):
configuration and command-line options.
"""
- import time, hotshot, hotshot.stats
-
# manual or automatic namespacing by module would remove conflict issues
if target is None:
target = 'anonymous_target'
diff --git a/test/testlib/requires.py b/test/testlib/requires.py
new file mode 100644
index 000000000..a4604ff7f
--- /dev/null
+++ b/test/testlib/requires.py
@@ -0,0 +1,32 @@
+"""Global database feature support policy.
+
+Provides decorators to mark tests requiring specific feature support from the
+target database.
+
+"""
+from testlib import testing
+
+def savepoints(fn):
+ """Target database must support savepoints."""
+ return (testing.unsupported(
+ 'access',
+ 'mssql',
+ 'sqlite',
+ 'sybase',
+ )
+ (testing.exclude('mysql', '<', (5, 0, 3))
+ (fn)))
+
+def two_phase_transactions(fn):
+ """Target database must support two-phase transactions."""
+ return (testing.unsupported(
+ 'access',
+ 'firebird',
+ 'maxdb',
+ 'mssql',
+ 'oracle',
+ 'sqlite',
+ 'sybase',
+ )
+ (testing.exclude('mysql', '<', (5, 0, 3))
+ (fn)))
diff --git a/test/testlib/schema.py b/test/testlib/schema.py
index 37f3591ad..9cedc02f0 100644
--- a/test/testlib/schema.py
+++ b/test/testlib/schema.py
@@ -1,5 +1,5 @@
from testlib import testing
-import itertools
+
schema = None
__all__ = 'Table', 'Column',
diff --git a/test/testlib/tables.py b/test/testlib/tables.py
index 33b1b20db..3399acaae 100644
--- a/test/testlib/tables.py
+++ b/test/testlib/tables.py
@@ -1,8 +1,9 @@
# can't be imported until the path is setup; be sure to configure
# first if covering.
-from sqlalchemy import *
+
from testlib import testing
-from testlib.schema import Table, Column
+from testlib.sa import MetaData, Table, Column, Integer, String, Sequence, \
+ ForeignKey, VARCHAR, INT
# these are older test fixtures, used primarily by test/orm/mapper.py and
diff --git a/test/testlib/testing.py b/test/testlib/testing.py
index cf0936e92..1e2ca62e9 100644
--- a/test/testlib/testing.py
+++ b/test/testlib/testing.py
@@ -2,15 +2,27 @@
# monkeypatches unittest.TestLoader.suiteClass at import time
-import itertools, os, operator, re, sys, unittest, warnings
+import itertools
+import operator
+import re
+import sys
+import types
+import unittest
+import warnings
from cStringIO import StringIO
+
import testlib.config as config
-from testlib.compat import *
+from testlib.compat import set, _function_named, reversed
-sql, sqltypes, schema, MetaData, clear_mappers, Session, util = None, None, None, None, None, None, None
-sa_exceptions = None
+# Delayed imports
+MetaData = None
+Session = None
+clear_mappers = None
+sa_exc = None
+schema = None
+sqltypes = None
+util = None
-__all__ = ('TestBase', 'AssertsExecutionResults', 'ComparesTables', 'ORMTest', 'AssertsCompiledSQL')
_ops = { '<': operator.lt,
'>': operator.gt,
@@ -25,6 +37,9 @@ _ops = { '<': operator.lt,
# sugar ('testing.db'); set here by config() at runtime
db = None
+# more sugar, installed by __init__
+requires = None
+
def fails_if(callable_):
"""Mark a test as expected to fail if callable_ returns True.
@@ -224,17 +239,17 @@ def emits_warning(*messages):
# - update: jython looks ok, it uses cpython's module
def decorate(fn):
def safe(*args, **kw):
- global sa_exceptions
- if sa_exceptions is None:
- import sqlalchemy.exceptions as sa_exceptions
+ global sa_exc
+ if sa_exc is None:
+ import sqlalchemy.exc as sa_exc
if not messages:
filters = [dict(action='ignore',
- category=sa_exceptions.SAWarning)]
+ category=sa_exc.SAWarning)]
else:
filters = [dict(action='ignore',
message=message,
- category=sa_exceptions.SAWarning)
+ category=sa_exc.SAWarning)
for message in messages ]
for f in filters:
warnings.filterwarnings(**f)
@@ -259,17 +274,17 @@ def uses_deprecated(*messages):
def decorate(fn):
def safe(*args, **kw):
- global sa_exceptions
- if sa_exceptions is None:
- import sqlalchemy.exceptions as sa_exceptions
+ global sa_exc
+ if sa_exc is None:
+ import sqlalchemy.exc as sa_exc
if not messages:
filters = [dict(action='ignore',
- category=sa_exceptions.SADeprecationWarning)]
+ category=sa_exc.SADeprecationWarning)]
else:
filters = [dict(action='ignore',
message=message,
- category=sa_exceptions.SADeprecationWarning)
+ category=sa_exc.SADeprecationWarning)
for message in
[ (m.startswith('//') and
('Call to deprecated function ' + m[2:]) or m)
@@ -287,13 +302,13 @@ def uses_deprecated(*messages):
def resetwarnings():
"""Reset warning behavior to testing defaults."""
- global sa_exceptions
- if sa_exceptions is None:
- import sqlalchemy.exceptions as sa_exceptions
+ global sa_exc
+ if sa_exc is None:
+ import sqlalchemy.exc as sa_exc
warnings.resetwarnings()
- warnings.filterwarnings('error', category=sa_exceptions.SADeprecationWarning)
- warnings.filterwarnings('error', category=sa_exceptions.SAWarning)
+ warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning)
+ warnings.filterwarnings('error', category=sa_exc.SAWarning)
if sys.version_info < (2, 4):
warnings.filterwarnings('ignore', category=FutureWarning)
@@ -338,6 +353,23 @@ def rowset(results):
return set([tuple(row) for row in results])
+def eq_(a, b, msg=None):
+ """Assert a == b, with repr messaging on failure."""
+ assert a == b, msg or "%r != %r" % (a, b)
+
+def ne_(a, b, msg=None):
+ """Assert a != b, with repr messaging on failure."""
+ assert a != b, msg or "%r == %r" % (a, b)
+
+def is_(a, b, msg=None):
+ """Assert a is b, with repr messaging on failure."""
+ assert a is b, msg or "%r is not %r" % (a, b)
+
+def is_not_(a, b, msg=None):
+ """Assert a is not b, with repr messaging on failure."""
+ assert a is not b, msg or "%r is %r" % (a, b)
+
+
class TestData(object):
"""Tracks SQL expressions as they are executed via an instrumented ExecutionContext."""
@@ -360,10 +392,6 @@ class ExecutionContextWrapper(object):
can be tracked."""
def __init__(self, ctx):
- global sql
- if sql is None:
- from sqlalchemy import sql
-
self.__dict__['ctx'] = ctx
def __getattr__(self, key):
return getattr(self.ctx, key)
@@ -414,7 +442,7 @@ class ExecutionContextWrapper(object):
query = self.convert_statement(query)
equivalent = ( (statement == query)
- or ( (config.db.name == 'oracle') and (self.trailing_underscore_pattern.sub(r'\1', statement) == query) )
+ or ( (config.db.name == 'oracle') and (self.trailing_underscore_pattern.sub(r'\1', statement) == query) )
) \
and \
( (params is None) or (params == parameters)
@@ -422,7 +450,7 @@ class ExecutionContextWrapper(object):
for (k, v) in p.items()])
for p in parameters]
)
- testdata.unittest.assert_(equivalent,
+ testdata.unittest.assert_(equivalent,
"Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
testdata.sql_count += 1
self.ctx.post_execution()
@@ -445,6 +473,44 @@ class ExecutionContextWrapper(object):
query = re.sub(r':([\w_]+)', repl, query)
return query
+
+def _import_by_name(name):
+ submodule = name.split('.')[-1]
+ return __import__(name, globals(), locals(), [submodule])
+
+class CompositeModule(types.ModuleType):
+ """Merged attribute access for multiple modules."""
+
+ # break the habit
+ __all__ = ()
+
+ def __init__(self, name, *modules, **overrides):
+ """Construct a new lazy composite of modules.
+
+ Modules may be string names or module-like instances. Individual
+ attribute overrides may be specified as keyword arguments for
+ convenience.
+
+ The constructed module will resolve attribute access in reverse order:
+ overrides, then each member of reversed(modules). Modules specified
+ by name will be loaded lazily when encountered in attribute
+ resolution.
+
+ """
+ types.ModuleType.__init__(self, name)
+ self.__modules = list(reversed(modules))
+ for key, value in overrides.iteritems():
+ setattr(self, key, value)
+
+ def __getattr__(self, key):
+ for idx, mod in enumerate(self.__modules):
+ if isinstance(mod, basestring):
+ self.__modules[idx] = mod = _import_by_name(mod)
+ if hasattr(mod, key):
+ return getattr(mod, key)
+ raise AttributeError(key)
+
+
class TestBase(unittest.TestCase):
# A sequence of dialect names to exclude from the test class.
__unsupported_on__ = ()
@@ -469,14 +535,14 @@ class TestBase(unittest.TestCase):
def shortDescription(self):
"""overridden to not return docstrings"""
return None
-
+
def assertRaisesMessage(self, except_cls, msg, callable_, *args, **kwargs):
try:
callable_(*args, **kwargs)
assert False, "Callable did not raise expected exception"
except except_cls, e:
assert re.search(msg, str(e)), "Exception message did not match: '%s'" % str(e)
-
+
if not hasattr(unittest.TestCase, 'assertTrue'):
assertTrue = unittest.TestCase.failUnless
if not hasattr(unittest.TestCase, 'assertFalse'):
@@ -522,7 +588,7 @@ class ComparesTables(object):
set(type(c.type).__mro__).difference(base_mro)
)
) > 0, "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type)
-
+
if isinstance(c.type, sqltypes.String):
self.assertEquals(c.type.length, reflected_c.type.length)
@@ -535,18 +601,18 @@ class ComparesTables(object):
elif not c.primary_key or not against('postgres'):
print repr(c)
assert reflected_c.default is None, reflected_c.default
-
+
assert len(table.primary_key) == len(reflected_table.primary_key)
for c in table.primary_key:
assert reflected_table.primary_key.columns[c.name]
-
+
class AssertsExecutionResults(object):
def assert_result(self, result, class_, *objects):
result = list(result)
print repr(result)
self.assert_list(result, class_, objects)
-
+
def assert_list(self, result, class_, list):
self.assert_(len(result) == len(list),
"result list is not the same size as test list, " +
@@ -675,10 +741,10 @@ class ORMTest(TestBase, AssertsExecutionResults):
def define_tables(self, _otest_metadata):
raise NotImplementedError()
-
+
def setup_mappers(self):
pass
-
+
def insert_data(self):
pass