summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>2016-07-01 18:03:12 +0100
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>2016-07-01 18:03:12 +0100
commit1442655d3ce8b64abe57200f18bd8878a8b2c4a6 (patch)
tree2571ed219d69dc129161510f37858c1b681e71fc
parentc29b5cd46a24fd81cff8b3affd9c78d18d53aa69 (diff)
parent4a450b63c418bf7e6e62f7b444fd2edd9db246da (diff)
downloadpsycopg2-1442655d3ce8b64abe57200f18bd8878a8b2c4a6.tar.gz
Merge branch 'qstring-writable-encoding'
-rw-r--r--NEWS1
-rw-r--r--psycopg/adapter_qstring.c74
-rw-r--r--psycopg/adapter_qstring.h3
-rwxr-xr-xtests/test_quote.py55
-rwxr-xr-xtests/test_types_basic.py15
-rw-r--r--tests/testconfig.py2
6 files changed, 114 insertions, 36 deletions
diff --git a/NEWS b/NEWS
index 9056b0f..4571f67 100644
--- a/NEWS
+++ b/NEWS
@@ -26,6 +26,7 @@ What's new in psycopg 2.6.2
- Report the server response status on errors (such as :ticket:`#281`).
- Raise `!NotSupportedError` on unhandled server response status
(:ticket:`#352`).
+- Allow overriding string adapter encoding with no connection (:ticket:`#331`).
- The `~psycopg2.extras.wait_select` callback allows interrupting a
long-running query in an interactive shell using :kbd:`Ctrl-C`
(:ticket:`#333`).
diff --git a/psycopg/adapter_qstring.c b/psycopg/adapter_qstring.c
index 2e3ab0a..110093e 100644
--- a/psycopg/adapter_qstring.c
+++ b/psycopg/adapter_qstring.c
@@ -36,28 +36,43 @@ static const char *default_encoding = "latin1";
/* qstring_quote - do the quote process on plain and unicode strings */
+const char *
+_qstring_get_encoding(qstringObject *self)
+{
+ /* if the wrapped object is an unicode object we can encode it to match
+ conn->encoding but if the encoding is not specified we don't know what
+ to do and we raise an exception */
+ if (self->conn) {
+ return self->conn->codec;
+ }
+ else {
+ return self->encoding ? self->encoding : default_encoding;
+ }
+}
+
static PyObject *
qstring_quote(qstringObject *self)
{
PyObject *str = NULL;
char *s, *buffer = NULL;
Py_ssize_t len, qlen;
- const char *encoding = default_encoding;
+ const char *encoding;
PyObject *rv = NULL;
- /* if the wrapped object is an unicode object we can encode it to match
- conn->encoding but if the encoding is not specified we don't know what
- to do and we raise an exception */
- if (self->conn) {
- encoding = self->conn->codec;
- }
-
+ encoding = _qstring_get_encoding(self);
Dprintf("qstring_quote: encoding to %s", encoding);
- if (PyUnicode_Check(self->wrapped) && encoding) {
- str = PyUnicode_AsEncodedString(self->wrapped, encoding, NULL);
- Dprintf("qstring_quote: got encoded object at %p", str);
- if (str == NULL) goto exit;
+ if (PyUnicode_Check(self->wrapped)) {
+ if (encoding) {
+ str = PyUnicode_AsEncodedString(self->wrapped, encoding, NULL);
+ Dprintf("qstring_quote: got encoded object at %p", str);
+ if (str == NULL) goto exit;
+ }
+ else {
+ PyErr_SetString(PyExc_TypeError,
+ "missing encoding to encode unicode object");
+ goto exit;
+ }
}
#if PY_MAJOR_VERSION < 3
@@ -72,8 +87,7 @@ qstring_quote(qstringObject *self)
/* if the wrapped object is not a string, this is an error */
else {
- PyErr_SetString(PyExc_TypeError,
- "can't quote non-string object (or missing encoding)");
+ PyErr_SetString(PyExc_TypeError, "can't quote non-string object");
goto exit;
}
@@ -150,13 +164,32 @@ qstring_conform(qstringObject *self, PyObject *args)
static PyObject *
qstring_get_encoding(qstringObject *self)
{
- const char *encoding = default_encoding;
+ const char *encoding;
+ encoding = _qstring_get_encoding(self);
+ return Text_FromUTF8(encoding);
+}
- if (self->conn) {
- encoding = self->conn->codec;
- }
+static int
+qstring_set_encoding(qstringObject *self, PyObject *pyenc)
+{
+ int rv = -1;
+ const char *tmp;
+ char *cenc;
- return Text_FromUTF8(encoding);
+ /* get a C copy of the encoding (which may come from unicode) */
+ Py_INCREF(pyenc);
+ if (!(pyenc = psycopg_ensure_bytes(pyenc))) { goto exit; }
+ if (!(tmp = Bytes_AsString(pyenc))) { goto exit; }
+ if (0 > psycopg_strdup(&cenc, tmp, 0)) { goto exit; }
+
+ Dprintf("qstring_set_encoding: encoding set to %s", cenc);
+ PyMem_Free((void *)self->encoding);
+ self->encoding = cenc;
+ rv = 0;
+
+exit:
+ Py_XDECREF(pyenc);
+ return rv;
}
/** the QuotedString object **/
@@ -183,7 +216,7 @@ static PyMethodDef qstringObject_methods[] = {
static PyGetSetDef qstringObject_getsets[] = {
{ "encoding",
(getter)qstring_get_encoding,
- (setter)NULL,
+ (setter)qstring_set_encoding,
"current encoding of the adapter" },
{NULL}
};
@@ -216,6 +249,7 @@ qstring_dealloc(PyObject* obj)
Py_CLEAR(self->wrapped);
Py_CLEAR(self->buffer);
Py_CLEAR(self->conn);
+ PyMem_Free((void *)self->encoding);
Dprintf("qstring_dealloc: deleted qstring object at %p, refcnt = "
FORMAT_CODE_PY_SSIZE_T,
diff --git a/psycopg/adapter_qstring.h b/psycopg/adapter_qstring.h
index b7b086f..8abdc5f 100644
--- a/psycopg/adapter_qstring.h
+++ b/psycopg/adapter_qstring.h
@@ -39,6 +39,9 @@ typedef struct {
PyObject *buffer;
connectionObject *conn;
+
+ const char *encoding;
+
} qstringObject;
#ifdef __cplusplus
diff --git a/tests/test_quote.py b/tests/test_quote.py
index 6e94562..0a204c8 100755
--- a/tests/test_quote.py
+++ b/tests/test_quote.py
@@ -29,6 +29,7 @@ import psycopg2
import psycopg2.extensions
from psycopg2.extensions import b
+
class QuotingTestCase(ConnectingTestCase):
r"""Checks the correct quoting of strings and binary objects.
@@ -51,7 +52,7 @@ class QuotingTestCase(ConnectingTestCase):
data = """some data with \t chars
to escape into, 'quotes' and \\ a backslash too.
"""
- data += "".join(map(chr, range(1,127)))
+ data += "".join(map(chr, range(1, 127)))
curs = self.conn.cursor()
curs.execute("SELECT %s;", (data,))
@@ -90,13 +91,13 @@ class QuotingTestCase(ConnectingTestCase):
if server_encoding != "UTF8":
return self.skipTest(
"Unicode test skipped since server encoding is %s"
- % server_encoding)
+ % server_encoding)
data = u"""some data with \t chars
to escape into, 'quotes', \u20ac euro sign and \\ a backslash too.
"""
- data += u"".join(map(unichr, [ u for u in range(1,65536)
- if not 0xD800 <= u <= 0xDFFF ])) # surrogate area
+ data += u"".join(map(unichr, [u for u in range(1, 65536)
+ if not 0xD800 <= u <= 0xDFFF])) # surrogate area
self.conn.set_client_encoding('UNICODE')
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE, self.conn)
@@ -156,7 +157,7 @@ class QuotingTestCase(ConnectingTestCase):
class TestQuotedString(ConnectingTestCase):
- def test_encoding(self):
+ def test_encoding_from_conn(self):
q = psycopg2.extensions.QuotedString('hi')
self.assertEqual(q.encoding, 'latin1')
@@ -183,9 +184,51 @@ class TestQuotedIdentifier(ConnectingTestCase):
self.assertEqual(quote_ident(snowman, self.conn), quoted)
+class TestStringAdapter(ConnectingTestCase):
+ def test_encoding_default(self):
+ from psycopg2.extensions import adapt
+ a = adapt("hello")
+ self.assertEqual(a.encoding, 'latin1')
+ self.assertEqual(a.getquoted(), "'hello'")
+
+ # NOTE: we can't really test an encoding different from utf8, because
+ # when encoding without connection the libpq will use parameters from
+ # a previous one, so what would happens depends jn the tests run order.
+ # egrave = u'\xe8'
+ # self.assertEqual(adapt(egrave).getquoted(), "'\xe8'")
+
+ def test_encoding_error(self):
+ from psycopg2.extensions import adapt
+ snowman = u"\u2603"
+ a = adapt(snowman)
+ self.assertRaises(UnicodeEncodeError, a.getquoted)
+
+ def test_set_encoding(self):
+ # Note: this works-ish mostly in case when the standard db connection
+ # we test with is utf8, otherwise the encoding chosen by PQescapeString
+ # may give bad results.
+ from psycopg2.extensions import adapt
+ snowman = u"\u2603"
+ a = adapt(snowman)
+ a.encoding = 'utf8'
+ self.assertEqual(a.encoding, 'utf8')
+ self.assertEqual(a.getquoted(), "'\xe2\x98\x83'")
+
+ def test_connection_wins_anyway(self):
+ from psycopg2.extensions import adapt
+ snowman = u"\u2603"
+ a = adapt(snowman)
+ a.encoding = 'latin9'
+
+ self.conn.set_client_encoding('utf8')
+ a.prepare(self.conn)
+
+ self.assertEqual(a.encoding, 'utf_8')
+ self.assertEqual(a.getquoted(), "'\xe2\x98\x83'")
+
+
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()
-
diff --git a/tests/test_types_basic.py b/tests/test_types_basic.py
index 4923d82..248712b 100755
--- a/tests/test_types_basic.py
+++ b/tests/test_types_basic.py
@@ -95,11 +95,11 @@ class TypesBasicTests(ConnectingTestCase):
except ValueError:
return self.skipTest("inf not available on this platform")
s = self.execute("SELECT %s AS foo", (float("inf"),))
- self.failUnless(str(s) == "inf", "wrong float quoting: " + str(s))
+ self.failUnless(str(s) == "inf", "wrong float quoting: " + str(s))
self.failUnless(type(s) == float, "wrong float conversion: " + repr(s))
s = self.execute("SELECT %s AS foo", (float("-inf"),))
- self.failUnless(str(s) == "-inf", "wrong float quoting: " + str(s))
+ self.failUnless(str(s) == "-inf", "wrong float quoting: " + str(s))
def testBinary(self):
if sys.version_info[0] < 3:
@@ -364,8 +364,8 @@ class AdaptSubclassTest(unittest.TestCase):
try:
self.assertEqual(b('b'), adapt(C()).getquoted())
finally:
- del psycopg2.extensions.adapters[A, psycopg2.extensions.ISQLQuote]
- del psycopg2.extensions.adapters[B, psycopg2.extensions.ISQLQuote]
+ del psycopg2.extensions.adapters[A, psycopg2.extensions.ISQLQuote]
+ del psycopg2.extensions.adapters[B, psycopg2.extensions.ISQLQuote]
@testutils.skip_from_python(3)
def test_no_mro_no_joy(self):
@@ -378,8 +378,7 @@ class AdaptSubclassTest(unittest.TestCase):
try:
self.assertRaises(psycopg2.ProgrammingError, adapt, B())
finally:
- del psycopg2.extensions.adapters[A, psycopg2.extensions.ISQLQuote]
-
+ del psycopg2.extensions.adapters[A, psycopg2.extensions.ISQLQuote]
@testutils.skip_before_python(3)
def test_adapt_subtype_3(self):
@@ -392,7 +391,7 @@ class AdaptSubclassTest(unittest.TestCase):
try:
self.assertEqual(b("a"), adapt(B()).getquoted())
finally:
- del psycopg2.extensions.adapters[A, psycopg2.extensions.ISQLQuote]
+ del psycopg2.extensions.adapters[A, psycopg2.extensions.ISQLQuote]
class ByteaParserTest(unittest.TestCase):
@@ -480,6 +479,7 @@ class ByteaParserTest(unittest.TestCase):
self.assertEqual(rv, tgt)
+
def skip_if_cant_cast(f):
@wraps(f)
def skip_if_cant_cast_(self, *args, **kwargs):
@@ -499,4 +499,3 @@ def test_suite():
if __name__ == "__main__":
unittest.main()
-
diff --git a/tests/testconfig.py b/tests/testconfig.py
index 0f995fb..72c533e 100644
--- a/tests/testconfig.py
+++ b/tests/testconfig.py
@@ -34,5 +34,3 @@ if dbuser is not None:
dsn += ' user=%s' % dbuser
if dbpass is not None:
dsn += ' password=%s' % dbpass
-
-