diff options
-rw-r--r-- | lib/sql.py | 174 | ||||
-rwxr-xr-x | tests/test_sql.py | 223 |
2 files changed, 396 insertions, 1 deletions
@@ -23,3 +23,177 @@ # FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public # License for more details. +from psycopg2 import extensions as ext + + +class Composible(object): + """Base class for objects that can be used to compose an SQL string.""" + def as_string(self, conn_or_curs): + raise NotImplementedError + + def __add__(self, other): + if isinstance(other, Composed): + return Composed([self]) + other + if isinstance(other, Composible): + return Composed([self]) + Composed([other]) + else: + return NotImplemented + + +class Composed(Composible): + def __init__(self, seq): + self._seq = [] + for i in seq: + if not isinstance(i, Composible): + raise TypeError( + "Composed elements must be Composible, got %r instead" % i) + self._seq.append(i) + + def __repr__(self): + return "sql.Composed(%r)" % (self.seq,) + + def as_string(self, conn_or_curs): + rv = [] + for i in self._seq: + rv.append(i.as_string(conn_or_curs)) + return ''.join(rv) + + def __add__(self, other): + if isinstance(other, Composed): + return Composed(self._seq + other._seq) + if isinstance(other, Composible): + return Composed(self._seq + [other]) + else: + return NotImplemented + + def __mul__(self, n): + return Composed(self._seq * n) + + def join(self, joiner): + if isinstance(joiner, basestring): + joiner = SQL(joiner) + elif not isinstance(joiner, SQL): + raise TypeError( + "Composed.join() argument must be a string or an SQL") + + if len(self._seq) <= 1: + return self + + it = iter(self._seq) + rv = [it.next()] + for i in it: + rv.append(joiner) + rv.append(i) + + return Composed(rv) + + +class SQL(Composible): + def __init__(self, wrapped): + if not isinstance(wrapped, basestring): + raise TypeError("SQL values must be strings") + self._wrapped = wrapped + + def __repr__(self): + return "sql.SQL(%r)" % (self._wrapped,) + + def as_string(self, conn_or_curs): + return self._wrapped + + def __mul__(self, n): + return Composed([self] * n) + + def join(self, seq): + rv = [] + it = iter(seq) + try: + rv.append(it.next()) + except StopIteration: + pass + else: + for i in it: + rv.append(self) + rv.append(i) + + return Composed(rv) + + +class Identifier(Composible): + def __init__(self, wrapped): + if not isinstance(wrapped, basestring): + raise TypeError("SQL identifiers must be strings") + + self._wrapped = wrapped + + @property + def wrapped(self): + return self._wrapped + + def __repr__(self): + return "sql.Identifier(%r)" % (self._wrapped,) + + def as_string(self, conn_or_curs): + return ext.quote_ident(self._wrapped, conn_or_curs) + + +class Literal(Composible): + def __init__(self, wrapped): + self._wrapped = wrapped + + def __repr__(self): + return "sql.Literal(%r)" % (self._wrapped,) + + def as_string(self, conn_or_curs): + a = ext.adapt(self._wrapped) + if hasattr(a, 'prepare'): + # is it a connection or cursor? + if isinstance(conn_or_curs, ext.connection): + conn = conn_or_curs + elif isinstance(conn_or_curs, ext.cursor): + conn = conn_or_curs.connection + else: + raise TypeError("conn_or_curs must be a connection or a cursor") + + a.prepare(conn) + + return a.getquoted() + + def __mul__(self, n): + return Composed([self] * n) + + +class Placeholder(Composible): + def __init__(self, name=None): + if isinstance(name, basestring): + if ')' in name: + raise ValueError("invalid name: %r" % name) + + elif name is not None: + raise TypeError("expected string or None as name, got %r" % name) + + self._name = name + + def __repr__(self): + return "sql.Placeholder(%r)" % ( + self._name if self._name is not None else '',) + + def __mul__(self, n): + return Composed([self] * n) + + def as_string(self, conn_or_curs): + if self._name is not None: + return "%%(%s)s" % self._name + else: + return "%s" + + +def compose(sql, args=()): + raise NotImplementedError + + +# Alias +PH = Placeholder + +# Literals +NULL = SQL("NULL") +DEFAULT = SQL("DEFAULT") diff --git a/tests/test_sql.py b/tests/test_sql.py index e2f6670..510b545 100755 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -22,7 +22,228 @@ # FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public # License for more details. -from testutils import unittest +import datetime as dt +from testutils import unittest, ConnectingTestCase + +from psycopg2 import sql + + +class ComposeTests(ConnectingTestCase): + def test_pos(self): + s = sql.compose("select %s from %s", + (sql.Identifier('field'), sql.Identifier('table'))) + s1 = s.as_string(self.conn) + self.assert_(isinstance(s1, str)) + self.assertEqual(s1, 'select "field" from "table"') + + def test_dict(self): + s = sql.compose("select %(f)s from %(t)s", + {'f': sql.Identifier('field'), 't': sql.Identifier('table')}) + s1 = s.as_string(self.conn) + self.assert_(isinstance(s1, str)) + self.assertEqual(s1, 'select "field" from "table"') + + def test_unicode(self): + s = sql.compose(u"select %s from %s", + (sql.Identifier(u'field'), sql.Identifier('table'))) + s1 = s.as_string(self.conn) + self.assert_(isinstance(s1, unicode)) + self.assertEqual(s1, u'select "field" from "table"') + + def test_compose_literal(self): + s = sql.compose("select %s;", [sql.Literal(dt.date(2016, 12, 31))]) + s1 = s.as_string(self.conn) + self.assertEqual(s1, "select '2016-12-31'::date;") + + def test_must_be_adaptable(self): + class Foo(object): + pass + + self.assertRaises(TypeError, + sql.compose, "select %s;", [Foo()]) + + def test_execute(self): + cur = self.conn.cursor() + cur.execute(""" + create table test_compose ( + id serial primary key, + foo text, bar text, "ba'z" text) + """) + cur.execute( + sql.compose("insert into %s (id, %s) values (%%s, %s)", [ + sql.Identifier('test_compose'), + sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])), + (sql.PH() * 3).join(', '), + ]), + (10, 'a', 'b', 'c')) + + cur.execute("select * from test_compose") + self.assertEqual(cur.fetchall(), [(10, 'a', 'b', 'c')]) + + def test_executemany(self): + cur = self.conn.cursor() + cur.execute(""" + create table test_compose ( + id serial primary key, + foo text, bar text, "ba'z" text) + """) + cur.executemany( + sql.compose("insert into %s (id, %s) values (%%s, %s)", [ + sql.Identifier('test_compose'), + sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])), + (sql.PH() * 3).join(', '), + ]), + [(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')]) + + cur.execute("select * from test_compose") + self.assertEqual(cur.fetchall(), + [(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')]) + + +class IdentifierTests(ConnectingTestCase): + def test_class(self): + self.assert_(issubclass(sql.Identifier, sql.Composible)) + + def test_init(self): + self.assert_(isinstance(sql.Identifier('foo'), sql.Identifier)) + self.assert_(isinstance(sql.Identifier(u'foo'), sql.Identifier)) + self.assert_(isinstance(sql.Identifier(b'foo'), sql.Identifier)) + self.assertRaises(TypeError, sql.Identifier, 10) + self.assertRaises(TypeError, sql.Identifier, dt.date(2016, 12, 31)) + + def test_repr(self): + obj = sql.Identifier("fo'o") + self.assertEqual(repr(obj), 'sql.Identifier("fo\'o")') + self.assertEqual(repr(obj), str(obj)) + + def test_as_str(self): + self.assertEqual(sql.Identifier('foo').as_string(self.conn), '"foo"') + self.assertEqual(sql.Identifier("fo'o").as_string(self.conn), '"fo\'o"') + + def test_join(self): + self.assert_(not hasattr(sql.Identifier('foo'), 'join')) + + +class LiteralTests(ConnectingTestCase): + def test_class(self): + self.assert_(issubclass(sql.Literal, sql.Composible)) + + def test_init(self): + self.assert_(isinstance(sql.Literal('foo'), sql.Literal)) + self.assert_(isinstance(sql.Literal(u'foo'), sql.Literal)) + self.assert_(isinstance(sql.Literal(b'foo'), sql.Literal)) + self.assert_(isinstance(sql.Literal(42), sql.Literal)) + self.assert_(isinstance( + sql.Literal(dt.date(2016, 12, 31)), sql.Literal)) + + def test_repr(self): + self.assertEqual(repr(sql.Literal("foo")), "sql.Literal('foo')") + self.assertEqual(str(sql.Literal("foo")), "sql.Literal('foo')") + self.assertEqual( + sql.Literal("foo").as_string(self.conn).replace("E'", "'"), + "'foo'") + self.assertEqual(sql.Literal(42).as_string(self.conn), "42") + self.assertEqual( + sql.Literal(dt.date(2017, 1, 1)).as_string(self.conn), + "'2017-01-01'::date") + + +class SQLTests(ConnectingTestCase): + def test_class(self): + self.assert_(issubclass(sql.SQL, sql.Composible)) + + def test_init(self): + self.assert_(isinstance(sql.SQL('foo'), sql.SQL)) + self.assert_(isinstance(sql.SQL(u'foo'), sql.SQL)) + self.assert_(isinstance(sql.SQL(b'foo'), sql.SQL)) + self.assertRaises(TypeError, sql.SQL, 10) + self.assertRaises(TypeError, sql.SQL, dt.date(2016, 12, 31)) + + def test_str(self): + self.assertEqual(repr(sql.SQL("foo")), "sql.SQL('foo')") + self.assertEqual(str(sql.SQL("foo")), "sql.SQL('foo')") + self.assertEqual(sql.SQL("foo").as_string(self.conn), "foo") + + def test_sum(self): + obj = sql.SQL("foo") + sql.SQL("bar") + self.assert_(isinstance(obj, sql.Composed)) + self.assertEqual(obj.as_string(self.conn), "foobar") + + def test_sum_inplace(self): + obj = sql.SQL("foo") + obj += sql.SQL("bar") + self.assert_(isinstance(obj, sql.Composed)) + self.assertEqual(obj.as_string(self.conn), "foobar") + + def test_multiply(self): + obj = sql.SQL("foo") * 3 + self.assert_(isinstance(obj, sql.Composed)) + self.assertEqual(obj.as_string(self.conn), "foofoofoo") + + def test_join(self): + obj = sql.SQL(", ").join( + [sql.Identifier('foo'), sql.SQL('bar'), sql.Literal(42)]) + self.assert_(isinstance(obj, sql.Composed)) + self.assertEqual(obj.as_string(self.conn), '"foo", bar, 42') + + +class ComposedTest(ConnectingTestCase): + def test_class(self): + self.assert_(issubclass(sql.Composed, sql.Composible)) + + def test_join(self): + obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")]) + obj = obj.join(", ") + self.assert_(isinstance(obj, sql.Composed)) + self.assertEqual(obj.as_string(self.conn), "'foo', \"b'ar\"") + + def test_sum(self): + obj = sql.Composed([sql.SQL("foo ")]) + obj = obj + sql.Literal("bar") + self.assert_(isinstance(obj, sql.Composed)) + self.assertEqual(obj.as_string(self.conn), "foo 'bar'") + + def test_sum_inplace(self): + obj = sql.Composed([sql.SQL("foo ")]) + obj += sql.Literal("bar") + self.assert_(isinstance(obj, sql.Composed)) + self.assertEqual(obj.as_string(self.conn), "foo 'bar'") + + obj = sql.Composed([sql.SQL("foo ")]) + obj += sql.Composed([sql.Literal("bar")]) + self.assert_(isinstance(obj, sql.Composed)) + self.assertEqual(obj.as_string(self.conn), "foo 'bar'") + + +class PlaceholderTest(ConnectingTestCase): + def test_class(self): + self.assert_(issubclass(sql.Placeholder, sql.Composible)) + + def test_alias(self): + self.assert_(sql.Placeholder is sql.PH) + + def test_repr(self): + self.assert_(str(sql.Placeholder()), 'sql.Placeholder()') + self.assert_(repr(sql.Placeholder()), 'sql.Placeholder()') + self.assert_(sql.Placeholder().as_string(self.conn), '%s') + + def test_repr_name(self): + self.assert_(str(sql.Placeholder('foo')), "sql.Placeholder('foo')") + self.assert_(repr(sql.Placeholder('foo')), "sql.Placeholder('foo')") + self.assert_(sql.Placeholder('foo').as_string(self.conn), '%(foo)s') + + def test_bad_name(self): + self.assertRaises(ValueError, sql.Placeholder, ')') + + +class ValuesTest(ConnectingTestCase): + def test_null(self): + self.assert_(isinstance(sql.NULL, sql.SQL)) + self.assertEqual(sql.NULL.as_string(self.conn), "NULL") + + def test_default(self): + self.assert_(isinstance(sql.DEFAULT, sql.SQL)) + self.assertEqual(sql.DEFAULT.as_string(self.conn), "DEFAULT") def test_suite(): |