diff options
Diffstat (limited to 'tests/test_sql.py')
| -rwxr-xr-x | tests/test_sql.py | 89 |
1 files changed, 74 insertions, 15 deletions
diff --git a/tests/test_sql.py b/tests/test_sql.py index c2268fd..21c761e 100755 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -120,7 +120,7 @@ class SqlFormatTests(ConnectingTestCase): sql.SQL("insert into {} (id, {}) values (%s, {})").format( sql.Identifier('test_compose'), sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])), - (sql.PH() * 3).join(', ')), + (sql.Placeholder() * 3).join(', ')), (10, 'a', 'b', 'c')) cur.execute("select * from test_compose") @@ -137,7 +137,7 @@ class SqlFormatTests(ConnectingTestCase): sql.SQL("insert into {} (id, {}) values (%s, {})").format( sql.Identifier('test_compose'), sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])), - (sql.PH() * 3).join(', ')), + (sql.Placeholder() * 3).join(', ')), [(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')]) cur.execute("select * from test_compose") @@ -155,11 +155,20 @@ class IdentifierTests(ConnectingTestCase): self.assertRaises(TypeError, sql.Identifier, 10) self.assertRaises(TypeError, sql.Identifier, dt.date(2016, 12, 31)) + def test_string(self): + self.assertEqual(sql.Identifier('foo').string, 'foo') + def test_repr(self): obj = sql.Identifier("fo'o") - self.assertEqual(repr(obj), 'sql.Identifier("fo\'o")') + self.assertEqual(repr(obj), 'Identifier("fo\'o")') self.assertEqual(repr(obj), str(obj)) + def test_eq(self): + self.assert_(sql.Identifier('foo') == sql.Identifier('foo')) + self.assert_(sql.Identifier('foo') != sql.Identifier('bar')) + self.assert_(sql.Identifier('foo') != 'foo') + self.assert_(sql.Identifier('foo') != sql.SQL('foo')) + def test_as_str(self): self.assertEqual(sql.Identifier('foo').as_string(self.conn), '"foo"') self.assertEqual(sql.Identifier("fo'o").as_string(self.conn), '"fo\'o"') @@ -180,9 +189,12 @@ class LiteralTests(ConnectingTestCase): self.assert_(isinstance( sql.Literal(dt.date(2016, 12, 31)), sql.Literal)) + def test_wrapped(self): + self.assertEqual(sql.Literal('foo').wrapped, 'foo') + def test_repr(self): - self.assertEqual(repr(sql.Literal("foo")), "sql.Literal('foo')") - self.assertEqual(str(sql.Literal("foo")), "sql.Literal('foo')") + self.assertEqual(repr(sql.Literal("foo")), "Literal('foo')") + self.assertEqual(str(sql.Literal("foo")), "Literal('foo')") self.assertEqual( sql.Literal("foo").as_string(self.conn).replace("E'", "'"), "'foo'") @@ -191,6 +203,12 @@ class LiteralTests(ConnectingTestCase): sql.Literal(dt.date(2017, 1, 1)).as_string(self.conn), "'2017-01-01'::date") + def test_eq(self): + self.assert_(sql.Literal('foo') == sql.Literal('foo')) + self.assert_(sql.Literal('foo') != sql.Literal('bar')) + self.assert_(sql.Literal('foo') != 'foo') + self.assert_(sql.Literal('foo') != sql.SQL('foo')) + def test_must_be_adaptable(self): class Foo(object): pass @@ -209,11 +227,20 @@ class SQLTests(ConnectingTestCase): self.assertRaises(TypeError, sql.SQL, 10) self.assertRaises(TypeError, sql.SQL, dt.date(2016, 12, 31)) - def test_str(self): - self.assertEqual(repr(sql.SQL("foo")), "sql.SQL('foo')") - self.assertEqual(str(sql.SQL("foo")), "sql.SQL('foo')") + def test_string(self): + self.assertEqual(sql.SQL('foo').string, 'foo') + + def test_repr(self): + self.assertEqual(repr(sql.SQL("foo")), "SQL('foo')") + self.assertEqual(str(sql.SQL("foo")), "SQL('foo')") self.assertEqual(sql.SQL("foo").as_string(self.conn), "foo") + def test_eq(self): + self.assert_(sql.SQL('foo') == sql.SQL('foo')) + self.assert_(sql.SQL('foo') != sql.SQL('bar')) + self.assert_(sql.SQL('foo') != 'foo') + self.assert_(sql.SQL('foo') != sql.Literal('foo')) + def test_sum(self): obj = sql.SQL("foo") + sql.SQL("bar") self.assert_(isinstance(obj, sql.Composed)) @@ -241,6 +268,9 @@ class SQLTests(ConnectingTestCase): self.assert_(isinstance(obj, sql.Composed)) self.assertEqual(obj.as_string(self.conn), '"foo", bar, 42') + obj = sql.SQL(", ").join([]) + self.assertEqual(obj, sql.Composed([])) + class ComposedTest(ConnectingTestCase): def test_class(self): @@ -249,9 +279,20 @@ class ComposedTest(ConnectingTestCase): def test_repr(self): obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")]) self.assertEqual(repr(obj), - """sql.Composed([sql.Literal('foo'), sql.Identifier("b'ar")])""") + """Composed([Literal('foo'), Identifier("b'ar")])""") self.assertEqual(str(obj), repr(obj)) + def test_seq(self): + l = [sql.SQL('foo'), sql.Literal('bar'), sql.Identifier('baz')] + self.assertEqual(sql.Composed(l).seq, l) + + def test_eq(self): + l = [sql.Literal("foo"), sql.Identifier("b'ar")] + l2 = [sql.Literal("foo"), sql.Literal("b'ar")] + self.assert_(sql.Composed(l) == sql.Composed(list(l))) + self.assert_(sql.Composed(l) != l) + self.assert_(sql.Composed(l) != sql.Composed(l2)) + def test_join(self): obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")]) obj = obj.join(", ") @@ -275,27 +316,45 @@ class ComposedTest(ConnectingTestCase): self.assert_(isinstance(obj, sql.Composed)) self.assertEqual(obj.as_string(self.conn), "foo 'bar'") + def test_iter(self): + obj = sql.Composed([sql.SQL("foo"), sql.SQL('bar')]) + it = iter(obj) + i = it.next() + self.assertEqual(i, sql.SQL('foo')) + i = it.next() + self.assertEqual(i, sql.SQL('bar')) + self.assertRaises(StopIteration, it.next) + class PlaceholderTest(ConnectingTestCase): def test_class(self): self.assert_(issubclass(sql.Placeholder, sql.Composable)) - def test_alias(self): - self.assert_(sql.Placeholder is sql.PH) + def test_name(self): + self.assertEqual(sql.Placeholder().name, None) + self.assertEqual(sql.Placeholder('foo').name, 'foo') def test_repr(self): - self.assert_(str(sql.Placeholder()), 'sql.Placeholder()') - self.assert_(repr(sql.Placeholder()), 'sql.Placeholder()') + self.assert_(str(sql.Placeholder()), 'Placeholder()') + self.assert_(repr(sql.Placeholder()), 'Placeholder()') self.assert_(sql.Placeholder().as_string(self.conn), '%s') def test_repr_name(self): - self.assert_(str(sql.Placeholder('foo')), "sql.Placeholder('foo')") - self.assert_(repr(sql.Placeholder('foo')), "sql.Placeholder('foo')") + self.assert_(str(sql.Placeholder('foo')), "Placeholder('foo')") + self.assert_(repr(sql.Placeholder('foo')), "Placeholder('foo')") self.assert_(sql.Placeholder('foo').as_string(self.conn), '%(foo)s') def test_bad_name(self): self.assertRaises(ValueError, sql.Placeholder, ')') + def test_eq(self): + self.assert_(sql.Placeholder('foo') == sql.Placeholder('foo')) + self.assert_(sql.Placeholder('foo') != sql.Placeholder('bar')) + self.assert_(sql.Placeholder('foo') != 'foo') + self.assert_(sql.Placeholder() == sql.Placeholder()) + self.assert_(sql.Placeholder('foo') != sql.Placeholder()) + self.assert_(sql.Placeholder('foo') != sql.Literal('foo')) + class ValuesTest(ConnectingTestCase): def test_null(self): |
