diff options
Diffstat (limited to 'tests/testutils.py')
-rw-r--r-- | tests/testutils.py | 95 |
1 files changed, 90 insertions, 5 deletions
diff --git a/tests/testutils.py b/tests/testutils.py index 26551d4..1e2e150 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -26,6 +26,8 @@ import os import sys +from functools import wraps +from testconfig import dsn try: import unittest2 @@ -43,6 +45,7 @@ else: def skipIf(cond, msg): def skipIf_(f): + @wraps(f) def skipIf__(self): if cond: warnings.warn(msg) @@ -72,15 +75,62 @@ or unittest.TestCase.assert_ is not unittest.TestCase.assertTrue: unittest.TestCase.failUnlessEqual = unittest.TestCase.assertEqual -def decorate_all_tests(cls, decorator): - """Apply *decorator* to all the tests defined in the TestCase *cls*.""" +class ConnectingTestCase(unittest.TestCase): + """A test case providing connections for tests. + + A connection for the test is always available as `self.conn`. Others can be + created with `self.connect()`. All are closed on tearDown. + + Subclasses needing to customize setUp and tearDown should remember to call + the base class implementations. + """ + def setUp(self): + self._conns = [] + + def tearDown(self): + # close the connections used in the test + for conn in self._conns: + if not conn.closed: + conn.close() + + def connect(self, **kwargs): + try: + self._conns + except AttributeError, e: + raise AttributeError( + "%s (did you remember calling ConnectingTestCase.setUp()?)" + % e) + + import psycopg2 + conn = psycopg2.connect(dsn, **kwargs) + self._conns.append(conn) + return conn + + def _get_conn(self): + if not hasattr(self, '_the_conn'): + self._the_conn = self.connect() + + return self._the_conn + + def _set_conn(self, conn): + self._the_conn = conn + + conn = property(_get_conn, _set_conn) + + +def decorate_all_tests(cls, *decorators): + """ + Apply all the *decorators* to all the tests defined in the TestCase *cls*. + """ for n in dir(cls): if n.startswith('test'): - setattr(cls, n, decorator(getattr(cls, n))) + for d in decorators: + setattr(cls, n, d(getattr(cls, n))) def skip_if_no_uuid(f): """Decorator to skip a test if uuid is not supported by Py/PG.""" + @wraps(f) def skip_if_no_uuid_(self): try: import uuid @@ -104,6 +154,7 @@ def skip_if_no_uuid(f): def skip_if_tpc_disabled(f): """Skip a test if the server has tpc support disabled.""" + @wraps(f) def skip_if_tpc_disabled_(self): from psycopg2 import ProgrammingError cnn = self.connect() @@ -123,11 +174,11 @@ def skip_if_tpc_disabled(f): "set max_prepared_transactions to > 0 to run the test") return f(self) - skip_if_tpc_disabled_.__name__ = f.__name__ return skip_if_tpc_disabled_ def skip_if_no_namedtuple(f): + @wraps(f) def skip_if_no_namedtuple_(self): try: from collections import namedtuple @@ -136,12 +187,12 @@ def skip_if_no_namedtuple(f): else: return f(self) - skip_if_no_namedtuple_.__name__ = f.__name__ return skip_if_no_namedtuple_ def skip_if_no_iobase(f): """Skip a test if io.TextIOBase is not available.""" + @wraps(f) def skip_if_no_iobase_(self): try: from io import TextIOBase @@ -157,6 +208,7 @@ def skip_before_postgres(*ver): """Skip a test on PostgreSQL before a certain version.""" ver = ver + (0,) * (3 - len(ver)) def skip_before_postgres_(f): + @wraps(f) def skip_before_postgres__(self): if self.conn.server_version < int("%d%02d%02d" % ver): return self.skipTest("skipped because PostgreSQL %s" @@ -171,6 +223,7 @@ def skip_after_postgres(*ver): """Skip a test on PostgreSQL after (including) a certain version.""" ver = ver + (0,) * (3 - len(ver)) def skip_after_postgres_(f): + @wraps(f) def skip_after_postgres__(self): if self.conn.server_version >= int("%d%02d%02d" % ver): return self.skipTest("skipped because PostgreSQL %s" @@ -184,6 +237,7 @@ def skip_after_postgres(*ver): def skip_before_python(*ver): """Skip a test on Python before a certain version.""" def skip_before_python_(f): + @wraps(f) def skip_before_python__(self): if sys.version_info[:len(ver)] < ver: return self.skipTest("skipped because Python %s" @@ -197,6 +251,7 @@ def skip_before_python(*ver): def skip_from_python(*ver): """Skip a test on Python after (including) a certain version.""" def skip_from_python_(f): + @wraps(f) def skip_from_python__(self): if sys.version_info[:len(ver)] >= ver: return self.skipTest("skipped because Python %s" @@ -207,6 +262,36 @@ def skip_from_python(*ver): return skip_from_python__ return skip_from_python_ +def skip_if_no_superuser(f): + """Skip a test if the database user running the test is not a superuser""" + @wraps(f) + def skip_if_no_superuser_(self): + from psycopg2 import ProgrammingError + try: + return f(self) + except ProgrammingError, e: + import psycopg2.errorcodes + if e.pgcode == psycopg2.errorcodes.INSUFFICIENT_PRIVILEGE: + self.skipTest("skipped because not superuser") + else: + raise + + return skip_if_no_superuser_ + +def skip_if_green(reason): + def skip_if_green_(f): + @wraps(f) + def skip_if_green__(self): + from testconfig import green + if green: + return self.skipTest(reason) + else: + return f(self) + + return skip_if_green__ + return skip_if_green_ + +skip_copy_if_green = skip_if_green("copy in async mode currently not supported") def script_to_py3(script): """Convert a script to Python3 syntax if required.""" |