summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--NEWS2
-rw-r--r--psycopg/cursor.h9
-rwxr-xr-xtests/test_cursor.py14
3 files changed, 22 insertions, 3 deletions
diff --git a/NEWS b/NEWS
index 272be2d..d5f1686 100644
--- a/NEWS
+++ b/NEWS
@@ -17,6 +17,8 @@ What's new in psycopg 2.5.3
Chris Withers (:ticket:`#193`).
- Avoid blocking async connections on connect (:ticket:`#194`). Thanks to
Adam Petrovich for the bug report and diagnosis.
+- Don't segfault using poorly defined cursor subclasses which forgot to call
+ the superclass init (:ticket:`#195`).
- Fixed debug build on Windows, thanks to James Emerton.
diff --git a/psycopg/cursor.h b/psycopg/cursor.h
index f96c657..e291d45 100644
--- a/psycopg/cursor.h
+++ b/psycopg/cursor.h
@@ -97,11 +97,14 @@ HIDDEN int psyco_curs_scrollable_set(cursorObject *self, PyObject *pyvalue);
/* exception-raising macros */
#define EXC_IF_CURS_CLOSED(self) \
-do \
- if ((self)->closed || ((self)->conn && (self)->conn->closed)) { \
+do { \
+ if (!(self)->conn) { \
+ PyErr_SetString(InterfaceError, "the cursor has no connection"); \
+ return NULL; } \
+ if ((self)->closed || (self)->conn->closed) { \
PyErr_SetString(InterfaceError, "cursor already closed"); \
return NULL; } \
-while (0)
+} while (0)
#define EXC_IF_NO_TUPLES(self) \
do \
diff --git a/tests/test_cursor.py b/tests/test_cursor.py
index c35d26c..cba5cca 100755
--- a/tests/test_cursor.py
+++ b/tests/test_cursor.py
@@ -413,6 +413,20 @@ class CursorTests(ConnectingTestCase):
cur.scroll(9, mode='absolute')
self.assertEqual(cur.fetchone(), (9,))
+ def test_bad_subclass(self):
+ # check that we get an error message instead of a segfault
+ # for badly written subclasses.
+ # see http://stackoverflow.com/questions/22019341/
+ class StupidCursor(psycopg2.extensions.cursor):
+ def __init__(self, *args, **kwargs):
+ # I am stupid so not calling superclass init
+ pass
+
+ cur = StupidCursor()
+ self.assertRaises(psycopg2.InterfaceError, cur.execute, 'select 1')
+ self.assertRaises(psycopg2.InterfaceError, cur.executemany,
+ 'select 1', [])
+
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)