summaryrefslogtreecommitdiff
path: root/tests/testutils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/testutils.py')
-rw-r--r--tests/testutils.py95
1 files changed, 90 insertions, 5 deletions
diff --git a/tests/testutils.py b/tests/testutils.py
index 26551d4..1e2e150 100644
--- a/tests/testutils.py
+++ b/tests/testutils.py
@@ -26,6 +26,8 @@
import os
import sys
+from functools import wraps
+from testconfig import dsn
try:
import unittest2
@@ -43,6 +45,7 @@ else:
def skipIf(cond, msg):
def skipIf_(f):
+ @wraps(f)
def skipIf__(self):
if cond:
warnings.warn(msg)
@@ -72,15 +75,62 @@ or unittest.TestCase.assert_ is not unittest.TestCase.assertTrue:
unittest.TestCase.failUnlessEqual = unittest.TestCase.assertEqual
-def decorate_all_tests(cls, decorator):
- """Apply *decorator* to all the tests defined in the TestCase *cls*."""
+class ConnectingTestCase(unittest.TestCase):
+ """A test case providing connections for tests.
+
+ A connection for the test is always available as `self.conn`. Others can be
+ created with `self.connect()`. All are closed on tearDown.
+
+ Subclasses needing to customize setUp and tearDown should remember to call
+ the base class implementations.
+ """
+ def setUp(self):
+ self._conns = []
+
+ def tearDown(self):
+ # close the connections used in the test
+ for conn in self._conns:
+ if not conn.closed:
+ conn.close()
+
+ def connect(self, **kwargs):
+ try:
+ self._conns
+ except AttributeError, e:
+ raise AttributeError(
+ "%s (did you remember calling ConnectingTestCase.setUp()?)"
+ % e)
+
+ import psycopg2
+ conn = psycopg2.connect(dsn, **kwargs)
+ self._conns.append(conn)
+ return conn
+
+ def _get_conn(self):
+ if not hasattr(self, '_the_conn'):
+ self._the_conn = self.connect()
+
+ return self._the_conn
+
+ def _set_conn(self, conn):
+ self._the_conn = conn
+
+ conn = property(_get_conn, _set_conn)
+
+
+def decorate_all_tests(cls, *decorators):
+ """
+ Apply all the *decorators* to all the tests defined in the TestCase *cls*.
+ """
for n in dir(cls):
if n.startswith('test'):
- setattr(cls, n, decorator(getattr(cls, n)))
+ for d in decorators:
+ setattr(cls, n, d(getattr(cls, n)))
def skip_if_no_uuid(f):
"""Decorator to skip a test if uuid is not supported by Py/PG."""
+ @wraps(f)
def skip_if_no_uuid_(self):
try:
import uuid
@@ -104,6 +154,7 @@ def skip_if_no_uuid(f):
def skip_if_tpc_disabled(f):
"""Skip a test if the server has tpc support disabled."""
+ @wraps(f)
def skip_if_tpc_disabled_(self):
from psycopg2 import ProgrammingError
cnn = self.connect()
@@ -123,11 +174,11 @@ def skip_if_tpc_disabled(f):
"set max_prepared_transactions to > 0 to run the test")
return f(self)
- skip_if_tpc_disabled_.__name__ = f.__name__
return skip_if_tpc_disabled_
def skip_if_no_namedtuple(f):
+ @wraps(f)
def skip_if_no_namedtuple_(self):
try:
from collections import namedtuple
@@ -136,12 +187,12 @@ def skip_if_no_namedtuple(f):
else:
return f(self)
- skip_if_no_namedtuple_.__name__ = f.__name__
return skip_if_no_namedtuple_
def skip_if_no_iobase(f):
"""Skip a test if io.TextIOBase is not available."""
+ @wraps(f)
def skip_if_no_iobase_(self):
try:
from io import TextIOBase
@@ -157,6 +208,7 @@ def skip_before_postgres(*ver):
"""Skip a test on PostgreSQL before a certain version."""
ver = ver + (0,) * (3 - len(ver))
def skip_before_postgres_(f):
+ @wraps(f)
def skip_before_postgres__(self):
if self.conn.server_version < int("%d%02d%02d" % ver):
return self.skipTest("skipped because PostgreSQL %s"
@@ -171,6 +223,7 @@ def skip_after_postgres(*ver):
"""Skip a test on PostgreSQL after (including) a certain version."""
ver = ver + (0,) * (3 - len(ver))
def skip_after_postgres_(f):
+ @wraps(f)
def skip_after_postgres__(self):
if self.conn.server_version >= int("%d%02d%02d" % ver):
return self.skipTest("skipped because PostgreSQL %s"
@@ -184,6 +237,7 @@ def skip_after_postgres(*ver):
def skip_before_python(*ver):
"""Skip a test on Python before a certain version."""
def skip_before_python_(f):
+ @wraps(f)
def skip_before_python__(self):
if sys.version_info[:len(ver)] < ver:
return self.skipTest("skipped because Python %s"
@@ -197,6 +251,7 @@ def skip_before_python(*ver):
def skip_from_python(*ver):
"""Skip a test on Python after (including) a certain version."""
def skip_from_python_(f):
+ @wraps(f)
def skip_from_python__(self):
if sys.version_info[:len(ver)] >= ver:
return self.skipTest("skipped because Python %s"
@@ -207,6 +262,36 @@ def skip_from_python(*ver):
return skip_from_python__
return skip_from_python_
+def skip_if_no_superuser(f):
+ """Skip a test if the database user running the test is not a superuser"""
+ @wraps(f)
+ def skip_if_no_superuser_(self):
+ from psycopg2 import ProgrammingError
+ try:
+ return f(self)
+ except ProgrammingError, e:
+ import psycopg2.errorcodes
+ if e.pgcode == psycopg2.errorcodes.INSUFFICIENT_PRIVILEGE:
+ self.skipTest("skipped because not superuser")
+ else:
+ raise
+
+ return skip_if_no_superuser_
+
+def skip_if_green(reason):
+ def skip_if_green_(f):
+ @wraps(f)
+ def skip_if_green__(self):
+ from testconfig import green
+ if green:
+ return self.skipTest(reason)
+ else:
+ return f(self)
+
+ return skip_if_green__
+ return skip_if_green_
+
+skip_copy_if_green = skip_if_green("copy in async mode currently not supported")
def script_to_py3(script):
"""Convert a script to Python3 syntax if required."""