summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--NEWS1
-rw-r--r--doc/src/faq.rst18
-rw-r--r--doc/src/index.rst1
-rw-r--r--doc/src/sql.rst89
-rw-r--r--doc/src/usage.rst7
-rw-r--r--lib/sql.py424
-rw-r--r--psycopg/cursor_type.c63
-rwxr-xr-xtests/__init__.py7
-rwxr-xr-xtests/test_ipaddress.py4
-rwxr-xr-xtests/test_sql.py378
10 files changed, 978 insertions, 14 deletions
diff --git a/NEWS b/NEWS
index c9cbb4c..499656d 100644
--- a/NEWS
+++ b/NEWS
@@ -6,6 +6,7 @@ What's new in psycopg 2.7
New features:
+- Added `~psycopg2.sql` module to generate SQL dynamically (:ticket:`#308`).
- Added :ref:`replication-support` (:ticket:`#322`). Main authors are
Oleksandr Shulgin and Craig Ringer, who deserve a huge thank you.
- Added `~psycopg2.extensions.parse_dsn()` and
diff --git a/doc/src/faq.rst b/doc/src/faq.rst
index 89d8a63..fb7b33d 100644
--- a/doc/src/faq.rst
+++ b/doc/src/faq.rst
@@ -151,6 +151,24 @@ Psycopg converts :sql:`json` values into Python objects but :sql:`jsonb` values
See :ref:`adapt-json` for further details.
+.. _faq-identifier:
+.. cssclass:: faq
+
+How can I pass field/table names to a query?
+ The arguments in the `~cursor.execute()` methods can only represent data
+ to pass to the query: they cannot represent a table or field name::
+
+ # This doesn't work
+ cur.execute("insert into %s values (%s)", ["my_table", 42])
+
+ If you want to build a query dynamically you can use the objects exposed
+ by the `psycopg2.sql` module::
+
+ cur.execute(
+ sql.SQL("insert into %s values (%%s)") % [sql.Identifier("my_table")],
+ [42])
+
+
.. _faq-bytea-9.0:
.. cssclass:: faq
diff --git a/doc/src/index.rst b/doc/src/index.rst
index 5cf0f24..96a1423 100644
--- a/doc/src/index.rst
+++ b/doc/src/index.rst
@@ -44,6 +44,7 @@ Psycopg 2 is both Unicode and Python 3 friendly.
advanced
extensions
extras
+ sql
tz
pool
errorcodes
diff --git a/doc/src/sql.rst b/doc/src/sql.rst
new file mode 100644
index 0000000..0aee451
--- /dev/null
+++ b/doc/src/sql.rst
@@ -0,0 +1,89 @@
+`psycopg2.sql` -- SQL string composition
+========================================
+
+.. sectionauthor:: Daniele Varrazzo <daniele.varrazzo@gmail.com>
+
+.. module:: psycopg2.sql
+
+.. versionadded:: 2.7
+
+The module contains objects and functions useful to generate SQL dynamically,
+in a convenient and safe way. SQL identifiers (e.g. names of tables and
+fields) cannot be passed to the `~cursor.execute()` method like query
+arguments::
+
+ # This will not work
+ table_name = 'my_table'
+ cur.execute("insert into %s values (%s, %s)", [table_name, 10, 20])
+
+The SQL query should be composed before the arguments are merged, for
+instance::
+
+ # This works, but it is not optimal
+ table_name = 'my_table'
+ cur.execute(
+ "insert into %s values (%%s, %%s)" % table_name,
+ [10, 20])
+
+This sort of works, but it is an accident waiting to happen: the table name
+may be an invalid SQL literal and need quoting; even more serious is the
+security problem in case the table name comes from an untrusted source. The
+name should be escaped using `~psycopg2.extensions.quote_ident()`::
+
+ # This works, but it is not optimal
+ table_name = 'my_table'
+ cur.execute(
+ "insert into %s values (%%s, %%s)" % ext.quote_ident(table_name),
+ [10, 20])
+
+This is now safe, but it somewhat ad-hoc. In case, for some reason, it is
+necessary to include a value in the query string (as opposite as in a value)
+the merging rule is still different (`~psycopg2.extensions.adapt()` should be
+used...). It is also still relatively dangerous: if `!quote_ident()` is
+forgotten somewhere, the program will usually work, but will eventually crash
+in the presence of a table or field name with containing characters to escape,
+or will present a potentially exploitable weakness.
+
+The objects exposed by the `!psycopg2.sql` module allow generating SQL
+statements on the fly, separating clearly the variable parts of the statement
+from the query parameters::
+
+ from psycopg2 import sql
+
+ cur.execute(
+ sql.SQL("insert into {} values (%s, %s)")
+ .format(sql.Identifier('my_table')),
+ [10, 20])
+
+
+.. autoclass:: Composable
+
+ .. automethod:: as_string
+
+
+.. autoclass:: SQL
+
+ .. autoattribute:: string
+
+ .. automethod:: format
+
+ .. automethod:: join
+
+
+.. autoclass:: Identifier
+
+ .. autoattribute:: string
+
+.. autoclass:: Literal
+
+ .. autoattribute:: wrapped
+
+.. autoclass:: Placeholder
+
+ .. autoattribute:: name
+
+.. autoclass:: Composed
+
+ .. autoattribute:: seq
+
+ .. automethod:: join
diff --git a/doc/src/usage.rst b/doc/src/usage.rst
index d9fea75..1366485 100644
--- a/doc/src/usage.rst
+++ b/doc/src/usage.rst
@@ -132,9 +132,10 @@ query:
>>> cur.execute("INSERT INTO foo VALUES (%s)", ("bar",)) # correct
>>> cur.execute("INSERT INTO foo VALUES (%s)", ["bar"]) # correct
-- Only variable values should be bound via this method: it shouldn't be used
- to set table or field names. For these elements, ordinary string formatting
- should be used before running `~cursor.execute()`.
+- Only query values should be bound via this method: it shouldn't be used to
+ merge table or field names to the query. If you need to generate dynamically
+ an SQL query (for instance choosing dynamically a table name) you can use
+ the facilities provided by the `psycopg2.sql` module.
diff --git a/lib/sql.py b/lib/sql.py
new file mode 100644
index 0000000..e7d42e6
--- /dev/null
+++ b/lib/sql.py
@@ -0,0 +1,424 @@
+"""SQL composition utility module
+"""
+
+# psycopg/sql.py - Implementation of the JSON adaptation objects
+#
+# Copyright (C) 2016 Daniele Varrazzo <daniele.varrazzo@gmail.com>
+#
+# psycopg2 is free software: you can redistribute it and/or modify it
+# under the terms of the GNU Lesser General Public License as published
+# by the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# In addition, as a special exception, the copyright holders give
+# permission to link this program with the OpenSSL library (or with
+# modified versions of OpenSSL that use the same license as OpenSSL),
+# and distribute linked combinations including the two.
+#
+# You must obey the GNU Lesser General Public License in all respects for
+# all of the code used other than OpenSSL.
+#
+# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
+# License for more details.
+
+import sys
+import string
+
+from psycopg2 import extensions as ext
+
+
+_formatter = string.Formatter()
+
+
+class Composable(object):
+ """
+ Abstract base class for objects that can be used to compose an SQL string.
+
+ `!Composable` objects can be passed directly to `~cursor.execute()` and
+ `~cursor.executemany()` in place of the query string.
+
+ `!Composable` objects can be joined using the ``+`` operator: the result
+ will be a `Composed` instance containing the objects joined. The operator
+ ``*`` is also supported with an integer argument: the result is a
+ `!Composed` instance containing the left argument repeated as many times as
+ requested.
+ """
+ def __init__(self, wrapped):
+ self._wrapped = wrapped
+
+ def __repr__(self):
+ return "%s(%r)" % (self.__class__.__name__, self._wrapped)
+
+ def as_string(self, context):
+ """
+ Return the string value of the object.
+
+ :param context: the context to evaluate the string into.
+ :type context: `connection` or `cursor`
+
+ The method is automatically invoked by `~cursor.execute()` and
+ `~cursor.executemany()` if a `!Composable` is passed instead of the
+ query string.
+ """
+ raise NotImplementedError
+
+ def __add__(self, other):
+ if isinstance(other, Composed):
+ return Composed([self]) + other
+ if isinstance(other, Composable):
+ return Composed([self]) + Composed([other])
+ else:
+ return NotImplemented
+
+ def __mul__(self, n):
+ return Composed([self] * n)
+
+ def __eq__(self, other):
+ return type(self) is type(other) and self._wrapped == other._wrapped
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+
+class Composed(Composable):
+ """
+ A `Composable` object made of a sequence of `Composable`.
+
+ The object is usually created using `Composable` operators and methods.
+ However it is possible to create a `!Composed` directly specifying a
+ sequence of `Composable` as arguments.
+
+ Example::
+
+ >>> comp = sql.Composed(
+ ... [sql.SQL("insert into "), sql.Identifier("table")])
+ >>> print(comp.as_string(conn))
+ insert into "table"
+
+ `!Composed` objects are iterable (so they can be used in `SQL.join` for
+ instance).
+ """
+ def __init__(self, seq):
+ wrapped = []
+ for i in seq:
+ if not isinstance(i, Composable):
+ raise TypeError(
+ "Composed elements must be Composable, got %r instead" % i)
+ wrapped.append(i)
+
+ super(Composed, self).__init__(wrapped)
+
+ @property
+ def seq(self):
+ """The list of the content of the `!Composed`."""
+ return list(self._wrapped)
+
+ def as_string(self, context):
+ rv = []
+ for i in self._wrapped:
+ rv.append(i.as_string(context))
+ return ''.join(rv)
+
+ def __iter__(self):
+ return iter(self._wrapped)
+
+ def __add__(self, other):
+ if isinstance(other, Composed):
+ return Composed(self._wrapped + other._wrapped)
+ if isinstance(other, Composable):
+ return Composed(self._wrapped + [other])
+ else:
+ return NotImplemented
+
+ def join(self, joiner):
+ """
+ Return a new `!Composed` interposing the *joiner* with the `!Composed` items.
+
+ The *joiner* must be a `SQL` or a string which will be interpreted as
+ an `SQL`.
+
+ Example::
+
+ >>> fields = sql.Identifier('foo') + sql.Identifier('bar') # a Composed
+ >>> print(fields.join(', ').as_string(conn))
+ "foo", "bar"
+
+ """
+ if isinstance(joiner, basestring):
+ joiner = SQL(joiner)
+ elif not isinstance(joiner, SQL):
+ raise TypeError(
+ "Composed.join() argument must be a string or an SQL")
+
+ return joiner.join(self)
+
+
+class SQL(Composable):
+ """
+ A `Composable` representing a snippet of SQL statement.
+
+ `!SQL` exposes `join()` and `format()` methods useful to create a template
+ where to merge variable parts of a query (for instance field or table
+ names).
+
+ The *string* doesn't undergo any form of escaping, so it is not suitable to
+ represent variable identifiers or values: you should only use it to pass
+ constant strings representing templates or snippets of SQL statements; use
+ other objects such as `Identifier` or `Literal` to represent variable
+ parts.
+
+ Example::
+
+ >>> query = sql.SQL("select {0} from {1}").format(
+ ... sql.SQL(', ').join([sql.Identifier('foo'), sql.Identifier('bar')]),
+ ... sql.Identifier('table'))
+ >>> print(query.as_string(conn))
+ select "foo", "bar" from "table"
+ """
+ def __init__(self, string):
+ if not isinstance(string, basestring):
+ raise TypeError("SQL values must be strings")
+ super(SQL, self).__init__(string)
+
+ @property
+ def string(self):
+ """The string wrapped by the `!SQL` object."""
+ return self._wrapped
+
+ def as_string(self, context):
+ return self._wrapped
+
+ def format(self, *args, **kwargs):
+ """
+ Merge `Composable` objects into a template.
+
+ :param `Composable` args: parameters to replace to numbered
+ (``{0}``, ``{1}``) or auto-numbered (``{}``) placeholders
+ :param `Composable` kwargs: parameters to replace to named (``{name}``)
+ placeholders
+ :return: the union of the `!SQL` string with placeholders replaced
+ :rtype: `Composed`
+
+ The method is similar to the Python `str.format()` method: the string
+ template supports auto-numbered (``{}``, only available from Python
+ 2.7), numbered (``{0}``, ``{1}``...), and named placeholders
+ (``{name}``), with positional arguments replacing the numbered
+ placeholders and keywords replacing the named ones. However placeholder
+ modifiers (``{0!r}``, ``{0:<10}``) are not supported. Only
+ `!Composable` objects can be passed to the template.
+
+ Example::
+
+ >>> print(sql.SQL("select * from {} where {} = %s")
+ ... .format(sql.Identifier('people'), sql.Identifier('id'))
+ ... .as_string(conn))
+ select * from "people" where "id" = %s
+
+ >>> print(sql.SQL("select * from {tbl} where {pkey} = %s")
+ ... .format(tbl=sql.Identifier('people'), pkey=sql.Identifier('id'))
+ ... .as_string(conn))
+ select * from "people" where "id" = %s
+
+ """
+ rv = []
+ autonum = 0
+ for pre, name, spec, conv in _formatter.parse(self._wrapped):
+ if spec:
+ raise ValueError("no format specification supported by SQL")
+ if conv:
+ raise ValueError("no format conversion supported by SQL")
+ if pre:
+ rv.append(SQL(pre))
+
+ if name is None:
+ continue
+
+ if name.isdigit():
+ if autonum:
+ raise ValueError(
+ "cannot switch from automatic field numbering to manual")
+ rv.append(args[int(name)])
+ autonum = None
+
+ elif not name:
+ if autonum is None:
+ raise ValueError(
+ "cannot switch from manual field numbering to automatic")
+ rv.append(args[autonum])
+ autonum += 1
+
+ else:
+ rv.append(kwargs[name])
+
+ return Composed(rv)
+
+ def join(self, seq):
+ """
+ Join a sequence of `Composable`.
+
+ :param seq: the elements to join.
+ :type seq: iterable of `!Composable`
+
+ Use the `!SQL` object's *string* to separate the elements in *seq*.
+ Note that `Composed` objects are iterable too, so they can be used as
+ argument for this method.
+
+ Example::
+
+ >>> snip = sql.SQL(', ').join(
+ ... sql.Identifier(n) for n in ['foo', 'bar', 'baz'])
+ >>> print(snip.as_string(conn))
+ "foo", "bar", "baz"
+ """
+ rv = []
+ it = iter(seq)
+ try:
+ rv.append(it.next())
+ except StopIteration:
+ pass
+ else:
+ for i in it:
+ rv.append(self)
+ rv.append(i)
+
+ return Composed(rv)
+
+
+class Identifier(Composable):
+ """
+ A `Composable` representing an SQL identifer.
+
+ Identifiers usually represent names of database objects, such as tables
+ or fields. They follow `different rules`__ than SQL string literals for
+ escaping (e.g. they use double quotes).
+
+ .. __: https://www.postgresql.org/docs/current/static/sql-syntax-lexical.html# \
+ SQL-SYNTAX-IDENTIFIERS
+
+ Example::
+
+ >>> t1 = sql.Identifier("foo")
+ >>> t2 = sql.Identifier("ba'r")
+ >>> t3 = sql.Identifier('ba"z')
+ >>> print(sql.SQL(', ').join([t1, t2, t3]).as_string(conn))
+ "foo", "ba'r", "ba""z"
+
+ """
+ def __init__(self, string):
+ if not isinstance(string, basestring):
+ raise TypeError("SQL identifiers must be strings")
+
+ super(Identifier, self).__init__(string)
+
+ @property
+ def string(self):
+ """The string wrapped by the `Identifier`."""
+ return self._wrapped
+
+ def as_string(self, context):
+ return ext.quote_ident(self._wrapped, context)
+
+
+class Literal(Composable):
+ """
+ A `Composable` representing an SQL value to include in a query.
+
+ Usually you will want to include placeholders in the query and pass values
+ as `~cursor.execute()` arguments. If however you really really need to
+ include a literal value in the query you can use this object.
+
+ The string returned by `!as_string()` follows the normal :ref:`adaptation
+ rules <python-types-adaptation>` for Python objects.
+
+ Example::
+
+ >>> s1 = sql.Literal("foo")
+ >>> s2 = sql.Literal("ba'r")
+ >>> s3 = sql.Literal(42)
+ >>> print(sql.SQL(', ').join([s1, s2, s3]).as_string(conn))
+ 'foo', 'ba''r', 42
+
+ """
+ @property
+ def wrapped(self):
+ """The object wrapped by the `!Literal`."""
+ return self._wrapped
+
+ def as_string(self, context):
+ # is it a connection or cursor?
+ if isinstance(context, ext.connection):
+ conn = context
+ elif isinstance(context, ext.cursor):
+ conn = context.connection
+ else:
+ raise TypeError("context must be a connection or a cursor")
+
+ a = ext.adapt(self._wrapped)
+ if hasattr(a, 'prepare'):
+ a.prepare(conn)
+
+ rv = a.getquoted()
+ if sys.version_info[0] >= 3 and isinstance(rv, bytes):
+ rv = rv.decode(ext.encodings[conn.encoding])
+
+ return rv
+
+
+class Placeholder(Composable):
+ """A `Composable` representing a placeholder for query parameters.
+
+ If the name is specified, generate a named placeholder (e.g. ``%(name)s``),
+ otherwise generate a positional placeholder (e.g. ``%s``).
+
+ The object is useful to generate SQL queries with a variable number of
+ arguments.
+
+ Examples::
+
+ >>> names = ['foo', 'bar', 'baz']
+
+ >>> q1 = sql.SQL("insert into table ({}) values ({})").format(
+ ... sql.SQL(', ').join(map(sql.Identifier, names)),
+ ... sql.SQL(', ').join(sql.Placeholder() * len(names)))
+ >>> print(q1.as_string(conn))
+ insert into table ("foo", "bar", "baz") values (%s, %s, %s)
+
+ >>> q2 = sql.SQL("insert into table ({}) values ({})").format(
+ ... sql.SQL(', ').join(map(sql.Identifier, names)),
+ ... sql.SQL(', ').join(map(sql.Placeholder, names)))
+ >>> print(q2.as_string(conn))
+ insert into table ("foo", "bar", "baz") values (%(foo)s, %(bar)s, %(baz)s)
+
+ """
+
+ def __init__(self, name=None):
+ if isinstance(name, basestring):
+ if ')' in name:
+ raise ValueError("invalid name: %r" % name)
+
+ elif name is not None:
+ raise TypeError("expected string or None as name, got %r" % name)
+
+ super(Placeholder, self).__init__(name)
+
+ @property
+ def name(self):
+ """The name of the `!Placeholder`."""
+ return self._wrapped
+
+ def __repr__(self):
+ return "Placeholder(%r)" % (
+ self._wrapped if self._wrapped is not None else '',)
+
+ def as_string(self, context):
+ if self._wrapped is not None:
+ return "%%(%s)s" % self._wrapped
+ else:
+ return "%s"
+
+
+# Literals
+NULL = SQL("NULL")
+DEFAULT = SQL("DEFAULT")
diff --git a/psycopg/cursor_type.c b/psycopg/cursor_type.c
index a7303c6..5031033 100644
--- a/psycopg/cursor_type.c
+++ b/psycopg/cursor_type.c
@@ -267,10 +267,35 @@ _mogrify(PyObject *var, PyObject *fmt, cursorObject *curs, PyObject **new)
return 0;
}
+/* Return 1 if `obj` is a `psycopg2.sql.Composable` instance, else 0
+ * Set an exception and return -1 in case of error.
+ */
+RAISES_NEG static int
+_curs_is_composible(PyObject *obj)
+{
+ int rv = -1;
+ PyObject *m = NULL;
+ PyObject *comp = NULL;
+
+ if (!(m = PyImport_ImportModule("psycopg2.sql"))) { goto exit; }
+ if (!(comp = PyObject_GetAttrString(m, "Composable"))) { goto exit; }
+ rv = PyObject_IsInstance(obj, comp);
+
+exit:
+ Py_XDECREF(comp);
+ Py_XDECREF(m);
+ return rv;
+
+}
+
static PyObject *_psyco_curs_validate_sql_basic(
cursorObject *self, PyObject *sql
)
{
+ PyObject *rv = NULL;
+ PyObject *comp = NULL;
+ int iscomp;
+
/* Performs very basic validation on an incoming SQL string.
Returns a new reference to a str instance on success; NULL on failure,
after having set an exception. */
@@ -278,26 +303,48 @@ static PyObject *_psyco_curs_validate_sql_basic(
if (!sql || !PyObject_IsTrue(sql)) {
psyco_set_error(ProgrammingError, self,
"can't execute an empty query");
- goto fail;
+ goto exit;
}
if (Bytes_Check(sql)) {
/* Necessary for ref-count symmetry with the unicode case: */
Py_INCREF(sql);
+ rv = sql;
}
else if (PyUnicode_Check(sql)) {
- if (!(sql = conn_encode(self->conn, sql))) { goto fail; }
+ if (!(rv = conn_encode(self->conn, sql))) { goto exit; }
+ }
+ else if (0 != (iscomp = _curs_is_composible(sql))) {
+ if (iscomp < 0) { goto exit; }
+ if (!(comp = PyObject_CallMethod(sql, "as_string", "O", self->conn))) {
+ goto exit;
+ }
+
+ if (Bytes_Check(comp)) {
+ rv = comp;
+ comp = NULL;
+ }
+ else if (PyUnicode_Check(comp)) {
+ if (!(rv = conn_encode(self->conn, comp))) { goto exit; }
+ }
+ else {
+ PyErr_Format(PyExc_TypeError,
+ "as_string() should return a string: got %s instead",
+ Py_TYPE(comp)->tp_name);
+ goto exit;
+ }
}
else {
/* the is not unicode or string, raise an error */
- PyErr_SetString(PyExc_TypeError,
- "argument 1 must be a string or unicode object");
- goto fail;
+ PyErr_Format(PyExc_TypeError,
+ "argument 1 must be a string or unicode object: got %s instead",
+ Py_TYPE(sql)->tp_name);
+ goto exit;
}
- return sql; /* new reference */
- fail:
- return NULL;
+exit:
+ Py_XDECREF(comp);
+ return rv;
}
/* Merge together a query string and its arguments.
diff --git a/tests/__init__.py b/tests/__init__.py
index fd12d07..85a4ec9 100755
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -36,7 +36,6 @@ import test_bugX000
import test_bug_gc
import test_cancel
import test_connection
-import test_replication
import test_copy
import test_cursor
import test_dates
@@ -50,6 +49,8 @@ import test_module
import test_notify
import test_psycopg2_dbapi20
import test_quote
+import test_replication
+import test_sql
import test_transaction
import test_types_basic
import test_types_extras
@@ -79,7 +80,6 @@ def test_suite():
suite.addTest(test_bug_gc.test_suite())
suite.addTest(test_cancel.test_suite())
suite.addTest(test_connection.test_suite())
- suite.addTest(test_replication.test_suite())
suite.addTest(test_copy.test_suite())
suite.addTest(test_cursor.test_suite())
suite.addTest(test_dates.test_suite())
@@ -93,11 +93,14 @@ def test_suite():
suite.addTest(test_notify.test_suite())
suite.addTest(test_psycopg2_dbapi20.test_suite())
suite.addTest(test_quote.test_suite())
+ suite.addTest(test_replication.test_suite())
+ suite.addTest(test_sql.test_suite())
suite.addTest(test_transaction.test_suite())
suite.addTest(test_types_basic.test_suite())
suite.addTest(test_types_extras.test_suite())
suite.addTest(test_with.test_suite())
return suite
+
if __name__ == '__main__':
unittest.main(defaultTest='test_suite')
diff --git a/tests/test_ipaddress.py b/tests/test_ipaddress.py
index 97eabba..49413f4 100755
--- a/tests/test_ipaddress.py
+++ b/tests/test_ipaddress.py
@@ -1,5 +1,7 @@
#!/usr/bin/env python
-# # test_ipaddress.py - tests for ipaddress support #
+#
+# test_ipaddress.py - tests for ipaddress support
+#
# Copyright (C) 2016 Daniele Varrazzo <daniele.varrazzo@gmail.com>
#
# psycopg2 is free software: you can redistribute it and/or modify it
diff --git a/tests/test_sql.py b/tests/test_sql.py
new file mode 100755
index 0000000..ffb4f1f
--- /dev/null
+++ b/tests/test_sql.py
@@ -0,0 +1,378 @@
+#!/usr/bin/env python
+
+# test_sql.py - tests for the psycopg2.sql module
+#
+# Copyright (C) 2016 Daniele Varrazzo <daniele.varrazzo@gmail.com>
+#
+# psycopg2 is free software: you can redistribute it and/or modify it
+# under the terms of the GNU Lesser General Public License as published
+# by the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# In addition, as a special exception, the copyright holders give
+# permission to link this program with the OpenSSL library (or with
+# modified versions of OpenSSL that use the same license as OpenSSL),
+# and distribute linked combinations including the two.
+#
+# You must obey the GNU Lesser General Public License in all respects for
+# all of the code used other than OpenSSL.
+#
+# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
+# License for more details.
+
+import datetime as dt
+from testutils import unittest, ConnectingTestCase, skip_before_python
+
+import psycopg2
+from psycopg2 import sql
+
+
+class SqlFormatTests(ConnectingTestCase):
+ @skip_before_python(2, 7)
+ def test_pos(self):
+ s = sql.SQL("select {} from {}").format(
+ sql.Identifier('field'), sql.Identifier('table'))
+ s1 = s.as_string(self.conn)
+ self.assert_(isinstance(s1, str))
+ self.assertEqual(s1, 'select "field" from "table"')
+
+ def test_pos_spec(self):
+ s = sql.SQL("select {0} from {1}").format(
+ sql.Identifier('field'), sql.Identifier('table'))
+ s1 = s.as_string(self.conn)
+ self.assert_(isinstance(s1, str))
+ self.assertEqual(s1, 'select "field" from "table"')
+
+ s = sql.SQL("select {1} from {0}").format(
+ sql.Identifier('table'), sql.Identifier('field'))
+ s1 = s.as_string(self.conn)
+ self.assert_(isinstance(s1, str))
+ self.assertEqual(s1, 'select "field" from "table"')
+
+ def test_dict(self):
+ s = sql.SQL("select {f} from {t}").format(
+ f=sql.Identifier('field'), t=sql.Identifier('table'))
+ s1 = s.as_string(self.conn)
+ self.assert_(isinstance(s1, str))
+ self.assertEqual(s1, 'select "field" from "table"')
+
+ def test_unicode(self):
+ s = sql.SQL(u"select {0} from {1}").format(
+ sql.Identifier(u'field'), sql.Identifier('table'))
+ s1 = s.as_string(self.conn)
+ self.assert_(isinstance(s1, unicode))
+ self.assertEqual(s1, u'select "field" from "table"')
+
+ def test_compose_literal(self):
+ s = sql.SQL("select {0};").format(sql.Literal(dt.date(2016, 12, 31)))
+ s1 = s.as_string(self.conn)
+ self.assertEqual(s1, "select '2016-12-31'::date;")
+
+ def test_compose_empty(self):
+ s = sql.SQL("select foo;").format()
+ s1 = s.as_string(self.conn)
+ self.assertEqual(s1, "select foo;")
+
+ def test_percent_escape(self):
+ s = sql.SQL("42 % {0}").format(sql.Literal(7))
+ s1 = s.as_string(self.conn)
+ self.assertEqual(s1, "42 % 7")
+
+ def test_braces_escape(self):
+ s = sql.SQL("{{{0}}}").format(sql.Literal(7))
+ self.assertEqual(s.as_string(self.conn), "{7}")
+ s = sql.SQL("{{1,{0}}}").format(sql.Literal(7))
+ self.assertEqual(s.as_string(self.conn), "{1,7}")
+
+ def test_compose_badnargs(self):
+ self.assertRaises(IndexError, sql.SQL("select {0};").format)
+
+ @skip_before_python(2, 7)
+ def test_compose_badnargs_auto(self):
+ self.assertRaises(IndexError, sql.SQL("select {};").format)
+ self.assertRaises(ValueError, sql.SQL("select {} {1};").format, 10, 20)
+ self.assertRaises(ValueError, sql.SQL("select {0} {};").format, 10, 20)
+
+ def test_compose_bad_args_type(self):
+ self.assertRaises(IndexError, sql.SQL("select {0};").format, a=10)
+ self.assertRaises(KeyError, sql.SQL("select {x};").format, 10)
+
+ def test_must_be_composable(self):
+ self.assertRaises(TypeError, sql.SQL("select {0};").format, 'foo')
+ self.assertRaises(TypeError, sql.SQL("select {0};").format, 10)
+
+ def test_no_modifiers(self):
+ self.assertRaises(ValueError, sql.SQL("select {a!r};").format, a=10)
+ self.assertRaises(ValueError, sql.SQL("select {a:<};").format, a=10)
+
+ def test_must_be_adaptable(self):
+ class Foo(object):
+ pass
+
+ self.assertRaises(psycopg2.ProgrammingError,
+ sql.SQL("select {0};").format(sql.Literal(Foo())).as_string, self.conn)
+
+ def test_execute(self):
+ cur = self.conn.cursor()
+ cur.execute("""
+ create table test_compose (
+ id serial primary key,
+ foo text, bar text, "ba'z" text)
+ """)
+ cur.execute(
+ sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format(
+ sql.Identifier('test_compose'),
+ sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])),
+ (sql.Placeholder() * 3).join(', ')),
+ (10, 'a', 'b', 'c'))
+
+ cur.execute("select * from test_compose")
+ self.assertEqual(cur.fetchall(), [(10, 'a', 'b', 'c')])
+
+ def test_executemany(self):
+ cur = self.conn.cursor()
+ cur.execute("""
+ create table test_compose (
+ id serial primary key,
+ foo text, bar text, "ba'z" text)
+ """)
+ cur.executemany(
+ sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format(
+ sql.Identifier('test_compose'),
+ sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])),
+ (sql.Placeholder() * 3).join(', ')),
+ [(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')])
+
+ cur.execute("select * from test_compose")
+ self.assertEqual(cur.fetchall(),
+ [(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')])
+
+
+class IdentifierTests(ConnectingTestCase):
+ def test_class(self):
+ self.assert_(issubclass(sql.Identifier, sql.Composable))
+
+ def test_init(self):
+ self.assert_(isinstance(sql.Identifier('foo'), sql.Identifier))
+ self.assert_(isinstance(sql.Identifier(u'foo'), sql.Identifier))
+ self.assertRaises(TypeError, sql.Identifier, 10)
+ self.assertRaises(TypeError, sql.Identifier, dt.date(2016, 12, 31))
+
+ def test_string(self):
+ self.assertEqual(sql.Identifier('foo').string, 'foo')
+
+ def test_repr(self):
+ obj = sql.Identifier("fo'o")
+ self.assertEqual(repr(obj), 'Identifier("fo\'o")')
+ self.assertEqual(repr(obj), str(obj))
+
+ def test_eq(self):
+ self.assert_(sql.Identifier('foo') == sql.Identifier('foo'))
+ self.assert_(sql.Identifier('foo') != sql.Identifier('bar'))
+ self.assert_(sql.Identifier('foo') != 'foo')
+ self.assert_(sql.Identifier('foo') != sql.SQL('foo'))
+
+ def test_as_str(self):
+ self.assertEqual(sql.Identifier('foo').as_string(self.conn), '"foo"')
+ self.assertEqual(sql.Identifier("fo'o").as_string(self.conn), '"fo\'o"')
+
+ def test_join(self):
+ self.assert_(not hasattr(sql.Identifier('foo'), 'join'))
+
+
+class LiteralTests(ConnectingTestCase):
+ def test_class(self):
+ self.assert_(issubclass(sql.Literal, sql.Composable))
+
+ def test_init(self):
+ self.assert_(isinstance(sql.Literal('foo'), sql.Literal))
+ self.assert_(isinstance(sql.Literal(u'foo'), sql.Literal))
+ self.assert_(isinstance(sql.Literal(b'foo'), sql.Literal))
+ self.assert_(isinstance(sql.Literal(42), sql.Literal))
+ self.assert_(isinstance(
+ sql.Literal(dt.date(2016, 12, 31)), sql.Literal))
+
+ def test_wrapped(self):
+ self.assertEqual(sql.Literal('foo').wrapped, 'foo')
+
+ def test_repr(self):
+ self.assertEqual(repr(sql.Literal("foo")), "Literal('foo')")
+ self.assertEqual(str(sql.Literal("foo")), "Literal('foo')")
+ self.assertEqual(
+ sql.Literal("foo").as_string(self.conn).replace("E'", "'"),
+ "'foo'")
+ self.assertEqual(sql.Literal(42).as_string(self.conn), "42")
+ self.assertEqual(
+ sql.Literal(dt.date(2017, 1, 1)).as_string(self.conn),
+ "'2017-01-01'::date")
+
+ def test_eq(self):
+ self.assert_(sql.Literal('foo') == sql.Literal('foo'))
+ self.assert_(sql.Literal('foo') != sql.Literal('bar'))
+ self.assert_(sql.Literal('foo') != 'foo')
+ self.assert_(sql.Literal('foo') != sql.SQL('foo'))
+
+ def test_must_be_adaptable(self):
+ class Foo(object):
+ pass
+
+ self.assertRaises(psycopg2.ProgrammingError,
+ sql.Literal(Foo()).as_string, self.conn)
+
+
+class SQLTests(ConnectingTestCase):
+ def test_class(self):
+ self.assert_(issubclass(sql.SQL, sql.Composable))
+
+ def test_init(self):
+ self.assert_(isinstance(sql.SQL('foo'), sql.SQL))
+ self.assert_(isinstance(sql.SQL(u'foo'), sql.SQL))
+ self.assertRaises(TypeError, sql.SQL, 10)
+ self.assertRaises(TypeError, sql.SQL, dt.date(2016, 12, 31))
+
+ def test_string(self):
+ self.assertEqual(sql.SQL('foo').string, 'foo')
+
+ def test_repr(self):
+ self.assertEqual(repr(sql.SQL("foo")), "SQL('foo')")
+ self.assertEqual(str(sql.SQL("foo")), "SQL('foo')")
+ self.assertEqual(sql.SQL("foo").as_string(self.conn), "foo")
+
+ def test_eq(self):
+ self.assert_(sql.SQL('foo') == sql.SQL('foo'))
+ self.assert_(sql.SQL('foo') != sql.SQL('bar'))
+ self.assert_(sql.SQL('foo') != 'foo')
+ self.assert_(sql.SQL('foo') != sql.Literal('foo'))
+
+ def test_sum(self):
+ obj = sql.SQL("foo") + sql.SQL("bar")
+ self.assert_(isinstance(obj, sql.Composed))
+ self.assertEqual(obj.as_string(self.conn), "foobar")
+
+ def test_sum_inplace(self):
+ obj = sql.SQL("foo")
+ obj += sql.SQL("bar")
+ self.assert_(isinstance(obj, sql.Composed))
+ self.assertEqual(obj.as_string(self.conn), "foobar")
+
+ def test_multiply(self):
+ obj = sql.SQL("foo") * 3
+ self.assert_(isinstance(obj, sql.Composed))
+ self.assertEqual(obj.as_string(self.conn), "foofoofoo")
+
+ def test_join(self):
+ obj = sql.SQL(", ").join(
+ [sql.Identifier('foo'), sql.SQL('bar'), sql.Literal(42)])
+ self.assert_(isinstance(obj, sql.Composed))
+ self.assertEqual(obj.as_string(self.conn), '"foo", bar, 42')
+
+ obj = sql.SQL(", ").join(
+ sql.Composed([sql.Identifier('foo'), sql.SQL('bar'), sql.Literal(42)]))
+ self.assert_(isinstance(obj, sql.Composed))
+ self.assertEqual(obj.as_string(self.conn), '"foo", bar, 42')
+
+ obj = sql.SQL(", ").join([])
+ self.assertEqual(obj, sql.Composed([]))
+
+
+class ComposedTest(ConnectingTestCase):
+ def test_class(self):
+ self.assert_(issubclass(sql.Composed, sql.Composable))
+
+ def test_repr(self):
+ obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")])
+ self.assertEqual(repr(obj),
+ """Composed([Literal('foo'), Identifier("b'ar")])""")
+ self.assertEqual(str(obj), repr(obj))
+
+ def test_seq(self):
+ l = [sql.SQL('foo'), sql.Literal('bar'), sql.Identifier('baz')]
+ self.assertEqual(sql.Composed(l).seq, l)
+
+ def test_eq(self):
+ l = [sql.Literal("foo"), sql.Identifier("b'ar")]
+ l2 = [sql.Literal("foo"), sql.Literal("b'ar")]
+ self.assert_(sql.Composed(l) == sql.Composed(list(l)))
+ self.assert_(sql.Composed(l) != l)
+ self.assert_(sql.Composed(l) != sql.Composed(l2))
+
+ def test_join(self):
+ obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")])
+ obj = obj.join(", ")
+ self.assert_(isinstance(obj, sql.Composed))
+ self.assertEqual(obj.as_string(self.conn), "'foo', \"b'ar\"")
+
+ def test_sum(self):
+ obj = sql.Composed([sql.SQL("foo ")])
+ obj = obj + sql.Literal("bar")
+ self.assert_(isinstance(obj, sql.Composed))
+ self.assertEqual(obj.as_string(self.conn), "foo 'bar'")
+
+ def test_sum_inplace(self):
+ obj = sql.Composed([sql.SQL("foo ")])
+ obj += sql.Literal("bar")
+ self.assert_(isinstance(obj, sql.Composed))
+ self.assertEqual(obj.as_string(self.conn), "foo 'bar'")
+
+ obj = sql.Composed([sql.SQL("foo ")])
+ obj += sql.Composed([sql.Literal("bar")])
+ self.assert_(isinstance(obj, sql.Composed))
+ self.assertEqual(obj.as_string(self.conn), "foo 'bar'")
+
+ def test_iter(self):
+ obj = sql.Composed([sql.SQL("foo"), sql.SQL('bar')])
+ it = iter(obj)
+ i = it.next()
+ self.assertEqual(i, sql.SQL('foo'))
+ i = it.next()
+ self.assertEqual(i, sql.SQL('bar'))
+ self.assertRaises(StopIteration, it.next)
+
+
+class PlaceholderTest(ConnectingTestCase):
+ def test_class(self):
+ self.assert_(issubclass(sql.Placeholder, sql.Composable))
+
+ def test_name(self):
+ self.assertEqual(sql.Placeholder().name, None)
+ self.assertEqual(sql.Placeholder('foo').name, 'foo')
+
+ def test_repr(self):
+ self.assert_(str(sql.Placeholder()), 'Placeholder()')
+ self.assert_(repr(sql.Placeholder()), 'Placeholder()')
+ self.assert_(sql.Placeholder().as_string(self.conn), '%s')
+
+ def test_repr_name(self):
+ self.assert_(str(sql.Placeholder('foo')), "Placeholder('foo')")
+ self.assert_(repr(sql.Placeholder('foo')), "Placeholder('foo')")
+ self.assert_(sql.Placeholder('foo').as_string(self.conn), '%(foo)s')
+
+ def test_bad_name(self):
+ self.assertRaises(ValueError, sql.Placeholder, ')')
+
+ def test_eq(self):
+ self.assert_(sql.Placeholder('foo') == sql.Placeholder('foo'))
+ self.assert_(sql.Placeholder('foo') != sql.Placeholder('bar'))
+ self.assert_(sql.Placeholder('foo') != 'foo')
+ self.assert_(sql.Placeholder() == sql.Placeholder())
+ self.assert_(sql.Placeholder('foo') != sql.Placeholder())
+ self.assert_(sql.Placeholder('foo') != sql.Literal('foo'))
+
+
+class ValuesTest(ConnectingTestCase):
+ def test_null(self):
+ self.assert_(isinstance(sql.NULL, sql.SQL))
+ self.assertEqual(sql.NULL.as_string(self.conn), "NULL")
+
+ def test_default(self):
+ self.assert_(isinstance(sql.DEFAULT, sql.SQL))
+ self.assertEqual(sql.DEFAULT.as_string(self.conn), "DEFAULT")
+
+
+def test_suite():
+ return unittest.TestLoader().loadTestsFromName(__name__)
+
+if __name__ == "__main__":
+ unittest.main()