summaryrefslogtreecommitdiff
path: root/tests/test_sql.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_sql.py')
-rwxr-xr-xtests/test_sql.py89
1 files changed, 74 insertions, 15 deletions
diff --git a/tests/test_sql.py b/tests/test_sql.py
index c2268fd..21c761e 100755
--- a/tests/test_sql.py
+++ b/tests/test_sql.py
@@ -120,7 +120,7 @@ class SqlFormatTests(ConnectingTestCase):
sql.SQL("insert into {} (id, {}) values (%s, {})").format(
sql.Identifier('test_compose'),
sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])),
- (sql.PH() * 3).join(', ')),
+ (sql.Placeholder() * 3).join(', ')),
(10, 'a', 'b', 'c'))
cur.execute("select * from test_compose")
@@ -137,7 +137,7 @@ class SqlFormatTests(ConnectingTestCase):
sql.SQL("insert into {} (id, {}) values (%s, {})").format(
sql.Identifier('test_compose'),
sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])),
- (sql.PH() * 3).join(', ')),
+ (sql.Placeholder() * 3).join(', ')),
[(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')])
cur.execute("select * from test_compose")
@@ -155,11 +155,20 @@ class IdentifierTests(ConnectingTestCase):
self.assertRaises(TypeError, sql.Identifier, 10)
self.assertRaises(TypeError, sql.Identifier, dt.date(2016, 12, 31))
+ def test_string(self):
+ self.assertEqual(sql.Identifier('foo').string, 'foo')
+
def test_repr(self):
obj = sql.Identifier("fo'o")
- self.assertEqual(repr(obj), 'sql.Identifier("fo\'o")')
+ self.assertEqual(repr(obj), 'Identifier("fo\'o")')
self.assertEqual(repr(obj), str(obj))
+ def test_eq(self):
+ self.assert_(sql.Identifier('foo') == sql.Identifier('foo'))
+ self.assert_(sql.Identifier('foo') != sql.Identifier('bar'))
+ self.assert_(sql.Identifier('foo') != 'foo')
+ self.assert_(sql.Identifier('foo') != sql.SQL('foo'))
+
def test_as_str(self):
self.assertEqual(sql.Identifier('foo').as_string(self.conn), '"foo"')
self.assertEqual(sql.Identifier("fo'o").as_string(self.conn), '"fo\'o"')
@@ -180,9 +189,12 @@ class LiteralTests(ConnectingTestCase):
self.assert_(isinstance(
sql.Literal(dt.date(2016, 12, 31)), sql.Literal))
+ def test_wrapped(self):
+ self.assertEqual(sql.Literal('foo').wrapped, 'foo')
+
def test_repr(self):
- self.assertEqual(repr(sql.Literal("foo")), "sql.Literal('foo')")
- self.assertEqual(str(sql.Literal("foo")), "sql.Literal('foo')")
+ self.assertEqual(repr(sql.Literal("foo")), "Literal('foo')")
+ self.assertEqual(str(sql.Literal("foo")), "Literal('foo')")
self.assertEqual(
sql.Literal("foo").as_string(self.conn).replace("E'", "'"),
"'foo'")
@@ -191,6 +203,12 @@ class LiteralTests(ConnectingTestCase):
sql.Literal(dt.date(2017, 1, 1)).as_string(self.conn),
"'2017-01-01'::date")
+ def test_eq(self):
+ self.assert_(sql.Literal('foo') == sql.Literal('foo'))
+ self.assert_(sql.Literal('foo') != sql.Literal('bar'))
+ self.assert_(sql.Literal('foo') != 'foo')
+ self.assert_(sql.Literal('foo') != sql.SQL('foo'))
+
def test_must_be_adaptable(self):
class Foo(object):
pass
@@ -209,11 +227,20 @@ class SQLTests(ConnectingTestCase):
self.assertRaises(TypeError, sql.SQL, 10)
self.assertRaises(TypeError, sql.SQL, dt.date(2016, 12, 31))
- def test_str(self):
- self.assertEqual(repr(sql.SQL("foo")), "sql.SQL('foo')")
- self.assertEqual(str(sql.SQL("foo")), "sql.SQL('foo')")
+ def test_string(self):
+ self.assertEqual(sql.SQL('foo').string, 'foo')
+
+ def test_repr(self):
+ self.assertEqual(repr(sql.SQL("foo")), "SQL('foo')")
+ self.assertEqual(str(sql.SQL("foo")), "SQL('foo')")
self.assertEqual(sql.SQL("foo").as_string(self.conn), "foo")
+ def test_eq(self):
+ self.assert_(sql.SQL('foo') == sql.SQL('foo'))
+ self.assert_(sql.SQL('foo') != sql.SQL('bar'))
+ self.assert_(sql.SQL('foo') != 'foo')
+ self.assert_(sql.SQL('foo') != sql.Literal('foo'))
+
def test_sum(self):
obj = sql.SQL("foo") + sql.SQL("bar")
self.assert_(isinstance(obj, sql.Composed))
@@ -241,6 +268,9 @@ class SQLTests(ConnectingTestCase):
self.assert_(isinstance(obj, sql.Composed))
self.assertEqual(obj.as_string(self.conn), '"foo", bar, 42')
+ obj = sql.SQL(", ").join([])
+ self.assertEqual(obj, sql.Composed([]))
+
class ComposedTest(ConnectingTestCase):
def test_class(self):
@@ -249,9 +279,20 @@ class ComposedTest(ConnectingTestCase):
def test_repr(self):
obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")])
self.assertEqual(repr(obj),
- """sql.Composed([sql.Literal('foo'), sql.Identifier("b'ar")])""")
+ """Composed([Literal('foo'), Identifier("b'ar")])""")
self.assertEqual(str(obj), repr(obj))
+ def test_seq(self):
+ l = [sql.SQL('foo'), sql.Literal('bar'), sql.Identifier('baz')]
+ self.assertEqual(sql.Composed(l).seq, l)
+
+ def test_eq(self):
+ l = [sql.Literal("foo"), sql.Identifier("b'ar")]
+ l2 = [sql.Literal("foo"), sql.Literal("b'ar")]
+ self.assert_(sql.Composed(l) == sql.Composed(list(l)))
+ self.assert_(sql.Composed(l) != l)
+ self.assert_(sql.Composed(l) != sql.Composed(l2))
+
def test_join(self):
obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")])
obj = obj.join(", ")
@@ -275,27 +316,45 @@ class ComposedTest(ConnectingTestCase):
self.assert_(isinstance(obj, sql.Composed))
self.assertEqual(obj.as_string(self.conn), "foo 'bar'")
+ def test_iter(self):
+ obj = sql.Composed([sql.SQL("foo"), sql.SQL('bar')])
+ it = iter(obj)
+ i = it.next()
+ self.assertEqual(i, sql.SQL('foo'))
+ i = it.next()
+ self.assertEqual(i, sql.SQL('bar'))
+ self.assertRaises(StopIteration, it.next)
+
class PlaceholderTest(ConnectingTestCase):
def test_class(self):
self.assert_(issubclass(sql.Placeholder, sql.Composable))
- def test_alias(self):
- self.assert_(sql.Placeholder is sql.PH)
+ def test_name(self):
+ self.assertEqual(sql.Placeholder().name, None)
+ self.assertEqual(sql.Placeholder('foo').name, 'foo')
def test_repr(self):
- self.assert_(str(sql.Placeholder()), 'sql.Placeholder()')
- self.assert_(repr(sql.Placeholder()), 'sql.Placeholder()')
+ self.assert_(str(sql.Placeholder()), 'Placeholder()')
+ self.assert_(repr(sql.Placeholder()), 'Placeholder()')
self.assert_(sql.Placeholder().as_string(self.conn), '%s')
def test_repr_name(self):
- self.assert_(str(sql.Placeholder('foo')), "sql.Placeholder('foo')")
- self.assert_(repr(sql.Placeholder('foo')), "sql.Placeholder('foo')")
+ self.assert_(str(sql.Placeholder('foo')), "Placeholder('foo')")
+ self.assert_(repr(sql.Placeholder('foo')), "Placeholder('foo')")
self.assert_(sql.Placeholder('foo').as_string(self.conn), '%(foo)s')
def test_bad_name(self):
self.assertRaises(ValueError, sql.Placeholder, ')')
+ def test_eq(self):
+ self.assert_(sql.Placeholder('foo') == sql.Placeholder('foo'))
+ self.assert_(sql.Placeholder('foo') != sql.Placeholder('bar'))
+ self.assert_(sql.Placeholder('foo') != 'foo')
+ self.assert_(sql.Placeholder() == sql.Placeholder())
+ self.assert_(sql.Placeholder('foo') != sql.Placeholder())
+ self.assert_(sql.Placeholder('foo') != sql.Literal('foo'))
+
class ValuesTest(ConnectingTestCase):
def test_null(self):