diff options
-rw-r--r-- | NEWS | 1 | ||||
-rw-r--r-- | doc/src/faq.rst | 18 | ||||
-rw-r--r-- | doc/src/index.rst | 1 | ||||
-rw-r--r-- | doc/src/sql.rst | 89 | ||||
-rw-r--r-- | doc/src/usage.rst | 7 | ||||
-rw-r--r-- | lib/sql.py | 424 | ||||
-rw-r--r-- | psycopg/cursor_type.c | 63 | ||||
-rwxr-xr-x | tests/__init__.py | 7 | ||||
-rwxr-xr-x | tests/test_ipaddress.py | 4 | ||||
-rwxr-xr-x | tests/test_sql.py | 378 |
10 files changed, 978 insertions, 14 deletions
@@ -6,6 +6,7 @@ What's new in psycopg 2.7 New features: +- Added `~psycopg2.sql` module to generate SQL dynamically (:ticket:`#308`). - Added :ref:`replication-support` (:ticket:`#322`). Main authors are Oleksandr Shulgin and Craig Ringer, who deserve a huge thank you. - Added `~psycopg2.extensions.parse_dsn()` and diff --git a/doc/src/faq.rst b/doc/src/faq.rst index 89d8a63..fb7b33d 100644 --- a/doc/src/faq.rst +++ b/doc/src/faq.rst @@ -151,6 +151,24 @@ Psycopg converts :sql:`json` values into Python objects but :sql:`jsonb` values See :ref:`adapt-json` for further details. +.. _faq-identifier: +.. cssclass:: faq + +How can I pass field/table names to a query? + The arguments in the `~cursor.execute()` methods can only represent data + to pass to the query: they cannot represent a table or field name:: + + # This doesn't work + cur.execute("insert into %s values (%s)", ["my_table", 42]) + + If you want to build a query dynamically you can use the objects exposed + by the `psycopg2.sql` module:: + + cur.execute( + sql.SQL("insert into %s values (%%s)") % [sql.Identifier("my_table")], + [42]) + + .. _faq-bytea-9.0: .. cssclass:: faq diff --git a/doc/src/index.rst b/doc/src/index.rst index 5cf0f24..96a1423 100644 --- a/doc/src/index.rst +++ b/doc/src/index.rst @@ -44,6 +44,7 @@ Psycopg 2 is both Unicode and Python 3 friendly. advanced extensions extras + sql tz pool errorcodes diff --git a/doc/src/sql.rst b/doc/src/sql.rst new file mode 100644 index 0000000..0aee451 --- /dev/null +++ b/doc/src/sql.rst @@ -0,0 +1,89 @@ +`psycopg2.sql` -- SQL string composition +======================================== + +.. sectionauthor:: Daniele Varrazzo <daniele.varrazzo@gmail.com> + +.. module:: psycopg2.sql + +.. versionadded:: 2.7 + +The module contains objects and functions useful to generate SQL dynamically, +in a convenient and safe way. SQL identifiers (e.g. names of tables and +fields) cannot be passed to the `~cursor.execute()` method like query +arguments:: + + # This will not work + table_name = 'my_table' + cur.execute("insert into %s values (%s, %s)", [table_name, 10, 20]) + +The SQL query should be composed before the arguments are merged, for +instance:: + + # This works, but it is not optimal + table_name = 'my_table' + cur.execute( + "insert into %s values (%%s, %%s)" % table_name, + [10, 20]) + +This sort of works, but it is an accident waiting to happen: the table name +may be an invalid SQL literal and need quoting; even more serious is the +security problem in case the table name comes from an untrusted source. The +name should be escaped using `~psycopg2.extensions.quote_ident()`:: + + # This works, but it is not optimal + table_name = 'my_table' + cur.execute( + "insert into %s values (%%s, %%s)" % ext.quote_ident(table_name), + [10, 20]) + +This is now safe, but it somewhat ad-hoc. In case, for some reason, it is +necessary to include a value in the query string (as opposite as in a value) +the merging rule is still different (`~psycopg2.extensions.adapt()` should be +used...). It is also still relatively dangerous: if `!quote_ident()` is +forgotten somewhere, the program will usually work, but will eventually crash +in the presence of a table or field name with containing characters to escape, +or will present a potentially exploitable weakness. + +The objects exposed by the `!psycopg2.sql` module allow generating SQL +statements on the fly, separating clearly the variable parts of the statement +from the query parameters:: + + from psycopg2 import sql + + cur.execute( + sql.SQL("insert into {} values (%s, %s)") + .format(sql.Identifier('my_table')), + [10, 20]) + + +.. autoclass:: Composable + + .. automethod:: as_string + + +.. autoclass:: SQL + + .. autoattribute:: string + + .. automethod:: format + + .. automethod:: join + + +.. autoclass:: Identifier + + .. autoattribute:: string + +.. autoclass:: Literal + + .. autoattribute:: wrapped + +.. autoclass:: Placeholder + + .. autoattribute:: name + +.. autoclass:: Composed + + .. autoattribute:: seq + + .. automethod:: join diff --git a/doc/src/usage.rst b/doc/src/usage.rst index d9fea75..1366485 100644 --- a/doc/src/usage.rst +++ b/doc/src/usage.rst @@ -132,9 +132,10 @@ query: >>> cur.execute("INSERT INTO foo VALUES (%s)", ("bar",)) # correct >>> cur.execute("INSERT INTO foo VALUES (%s)", ["bar"]) # correct -- Only variable values should be bound via this method: it shouldn't be used - to set table or field names. For these elements, ordinary string formatting - should be used before running `~cursor.execute()`. +- Only query values should be bound via this method: it shouldn't be used to + merge table or field names to the query. If you need to generate dynamically + an SQL query (for instance choosing dynamically a table name) you can use + the facilities provided by the `psycopg2.sql` module. diff --git a/lib/sql.py b/lib/sql.py new file mode 100644 index 0000000..e7d42e6 --- /dev/null +++ b/lib/sql.py @@ -0,0 +1,424 @@ +"""SQL composition utility module +""" + +# psycopg/sql.py - Implementation of the JSON adaptation objects +# +# Copyright (C) 2016 Daniele Varrazzo <daniele.varrazzo@gmail.com> +# +# psycopg2 is free software: you can redistribute it and/or modify it +# under the terms of the GNU Lesser General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# In addition, as a special exception, the copyright holders give +# permission to link this program with the OpenSSL library (or with +# modified versions of OpenSSL that use the same license as OpenSSL), +# and distribute linked combinations including the two. +# +# You must obey the GNU Lesser General Public License in all respects for +# all of the code used other than OpenSSL. +# +# psycopg2 is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. + +import sys +import string + +from psycopg2 import extensions as ext + + +_formatter = string.Formatter() + + +class Composable(object): + """ + Abstract base class for objects that can be used to compose an SQL string. + + `!Composable` objects can be passed directly to `~cursor.execute()` and + `~cursor.executemany()` in place of the query string. + + `!Composable` objects can be joined using the ``+`` operator: the result + will be a `Composed` instance containing the objects joined. The operator + ``*`` is also supported with an integer argument: the result is a + `!Composed` instance containing the left argument repeated as many times as + requested. + """ + def __init__(self, wrapped): + self._wrapped = wrapped + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self._wrapped) + + def as_string(self, context): + """ + Return the string value of the object. + + :param context: the context to evaluate the string into. + :type context: `connection` or `cursor` + + The method is automatically invoked by `~cursor.execute()` and + `~cursor.executemany()` if a `!Composable` is passed instead of the + query string. + """ + raise NotImplementedError + + def __add__(self, other): + if isinstance(other, Composed): + return Composed([self]) + other + if isinstance(other, Composable): + return Composed([self]) + Composed([other]) + else: + return NotImplemented + + def __mul__(self, n): + return Composed([self] * n) + + def __eq__(self, other): + return type(self) is type(other) and self._wrapped == other._wrapped + + def __ne__(self, other): + return not self.__eq__(other) + + +class Composed(Composable): + """ + A `Composable` object made of a sequence of `Composable`. + + The object is usually created using `Composable` operators and methods. + However it is possible to create a `!Composed` directly specifying a + sequence of `Composable` as arguments. + + Example:: + + >>> comp = sql.Composed( + ... [sql.SQL("insert into "), sql.Identifier("table")]) + >>> print(comp.as_string(conn)) + insert into "table" + + `!Composed` objects are iterable (so they can be used in `SQL.join` for + instance). + """ + def __init__(self, seq): + wrapped = [] + for i in seq: + if not isinstance(i, Composable): + raise TypeError( + "Composed elements must be Composable, got %r instead" % i) + wrapped.append(i) + + super(Composed, self).__init__(wrapped) + + @property + def seq(self): + """The list of the content of the `!Composed`.""" + return list(self._wrapped) + + def as_string(self, context): + rv = [] + for i in self._wrapped: + rv.append(i.as_string(context)) + return ''.join(rv) + + def __iter__(self): + return iter(self._wrapped) + + def __add__(self, other): + if isinstance(other, Composed): + return Composed(self._wrapped + other._wrapped) + if isinstance(other, Composable): + return Composed(self._wrapped + [other]) + else: + return NotImplemented + + def join(self, joiner): + """ + Return a new `!Composed` interposing the *joiner* with the `!Composed` items. + + The *joiner* must be a `SQL` or a string which will be interpreted as + an `SQL`. + + Example:: + + >>> fields = sql.Identifier('foo') + sql.Identifier('bar') # a Composed + >>> print(fields.join(', ').as_string(conn)) + "foo", "bar" + + """ + if isinstance(joiner, basestring): + joiner = SQL(joiner) + elif not isinstance(joiner, SQL): + raise TypeError( + "Composed.join() argument must be a string or an SQL") + + return joiner.join(self) + + +class SQL(Composable): + """ + A `Composable` representing a snippet of SQL statement. + + `!SQL` exposes `join()` and `format()` methods useful to create a template + where to merge variable parts of a query (for instance field or table + names). + + The *string* doesn't undergo any form of escaping, so it is not suitable to + represent variable identifiers or values: you should only use it to pass + constant strings representing templates or snippets of SQL statements; use + other objects such as `Identifier` or `Literal` to represent variable + parts. + + Example:: + + >>> query = sql.SQL("select {0} from {1}").format( + ... sql.SQL(', ').join([sql.Identifier('foo'), sql.Identifier('bar')]), + ... sql.Identifier('table')) + >>> print(query.as_string(conn)) + select "foo", "bar" from "table" + """ + def __init__(self, string): + if not isinstance(string, basestring): + raise TypeError("SQL values must be strings") + super(SQL, self).__init__(string) + + @property + def string(self): + """The string wrapped by the `!SQL` object.""" + return self._wrapped + + def as_string(self, context): + return self._wrapped + + def format(self, *args, **kwargs): + """ + Merge `Composable` objects into a template. + + :param `Composable` args: parameters to replace to numbered + (``{0}``, ``{1}``) or auto-numbered (``{}``) placeholders + :param `Composable` kwargs: parameters to replace to named (``{name}``) + placeholders + :return: the union of the `!SQL` string with placeholders replaced + :rtype: `Composed` + + The method is similar to the Python `str.format()` method: the string + template supports auto-numbered (``{}``, only available from Python + 2.7), numbered (``{0}``, ``{1}``...), and named placeholders + (``{name}``), with positional arguments replacing the numbered + placeholders and keywords replacing the named ones. However placeholder + modifiers (``{0!r}``, ``{0:<10}``) are not supported. Only + `!Composable` objects can be passed to the template. + + Example:: + + >>> print(sql.SQL("select * from {} where {} = %s") + ... .format(sql.Identifier('people'), sql.Identifier('id')) + ... .as_string(conn)) + select * from "people" where "id" = %s + + >>> print(sql.SQL("select * from {tbl} where {pkey} = %s") + ... .format(tbl=sql.Identifier('people'), pkey=sql.Identifier('id')) + ... .as_string(conn)) + select * from "people" where "id" = %s + + """ + rv = [] + autonum = 0 + for pre, name, spec, conv in _formatter.parse(self._wrapped): + if spec: + raise ValueError("no format specification supported by SQL") + if conv: + raise ValueError("no format conversion supported by SQL") + if pre: + rv.append(SQL(pre)) + + if name is None: + continue + + if name.isdigit(): + if autonum: + raise ValueError( + "cannot switch from automatic field numbering to manual") + rv.append(args[int(name)]) + autonum = None + + elif not name: + if autonum is None: + raise ValueError( + "cannot switch from manual field numbering to automatic") + rv.append(args[autonum]) + autonum += 1 + + else: + rv.append(kwargs[name]) + + return Composed(rv) + + def join(self, seq): + """ + Join a sequence of `Composable`. + + :param seq: the elements to join. + :type seq: iterable of `!Composable` + + Use the `!SQL` object's *string* to separate the elements in *seq*. + Note that `Composed` objects are iterable too, so they can be used as + argument for this method. + + Example:: + + >>> snip = sql.SQL(', ').join( + ... sql.Identifier(n) for n in ['foo', 'bar', 'baz']) + >>> print(snip.as_string(conn)) + "foo", "bar", "baz" + """ + rv = [] + it = iter(seq) + try: + rv.append(it.next()) + except StopIteration: + pass + else: + for i in it: + rv.append(self) + rv.append(i) + + return Composed(rv) + + +class Identifier(Composable): + """ + A `Composable` representing an SQL identifer. + + Identifiers usually represent names of database objects, such as tables + or fields. They follow `different rules`__ than SQL string literals for + escaping (e.g. they use double quotes). + + .. __: https://www.postgresql.org/docs/current/static/sql-syntax-lexical.html# \ + SQL-SYNTAX-IDENTIFIERS + + Example:: + + >>> t1 = sql.Identifier("foo") + >>> t2 = sql.Identifier("ba'r") + >>> t3 = sql.Identifier('ba"z') + >>> print(sql.SQL(', ').join([t1, t2, t3]).as_string(conn)) + "foo", "ba'r", "ba""z" + + """ + def __init__(self, string): + if not isinstance(string, basestring): + raise TypeError("SQL identifiers must be strings") + + super(Identifier, self).__init__(string) + + @property + def string(self): + """The string wrapped by the `Identifier`.""" + return self._wrapped + + def as_string(self, context): + return ext.quote_ident(self._wrapped, context) + + +class Literal(Composable): + """ + A `Composable` representing an SQL value to include in a query. + + Usually you will want to include placeholders in the query and pass values + as `~cursor.execute()` arguments. If however you really really need to + include a literal value in the query you can use this object. + + The string returned by `!as_string()` follows the normal :ref:`adaptation + rules <python-types-adaptation>` for Python objects. + + Example:: + + >>> s1 = sql.Literal("foo") + >>> s2 = sql.Literal("ba'r") + >>> s3 = sql.Literal(42) + >>> print(sql.SQL(', ').join([s1, s2, s3]).as_string(conn)) + 'foo', 'ba''r', 42 + + """ + @property + def wrapped(self): + """The object wrapped by the `!Literal`.""" + return self._wrapped + + def as_string(self, context): + # is it a connection or cursor? + if isinstance(context, ext.connection): + conn = context + elif isinstance(context, ext.cursor): + conn = context.connection + else: + raise TypeError("context must be a connection or a cursor") + + a = ext.adapt(self._wrapped) + if hasattr(a, 'prepare'): + a.prepare(conn) + + rv = a.getquoted() + if sys.version_info[0] >= 3 and isinstance(rv, bytes): + rv = rv.decode(ext.encodings[conn.encoding]) + + return rv + + +class Placeholder(Composable): + """A `Composable` representing a placeholder for query parameters. + + If the name is specified, generate a named placeholder (e.g. ``%(name)s``), + otherwise generate a positional placeholder (e.g. ``%s``). + + The object is useful to generate SQL queries with a variable number of + arguments. + + Examples:: + + >>> names = ['foo', 'bar', 'baz'] + + >>> q1 = sql.SQL("insert into table ({}) values ({})").format( + ... sql.SQL(', ').join(map(sql.Identifier, names)), + ... sql.SQL(', ').join(sql.Placeholder() * len(names))) + >>> print(q1.as_string(conn)) + insert into table ("foo", "bar", "baz") values (%s, %s, %s) + + >>> q2 = sql.SQL("insert into table ({}) values ({})").format( + ... sql.SQL(', ').join(map(sql.Identifier, names)), + ... sql.SQL(', ').join(map(sql.Placeholder, names))) + >>> print(q2.as_string(conn)) + insert into table ("foo", "bar", "baz") values (%(foo)s, %(bar)s, %(baz)s) + + """ + + def __init__(self, name=None): + if isinstance(name, basestring): + if ')' in name: + raise ValueError("invalid name: %r" % name) + + elif name is not None: + raise TypeError("expected string or None as name, got %r" % name) + + super(Placeholder, self).__init__(name) + + @property + def name(self): + """The name of the `!Placeholder`.""" + return self._wrapped + + def __repr__(self): + return "Placeholder(%r)" % ( + self._wrapped if self._wrapped is not None else '',) + + def as_string(self, context): + if self._wrapped is not None: + return "%%(%s)s" % self._wrapped + else: + return "%s" + + +# Literals +NULL = SQL("NULL") +DEFAULT = SQL("DEFAULT") diff --git a/psycopg/cursor_type.c b/psycopg/cursor_type.c index a7303c6..5031033 100644 --- a/psycopg/cursor_type.c +++ b/psycopg/cursor_type.c @@ -267,10 +267,35 @@ _mogrify(PyObject *var, PyObject *fmt, cursorObject *curs, PyObject **new) return 0; } +/* Return 1 if `obj` is a `psycopg2.sql.Composable` instance, else 0 + * Set an exception and return -1 in case of error. + */ +RAISES_NEG static int +_curs_is_composible(PyObject *obj) +{ + int rv = -1; + PyObject *m = NULL; + PyObject *comp = NULL; + + if (!(m = PyImport_ImportModule("psycopg2.sql"))) { goto exit; } + if (!(comp = PyObject_GetAttrString(m, "Composable"))) { goto exit; } + rv = PyObject_IsInstance(obj, comp); + +exit: + Py_XDECREF(comp); + Py_XDECREF(m); + return rv; + +} + static PyObject *_psyco_curs_validate_sql_basic( cursorObject *self, PyObject *sql ) { + PyObject *rv = NULL; + PyObject *comp = NULL; + int iscomp; + /* Performs very basic validation on an incoming SQL string. Returns a new reference to a str instance on success; NULL on failure, after having set an exception. */ @@ -278,26 +303,48 @@ static PyObject *_psyco_curs_validate_sql_basic( if (!sql || !PyObject_IsTrue(sql)) { psyco_set_error(ProgrammingError, self, "can't execute an empty query"); - goto fail; + goto exit; } if (Bytes_Check(sql)) { /* Necessary for ref-count symmetry with the unicode case: */ Py_INCREF(sql); + rv = sql; } else if (PyUnicode_Check(sql)) { - if (!(sql = conn_encode(self->conn, sql))) { goto fail; } + if (!(rv = conn_encode(self->conn, sql))) { goto exit; } + } + else if (0 != (iscomp = _curs_is_composible(sql))) { + if (iscomp < 0) { goto exit; } + if (!(comp = PyObject_CallMethod(sql, "as_string", "O", self->conn))) { + goto exit; + } + + if (Bytes_Check(comp)) { + rv = comp; + comp = NULL; + } + else if (PyUnicode_Check(comp)) { + if (!(rv = conn_encode(self->conn, comp))) { goto exit; } + } + else { + PyErr_Format(PyExc_TypeError, + "as_string() should return a string: got %s instead", + Py_TYPE(comp)->tp_name); + goto exit; + } } else { /* the is not unicode or string, raise an error */ - PyErr_SetString(PyExc_TypeError, - "argument 1 must be a string or unicode object"); - goto fail; + PyErr_Format(PyExc_TypeError, + "argument 1 must be a string or unicode object: got %s instead", + Py_TYPE(sql)->tp_name); + goto exit; } - return sql; /* new reference */ - fail: - return NULL; +exit: + Py_XDECREF(comp); + return rv; } /* Merge together a query string and its arguments. diff --git a/tests/__init__.py b/tests/__init__.py index fd12d07..85a4ec9 100755 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -36,7 +36,6 @@ import test_bugX000 import test_bug_gc import test_cancel import test_connection -import test_replication import test_copy import test_cursor import test_dates @@ -50,6 +49,8 @@ import test_module import test_notify import test_psycopg2_dbapi20 import test_quote +import test_replication +import test_sql import test_transaction import test_types_basic import test_types_extras @@ -79,7 +80,6 @@ def test_suite(): suite.addTest(test_bug_gc.test_suite()) suite.addTest(test_cancel.test_suite()) suite.addTest(test_connection.test_suite()) - suite.addTest(test_replication.test_suite()) suite.addTest(test_copy.test_suite()) suite.addTest(test_cursor.test_suite()) suite.addTest(test_dates.test_suite()) @@ -93,11 +93,14 @@ def test_suite(): suite.addTest(test_notify.test_suite()) suite.addTest(test_psycopg2_dbapi20.test_suite()) suite.addTest(test_quote.test_suite()) + suite.addTest(test_replication.test_suite()) + suite.addTest(test_sql.test_suite()) suite.addTest(test_transaction.test_suite()) suite.addTest(test_types_basic.test_suite()) suite.addTest(test_types_extras.test_suite()) suite.addTest(test_with.test_suite()) return suite + if __name__ == '__main__': unittest.main(defaultTest='test_suite') diff --git a/tests/test_ipaddress.py b/tests/test_ipaddress.py index 97eabba..49413f4 100755 --- a/tests/test_ipaddress.py +++ b/tests/test_ipaddress.py @@ -1,5 +1,7 @@ #!/usr/bin/env python -# # test_ipaddress.py - tests for ipaddress support # +# +# test_ipaddress.py - tests for ipaddress support +# # Copyright (C) 2016 Daniele Varrazzo <daniele.varrazzo@gmail.com> # # psycopg2 is free software: you can redistribute it and/or modify it diff --git a/tests/test_sql.py b/tests/test_sql.py new file mode 100755 index 0000000..ffb4f1f --- /dev/null +++ b/tests/test_sql.py @@ -0,0 +1,378 @@ +#!/usr/bin/env python + +# test_sql.py - tests for the psycopg2.sql module +# +# Copyright (C) 2016 Daniele Varrazzo <daniele.varrazzo@gmail.com> +# +# psycopg2 is free software: you can redistribute it and/or modify it +# under the terms of the GNU Lesser General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# In addition, as a special exception, the copyright holders give +# permission to link this program with the OpenSSL library (or with +# modified versions of OpenSSL that use the same license as OpenSSL), +# and distribute linked combinations including the two. +# +# You must obey the GNU Lesser General Public License in all respects for +# all of the code used other than OpenSSL. +# +# psycopg2 is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. + +import datetime as dt +from testutils import unittest, ConnectingTestCase, skip_before_python + +import psycopg2 +from psycopg2 import sql + + +class SqlFormatTests(ConnectingTestCase): + @skip_before_python(2, 7) + def test_pos(self): + s = sql.SQL("select {} from {}").format( + sql.Identifier('field'), sql.Identifier('table')) + s1 = s.as_string(self.conn) + self.assert_(isinstance(s1, str)) + self.assertEqual(s1, 'select "field" from "table"') + + def test_pos_spec(self): + s = sql.SQL("select {0} from {1}").format( + sql.Identifier('field'), sql.Identifier('table')) + s1 = s.as_string(self.conn) + self.assert_(isinstance(s1, str)) + self.assertEqual(s1, 'select "field" from "table"') + + s = sql.SQL("select {1} from {0}").format( + sql.Identifier('table'), sql.Identifier('field')) + s1 = s.as_string(self.conn) + self.assert_(isinstance(s1, str)) + self.assertEqual(s1, 'select "field" from "table"') + + def test_dict(self): + s = sql.SQL("select {f} from {t}").format( + f=sql.Identifier('field'), t=sql.Identifier('table')) + s1 = s.as_string(self.conn) + self.assert_(isinstance(s1, str)) + self.assertEqual(s1, 'select "field" from "table"') + + def test_unicode(self): + s = sql.SQL(u"select {0} from {1}").format( + sql.Identifier(u'field'), sql.Identifier('table')) + s1 = s.as_string(self.conn) + self.assert_(isinstance(s1, unicode)) + self.assertEqual(s1, u'select "field" from "table"') + + def test_compose_literal(self): + s = sql.SQL("select {0};").format(sql.Literal(dt.date(2016, 12, 31))) + s1 = s.as_string(self.conn) + self.assertEqual(s1, "select '2016-12-31'::date;") + + def test_compose_empty(self): + s = sql.SQL("select foo;").format() + s1 = s.as_string(self.conn) + self.assertEqual(s1, "select foo;") + + def test_percent_escape(self): + s = sql.SQL("42 % {0}").format(sql.Literal(7)) + s1 = s.as_string(self.conn) + self.assertEqual(s1, "42 % 7") + + def test_braces_escape(self): + s = sql.SQL("{{{0}}}").format(sql.Literal(7)) + self.assertEqual(s.as_string(self.conn), "{7}") + s = sql.SQL("{{1,{0}}}").format(sql.Literal(7)) + self.assertEqual(s.as_string(self.conn), "{1,7}") + + def test_compose_badnargs(self): + self.assertRaises(IndexError, sql.SQL("select {0};").format) + + @skip_before_python(2, 7) + def test_compose_badnargs_auto(self): + self.assertRaises(IndexError, sql.SQL("select {};").format) + self.assertRaises(ValueError, sql.SQL("select {} {1};").format, 10, 20) + self.assertRaises(ValueError, sql.SQL("select {0} {};").format, 10, 20) + + def test_compose_bad_args_type(self): + self.assertRaises(IndexError, sql.SQL("select {0};").format, a=10) + self.assertRaises(KeyError, sql.SQL("select {x};").format, 10) + + def test_must_be_composable(self): + self.assertRaises(TypeError, sql.SQL("select {0};").format, 'foo') + self.assertRaises(TypeError, sql.SQL("select {0};").format, 10) + + def test_no_modifiers(self): + self.assertRaises(ValueError, sql.SQL("select {a!r};").format, a=10) + self.assertRaises(ValueError, sql.SQL("select {a:<};").format, a=10) + + def test_must_be_adaptable(self): + class Foo(object): + pass + + self.assertRaises(psycopg2.ProgrammingError, + sql.SQL("select {0};").format(sql.Literal(Foo())).as_string, self.conn) + + def test_execute(self): + cur = self.conn.cursor() + cur.execute(""" + create table test_compose ( + id serial primary key, + foo text, bar text, "ba'z" text) + """) + cur.execute( + sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format( + sql.Identifier('test_compose'), + sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])), + (sql.Placeholder() * 3).join(', ')), + (10, 'a', 'b', 'c')) + + cur.execute("select * from test_compose") + self.assertEqual(cur.fetchall(), [(10, 'a', 'b', 'c')]) + + def test_executemany(self): + cur = self.conn.cursor() + cur.execute(""" + create table test_compose ( + id serial primary key, + foo text, bar text, "ba'z" text) + """) + cur.executemany( + sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format( + sql.Identifier('test_compose'), + sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])), + (sql.Placeholder() * 3).join(', ')), + [(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')]) + + cur.execute("select * from test_compose") + self.assertEqual(cur.fetchall(), + [(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')]) + + +class IdentifierTests(ConnectingTestCase): + def test_class(self): + self.assert_(issubclass(sql.Identifier, sql.Composable)) + + def test_init(self): + self.assert_(isinstance(sql.Identifier('foo'), sql.Identifier)) + self.assert_(isinstance(sql.Identifier(u'foo'), sql.Identifier)) + self.assertRaises(TypeError, sql.Identifier, 10) + self.assertRaises(TypeError, sql.Identifier, dt.date(2016, 12, 31)) + + def test_string(self): + self.assertEqual(sql.Identifier('foo').string, 'foo') + + def test_repr(self): + obj = sql.Identifier("fo'o") + self.assertEqual(repr(obj), 'Identifier("fo\'o")') + self.assertEqual(repr(obj), str(obj)) + + def test_eq(self): + self.assert_(sql.Identifier('foo') == sql.Identifier('foo')) + self.assert_(sql.Identifier('foo') != sql.Identifier('bar')) + self.assert_(sql.Identifier('foo') != 'foo') + self.assert_(sql.Identifier('foo') != sql.SQL('foo')) + + def test_as_str(self): + self.assertEqual(sql.Identifier('foo').as_string(self.conn), '"foo"') + self.assertEqual(sql.Identifier("fo'o").as_string(self.conn), '"fo\'o"') + + def test_join(self): + self.assert_(not hasattr(sql.Identifier('foo'), 'join')) + + +class LiteralTests(ConnectingTestCase): + def test_class(self): + self.assert_(issubclass(sql.Literal, sql.Composable)) + + def test_init(self): + self.assert_(isinstance(sql.Literal('foo'), sql.Literal)) + self.assert_(isinstance(sql.Literal(u'foo'), sql.Literal)) + self.assert_(isinstance(sql.Literal(b'foo'), sql.Literal)) + self.assert_(isinstance(sql.Literal(42), sql.Literal)) + self.assert_(isinstance( + sql.Literal(dt.date(2016, 12, 31)), sql.Literal)) + + def test_wrapped(self): + self.assertEqual(sql.Literal('foo').wrapped, 'foo') + + def test_repr(self): + self.assertEqual(repr(sql.Literal("foo")), "Literal('foo')") + self.assertEqual(str(sql.Literal("foo")), "Literal('foo')") + self.assertEqual( + sql.Literal("foo").as_string(self.conn).replace("E'", "'"), + "'foo'") + self.assertEqual(sql.Literal(42).as_string(self.conn), "42") + self.assertEqual( + sql.Literal(dt.date(2017, 1, 1)).as_string(self.conn), + "'2017-01-01'::date") + + def test_eq(self): + self.assert_(sql.Literal('foo') == sql.Literal('foo')) + self.assert_(sql.Literal('foo') != sql.Literal('bar')) + self.assert_(sql.Literal('foo') != 'foo') + self.assert_(sql.Literal('foo') != sql.SQL('foo')) + + def test_must_be_adaptable(self): + class Foo(object): + pass + + self.assertRaises(psycopg2.ProgrammingError, + sql.Literal(Foo()).as_string, self.conn) + + +class SQLTests(ConnectingTestCase): + def test_class(self): + self.assert_(issubclass(sql.SQL, sql.Composable)) + + def test_init(self): + self.assert_(isinstance(sql.SQL('foo'), sql.SQL)) + self.assert_(isinstance(sql.SQL(u'foo'), sql.SQL)) + self.assertRaises(TypeError, sql.SQL, 10) + self.assertRaises(TypeError, sql.SQL, dt.date(2016, 12, 31)) + + def test_string(self): + self.assertEqual(sql.SQL('foo').string, 'foo') + + def test_repr(self): + self.assertEqual(repr(sql.SQL("foo")), "SQL('foo')") + self.assertEqual(str(sql.SQL("foo")), "SQL('foo')") + self.assertEqual(sql.SQL("foo").as_string(self.conn), "foo") + + def test_eq(self): + self.assert_(sql.SQL('foo') == sql.SQL('foo')) + self.assert_(sql.SQL('foo') != sql.SQL('bar')) + self.assert_(sql.SQL('foo') != 'foo') + self.assert_(sql.SQL('foo') != sql.Literal('foo')) + + def test_sum(self): + obj = sql.SQL("foo") + sql.SQL("bar") + self.assert_(isinstance(obj, sql.Composed)) + self.assertEqual(obj.as_string(self.conn), "foobar") + + def test_sum_inplace(self): + obj = sql.SQL("foo") + obj += sql.SQL("bar") + self.assert_(isinstance(obj, sql.Composed)) + self.assertEqual(obj.as_string(self.conn), "foobar") + + def test_multiply(self): + obj = sql.SQL("foo") * 3 + self.assert_(isinstance(obj, sql.Composed)) + self.assertEqual(obj.as_string(self.conn), "foofoofoo") + + def test_join(self): + obj = sql.SQL(", ").join( + [sql.Identifier('foo'), sql.SQL('bar'), sql.Literal(42)]) + self.assert_(isinstance(obj, sql.Composed)) + self.assertEqual(obj.as_string(self.conn), '"foo", bar, 42') + + obj = sql.SQL(", ").join( + sql.Composed([sql.Identifier('foo'), sql.SQL('bar'), sql.Literal(42)])) + self.assert_(isinstance(obj, sql.Composed)) + self.assertEqual(obj.as_string(self.conn), '"foo", bar, 42') + + obj = sql.SQL(", ").join([]) + self.assertEqual(obj, sql.Composed([])) + + +class ComposedTest(ConnectingTestCase): + def test_class(self): + self.assert_(issubclass(sql.Composed, sql.Composable)) + + def test_repr(self): + obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")]) + self.assertEqual(repr(obj), + """Composed([Literal('foo'), Identifier("b'ar")])""") + self.assertEqual(str(obj), repr(obj)) + + def test_seq(self): + l = [sql.SQL('foo'), sql.Literal('bar'), sql.Identifier('baz')] + self.assertEqual(sql.Composed(l).seq, l) + + def test_eq(self): + l = [sql.Literal("foo"), sql.Identifier("b'ar")] + l2 = [sql.Literal("foo"), sql.Literal("b'ar")] + self.assert_(sql.Composed(l) == sql.Composed(list(l))) + self.assert_(sql.Composed(l) != l) + self.assert_(sql.Composed(l) != sql.Composed(l2)) + + def test_join(self): + obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")]) + obj = obj.join(", ") + self.assert_(isinstance(obj, sql.Composed)) + self.assertEqual(obj.as_string(self.conn), "'foo', \"b'ar\"") + + def test_sum(self): + obj = sql.Composed([sql.SQL("foo ")]) + obj = obj + sql.Literal("bar") + self.assert_(isinstance(obj, sql.Composed)) + self.assertEqual(obj.as_string(self.conn), "foo 'bar'") + + def test_sum_inplace(self): + obj = sql.Composed([sql.SQL("foo ")]) + obj += sql.Literal("bar") + self.assert_(isinstance(obj, sql.Composed)) + self.assertEqual(obj.as_string(self.conn), "foo 'bar'") + + obj = sql.Composed([sql.SQL("foo ")]) + obj += sql.Composed([sql.Literal("bar")]) + self.assert_(isinstance(obj, sql.Composed)) + self.assertEqual(obj.as_string(self.conn), "foo 'bar'") + + def test_iter(self): + obj = sql.Composed([sql.SQL("foo"), sql.SQL('bar')]) + it = iter(obj) + i = it.next() + self.assertEqual(i, sql.SQL('foo')) + i = it.next() + self.assertEqual(i, sql.SQL('bar')) + self.assertRaises(StopIteration, it.next) + + +class PlaceholderTest(ConnectingTestCase): + def test_class(self): + self.assert_(issubclass(sql.Placeholder, sql.Composable)) + + def test_name(self): + self.assertEqual(sql.Placeholder().name, None) + self.assertEqual(sql.Placeholder('foo').name, 'foo') + + def test_repr(self): + self.assert_(str(sql.Placeholder()), 'Placeholder()') + self.assert_(repr(sql.Placeholder()), 'Placeholder()') + self.assert_(sql.Placeholder().as_string(self.conn), '%s') + + def test_repr_name(self): + self.assert_(str(sql.Placeholder('foo')), "Placeholder('foo')") + self.assert_(repr(sql.Placeholder('foo')), "Placeholder('foo')") + self.assert_(sql.Placeholder('foo').as_string(self.conn), '%(foo)s') + + def test_bad_name(self): + self.assertRaises(ValueError, sql.Placeholder, ')') + + def test_eq(self): + self.assert_(sql.Placeholder('foo') == sql.Placeholder('foo')) + self.assert_(sql.Placeholder('foo') != sql.Placeholder('bar')) + self.assert_(sql.Placeholder('foo') != 'foo') + self.assert_(sql.Placeholder() == sql.Placeholder()) + self.assert_(sql.Placeholder('foo') != sql.Placeholder()) + self.assert_(sql.Placeholder('foo') != sql.Literal('foo')) + + +class ValuesTest(ConnectingTestCase): + def test_null(self): + self.assert_(isinstance(sql.NULL, sql.SQL)) + self.assertEqual(sql.NULL.as_string(self.conn), "NULL") + + def test_default(self): + self.assert_(isinstance(sql.DEFAULT, sql.SQL)) + self.assertEqual(sql.DEFAULT.as_string(self.conn), "DEFAULT") + + +def test_suite(): + return unittest.TestLoader().loadTestsFromName(__name__) + +if __name__ == "__main__": + unittest.main() |