summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xtests/__init__.py2
-rw-r--r--tests/test_async.py252
2 files changed, 254 insertions, 0 deletions
diff --git a/tests/__init__.py b/tests/__init__.py
index a5d7118..47aedc8 100755
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -27,6 +27,7 @@ import test_transaction
import types_basic
import types_extras
import test_lobject
+import test_async
def test_suite():
suite = unittest.TestSuite()
@@ -40,6 +41,7 @@ def test_suite():
suite.addTest(types_basic.test_suite())
suite.addTest(types_extras.test_suite())
suite.addTest(test_lobject.test_suite())
+ suite.addTest(test_async.test_suite())
return suite
if __name__ == '__main__':
diff --git a/tests/test_async.py b/tests/test_async.py
new file mode 100644
index 0000000..c2bbf4c
--- /dev/null
+++ b/tests/test_async.py
@@ -0,0 +1,252 @@
+#!/usr/bin/env python
+import unittest
+
+import psycopg2
+from psycopg2 import extensions
+
+import select
+import StringIO
+
+import sys
+if sys.version_info < (3,):
+ import tests
+else:
+ import py3tests as tests
+
+
+class AsyncTests(unittest.TestCase):
+
+ def setUp(self):
+ self.sync_conn = psycopg2.connect(tests.dsn)
+ self.conn = psycopg2.connect(tests.dsn, async=True)
+
+ state = psycopg2.extensions.POLL_WRITE
+ while state != psycopg2.extensions.POLL_OK:
+ if state == psycopg2.extensions.POLL_WRITE:
+ select.select([], [self.conn.fileno()], [])
+ elif state == psycopg2.extensions.POLL_READ:
+ select.select([self.conn.fileno()], [], [])
+ state = self.conn.poll()
+
+ curs = self.conn.cursor()
+ curs.execute('''
+ CREATE TEMPORARY TABLE table1 (
+ id int PRIMARY KEY
+ )''')
+ self.conn.commit()
+
+ def tearDown(self):
+ self.sync_conn.close()
+ self.conn.close()
+
+ def wait_for_query(self, cur):
+ state = cur.poll()
+ while state != psycopg2.extensions.POLL_OK:
+ if state == psycopg2.extensions.POLL_READ:
+ select.select([cur.fileno()], [], [])
+ elif state == psycopg2.extensions.POLL_WRITE:
+ select.select([], [cur.fileno()], [])
+ state = cur.poll()
+
+ def test_wrong_execution_type(self):
+ cur = self.conn.cursor()
+ sync_cur = self.sync_conn.cursor()
+
+ self.assertRaises(psycopg2.ProgrammingError, cur.execute,
+ "select 'a'", async=False)
+ self.assertRaises(psycopg2.ProgrammingError, sync_cur.execute,
+ "select 'a'", async=True)
+
+ # but this should work anyway
+ sync_cur.execute("select 'a'", async=False)
+ cur.execute("select 'a'", async=True)
+
+ def test_async_select(self):
+ cur = self.conn.cursor()
+ self.assertFalse(self.conn.executing())
+ cur.execute("select 'a'")
+ self.assertTrue(self.conn.executing())
+
+ self.wait_for_query(cur)
+
+ self.assertFalse(self.conn.executing())
+ self.assertEquals(cur.fetchone()[0], "a")
+
+ def test_async_callproc(self):
+ cur = self.conn.cursor()
+ try:
+ cur.callproc("pg_sleep", (0.1, ), True)
+ except psycopg2.ProgrammingError:
+ # PG <8.1 did not have pg_sleep
+ return
+ self.assertTrue(self.conn.executing())
+
+ self.wait_for_query(cur)
+ self.assertFalse(self.conn.executing())
+ self.assertEquals(cur.fetchall()[0][0], '')
+
+ def test_async_after_async(self):
+ cur = self.conn.cursor()
+ cur2 = self.conn.cursor()
+
+ cur.execute("insert into table1 values (1)")
+
+ # an async execute after an async one blocks and waits for completion
+ cur.execute("select * from table1")
+ self.wait_for_query(cur)
+
+ self.assertEquals(cur.fetchall()[0][0], 1)
+
+ cur.execute("delete from table1")
+ self.wait_for_query(cur)
+
+ cur.execute("select * from table1")
+ self.wait_for_query(cur)
+
+ self.assertEquals(cur.fetchone(), None)
+
+ def test_fetch_after_async(self):
+ cur = self.conn.cursor()
+ cur.execute("select 'a'")
+
+ # a fetch after an asynchronous query blocks and waits for completion
+ self.assertEquals(cur.fetchall()[0][0], "a")
+
+ def test_rollback_while_async(self):
+ cur = self.conn.cursor()
+
+ cur.execute("select 'a'")
+
+ # a rollback blocks and should leave the connection in a workable state
+ self.conn.rollback()
+ self.assertFalse(self.conn.executing())
+
+ # try a sync cursor first
+ sync_cur = self.sync_conn.cursor()
+ sync_cur.execute("select 'b'")
+ self.assertEquals(sync_cur.fetchone()[0], "b")
+
+ # now try the async cursor
+ cur.execute("select 'c'")
+ self.wait_for_query(cur)
+ self.assertEquals(cur.fetchmany()[0][0], "c")
+
+ def test_commit_while_async(self):
+ cur = self.conn.cursor()
+
+ cur.execute("insert into table1 values (1)")
+
+ # a commit blocks
+ self.conn.commit()
+ self.assertFalse(self.conn.executing())
+
+ cur.execute("select * from table1")
+ self.wait_for_query(cur)
+ self.assertEquals(cur.fetchall()[0][0], 1)
+
+ cur.execute("delete from table1")
+ self.conn.commit()
+
+ cur.execute("select * from table1")
+ self.wait_for_query(cur)
+ self.assertEquals(cur.fetchone(), None)
+
+ def test_set_parameters_while_async(self):
+ prev_encoding = self.conn.encoding
+ cur = self.conn.cursor()
+
+ cur.execute("select 'c'")
+ self.assertTrue(self.conn.executing())
+
+ # getting transaction status works
+ self.assertEquals(self.conn.get_transaction_status(),
+ extensions.TRANSACTION_STATUS_ACTIVE)
+ self.assertTrue(self.conn.executing())
+
+ # this issues a ROLLBACK internally
+ self.conn.set_client_encoding("LATIN1")
+
+ self.assertFalse(self.conn.executing())
+ self.assertEquals(self.conn.encoding, "LATIN1")
+
+ self.conn.set_client_encoding(prev_encoding)
+
+ def test_reset_while_async(self):
+ prev_encoding = self.conn.encoding
+ # pick something different than the current encoding
+ new_encoding = (prev_encoding == "LATIN1") and "UTF8" or "LATIN1"
+
+ self.conn.set_client_encoding(new_encoding)
+
+ cur = self.conn.cursor()
+ cur.execute("select 'c'")
+ self.assertTrue(self.conn.executing())
+
+ self.conn.reset()
+ self.assertFalse(self.conn.executing())
+ self.assertEquals(self.conn.encoding, prev_encoding)
+
+ def test_async_iter(self):
+ cur = self.conn.cursor()
+
+ cur.execute("insert into table1 values (1), (2), (3)")
+ self.wait_for_query(cur)
+ cur.execute("select id from table1 order by id")
+
+ # iteration just blocks
+ self.assertEquals(list(cur), [(1, ), (2, ), (3, )])
+ self.assertFalse(self.conn.executing())
+
+ def test_copy_while_async(self):
+ cur = self.conn.cursor()
+ cur.execute("select 'a'")
+
+ # copy just blocks
+ cur.copy_from(StringIO.StringIO("1\n3\n5\n\\.\n"), "table1")
+
+ cur.execute("select * from table1 order by id")
+ self.assertEquals(cur.fetchall(), [(1, ), (3, ), (5, )])
+
+ def test_async_executemany(self):
+ cur = self.conn.cursor()
+ self.assertRaises(
+ psycopg2.ProgrammingError,
+ cur.executemany, "insert into table1 values (%s)", [1, 2, 3])
+
+ def test_async_scroll(self):
+ cur = self.conn.cursor()
+ cur.execute("insert into table1 values (1), (2), (3)")
+ self.wait_for_query(cur)
+ cur.execute("select id from table1 order by id")
+
+ # scroll blocks, but should work
+ cur.scroll(1)
+ self.assertFalse(self.conn.executing())
+ self.assertEquals(cur.fetchall(), [(2, ), (3, )])
+
+ cur = self.conn.cursor()
+ cur.execute("select id from table1 order by id")
+
+ cur2 = self.conn.cursor()
+ self.assertRaises(psycopg2.ProgrammingError, cur2.scroll, 1)
+
+ self.assertRaises(psycopg2.ProgrammingError, cur.scroll, 4)
+
+ cur = self.conn.cursor()
+ cur.execute("select id from table1 order by id")
+ cur.scroll(2)
+ cur.scroll(-1)
+ self.assertEquals(cur.fetchall(), [(2, ), (3, )])
+
+ def test_async_dont_read_all(self):
+ cur = self.conn.cursor()
+ cur.execute("select 'a'; select 'b'")
+
+ # fetch the result
+ self.wait_for_query(cur)
+
+ # it should be the result of the second query
+ self.assertEquals(cur.fetchone()[0][0], "b")
+
+def test_suite():
+ return unittest.TestLoader().loadTestsFromName(__name__)