summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--psycopg/cursor_type.c6
-rwxr-xr-xtests/test_async.py29
2 files changed, 34 insertions, 1 deletions
diff --git a/psycopg/cursor_type.c b/psycopg/cursor_type.c
index 895b3ce..05d1436 100644
--- a/psycopg/cursor_type.c
+++ b/psycopg/cursor_type.c
@@ -1480,6 +1480,12 @@ psyco_curs_poll(cursorObject *self)
{
EXC_IF_CURS_CLOSED(self);
+ if (self->conn->async_cursor != NULL &&
+ self->conn->async_cursor != (PyObject *) self) {
+ PyErr_SetString(ProgrammingError, "poll with wrong cursor");
+ return NULL;
+ }
+
Dprintf("curs_poll: polling with status %d", self->conn->async_status);
if (self->conn->async_status == ASYNC_WRITE) {
diff --git a/tests/test_async.py b/tests/test_async.py
index da084bd..6cb9149 100755
--- a/tests/test_async.py
+++ b/tests/test_async.py
@@ -299,7 +299,6 @@ class AsyncTests(unittest.TestCase):
self.assert_(not conn.issync())
conn.close()
-
def test_flush_on_write(self):
# a very large query requires a flush loop to be sent to the backend
curs = self.conn.cursor()
@@ -315,6 +314,34 @@ class AsyncTests(unittest.TestCase):
self.fail("sending a large query didn't trigger block on write.")
+ def test_sync_poll(self):
+ cur = self.sync_conn.cursor()
+ # polling a sync cursor works
+ cur.poll()
+
+ def test_async_poll_wrong_cursor(self):
+ cur1 = self.conn.cursor()
+ cur2 = self.conn.cursor()
+ cur1.execute("select 1")
+
+ # polling a cursor that's not currently executing is an error
+ self.assertRaises(psycopg2.ProgrammingError, cur2.poll)
+
+ self.wait_for_query(cur1)
+ self.assertEquals(cur1.fetchone()[0], 1)
+
+ def test_async_fetch_wrong_cursor(self):
+ cur1 = self.conn.cursor()
+ cur2 = self.conn.cursor()
+ cur1.execute("select 1")
+
+ self.wait_for_query(cur1)
+ self.assertFalse(self.conn.executing())
+ # fetching from a cursor with no results is an error
+ self.assertRaises(psycopg2.ProgrammingError, cur2.fetchone)
+ # fetching from the correct cursor works
+ self.assertEquals(cur1.fetchone()[0], 1)
+
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)