summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorAndi Albrecht <albrecht.andi@gmail.com>2009-04-03 21:26:42 +0200
committerAndi Albrecht <albrecht.andi@gmail.com>2009-04-03 21:26:42 +0200
commit361122eb22d5681c58dac731009e4814b3dd5fa5 (patch)
treeb096496bc9c6b8febe092d0aefd56de1a4f8f4a0 /tests
downloadsqlparse-361122eb22d5681c58dac731009e4814b3dd5fa5.tar.gz
Initial import.
Diffstat (limited to 'tests')
-rw-r--r--tests/__init__.py0
-rw-r--r--tests/files/begintag.sql4
-rw-r--r--tests/files/dashcomment.sql5
-rw-r--r--tests/files/function.sql13
-rw-r--r--tests/files/function_psql.sql72
-rw-r--r--tests/files/function_psql2.sql7
-rw-r--r--tests/files/function_psql3.sql8
-rwxr-xr-xtests/run_tests.py31
-rw-r--r--tests/test_format.py146
-rw-r--r--tests/test_grouping.py86
-rw-r--r--tests/test_parse.py39
-rw-r--r--tests/test_split.py88
-rw-r--r--tests/test_tokenize.py21
-rw-r--r--tests/utils.py38
14 files changed, 558 insertions, 0 deletions
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/__init__.py
diff --git a/tests/files/begintag.sql b/tests/files/begintag.sql
new file mode 100644
index 0000000..699b365
--- /dev/null
+++ b/tests/files/begintag.sql
@@ -0,0 +1,4 @@
+begin;
+update foo
+ set bar = 1;
+commit; \ No newline at end of file
diff --git a/tests/files/dashcomment.sql b/tests/files/dashcomment.sql
new file mode 100644
index 0000000..0d5ac62
--- /dev/null
+++ b/tests/files/dashcomment.sql
@@ -0,0 +1,5 @@
+select * from user;
+--select * from host;
+select * from user;
+select * -- foo;
+from foo; \ No newline at end of file
diff --git a/tests/files/function.sql b/tests/files/function.sql
new file mode 100644
index 0000000..d19227f
--- /dev/null
+++ b/tests/files/function.sql
@@ -0,0 +1,13 @@
+CREATE OR REPLACE FUNCTION foo(
+ p_in1 VARCHAR
+ , p_in2 INTEGER
+) RETURNS INTEGER AS
+
+ DECLARE
+ v_foo INTEGER;
+ BEGIN
+ SELECT *
+ FROM foo
+ INTO v_foo;
+ RETURN v_foo.id;
+ END; \ No newline at end of file
diff --git a/tests/files/function_psql.sql b/tests/files/function_psql.sql
new file mode 100644
index 0000000..e485f7a
--- /dev/null
+++ b/tests/files/function_psql.sql
@@ -0,0 +1,72 @@
+CREATE OR REPLACE FUNCTION public.delete_data (
+ p_tabelle VARCHAR
+ , p_key VARCHAR
+ , p_value INTEGER
+) RETURNS INTEGER AS
+$$
+DECLARE
+ p_retval INTEGER;
+ v_constraint RECORD;
+ v_count INTEGER;
+ v_data RECORD;
+ v_fieldname VARCHAR;
+ v_sql VARCHAR;
+ v_key VARCHAR;
+ v_value INTEGER;
+BEGIN
+ v_sql := 'SELECT COUNT(*) FROM ' || p_tabelle || ' WHERE ' || p_key || ' = ' || p_value;
+ --RAISE NOTICE '%', v_sql;
+ EXECUTE v_sql INTO v_count;
+ IF v_count::integer != 0 THEN
+ SELECT att.attname
+ INTO v_key
+ FROM pg_attribute att
+ LEFT JOIN pg_constraint con ON con.conrelid = att.attrelid
+ AND con.conkey[1] = att.attnum
+ AND con.contype = 'p', pg_type typ, pg_class rel, pg_namespace ns
+ WHERE att.attrelid = rel.oid
+ AND att.attnum > 0
+ AND typ.oid = att.atttypid
+ AND att.attisdropped = false
+ AND rel.relname = p_tabelle
+ AND con.conkey[1] = 1
+ AND ns.oid = rel.relnamespace
+ AND ns.nspname = 'public'
+ ORDER BY att.attnum;
+ v_sql := 'SELECT ' || v_key || ' AS id FROM ' || p_tabelle || ' WHERE ' || p_key || ' = ' || p_value;
+ FOR v_data IN EXECUTE v_sql
+ LOOP
+ --RAISE NOTICE ' -> % %', p_tabelle, v_data.id;
+ FOR v_constraint IN SELECT t.constraint_name
+ , t.constraint_type
+ , t.table_name
+ , c.column_name
+ FROM public.v_table_constraints t
+ , public.v_constraint_columns c
+ WHERE t.constraint_name = c.constraint_name
+ AND t.constraint_type = 'FOREIGN KEY'
+ AND c.table_name = p_tabelle
+ AND t.table_schema = 'public'
+ AND c.table_schema = 'public'
+ LOOP
+ v_fieldname := substring(v_constraint.constraint_name from 1 for length(v_constraint.constraint_name) - length(v_constraint.column_name) - 1);
+ IF (v_constraint.table_name = p_tabelle) AND (p_value = v_data.id) THEN
+ --RAISE NOTICE 'Skip (Selbstverweis)';
+ CONTINUE;
+ ELSE
+ PERFORM delete_data(v_constraint.table_name::varchar, v_fieldname::varchar, v_data.id::integer);
+ END IF;
+ END LOOP;
+ END LOOP;
+ v_sql := 'DELETE FROM ' || p_tabelle || ' WHERE ' || p_key || ' = ' || p_value;
+ --RAISE NOTICE '%', v_sql;
+ EXECUTE v_sql;
+ p_retval := 1;
+ ELSE
+ --RAISE NOTICE ' -> Keine Sätze gefunden';
+ p_retval := 0;
+ END IF;
+ RETURN p_retval;
+END;
+$$
+LANGUAGE plpgsql; \ No newline at end of file
diff --git a/tests/files/function_psql2.sql b/tests/files/function_psql2.sql
new file mode 100644
index 0000000..b5d494c
--- /dev/null
+++ b/tests/files/function_psql2.sql
@@ -0,0 +1,7 @@
+CREATE OR REPLACE FUNCTION update_something() RETURNS void AS
+$body$
+BEGIN
+ raise notice 'foo';
+END;
+$body$
+LANGUAGE 'plpgsql' VOLATILE CALLED ON NULL INPUT SECURITY INVOKER; \ No newline at end of file
diff --git a/tests/files/function_psql3.sql b/tests/files/function_psql3.sql
new file mode 100644
index 0000000..b25d818
--- /dev/null
+++ b/tests/files/function_psql3.sql
@@ -0,0 +1,8 @@
+CREATE OR REPLACE FUNCTION foo() RETURNS integer AS
+$body$
+DECLARE
+BEGIN
+ select * from foo;
+END;
+$body$
+LANGUAGE 'plpgsql' VOLATILE CALLED ON NULL INPUT SECURITY INVOKER; \ No newline at end of file
diff --git a/tests/run_tests.py b/tests/run_tests.py
new file mode 100755
index 0000000..1c7960e
--- /dev/null
+++ b/tests/run_tests.py
@@ -0,0 +1,31 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+"""Test runner for sqlparse."""
+
+import os
+import sys
+import unittest
+
+sys.path.insert(1, os.path.join(os.path.dirname(__file__), '../'))
+
+
+def main():
+ """Create a TestSuite and run it."""
+ loader = unittest.TestLoader()
+ suite = unittest.TestSuite()
+ fnames = [os.path.split(f)[-1] for f in sys.argv[1:]]
+ for fname in os.listdir(os.path.dirname(__file__)):
+ if (not fname.startswith('test_') or not fname.endswith('.py')
+ or (fnames and fname not in fnames)):
+ continue
+ modname = os.path.splitext(fname)[0]
+ mod = __import__(os.path.splitext(fname)[0])
+ suite.addTests(loader.loadTestsFromModule(mod))
+ unittest.TextTestRunner(verbosity=2).run(suite)
+
+
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tests/test_format.py b/tests/test_format.py
new file mode 100644
index 0000000..5748704
--- /dev/null
+++ b/tests/test_format.py
@@ -0,0 +1,146 @@
+# -*- coding: utf-8 -*-
+
+from tests.utils import TestCaseBase
+
+import sqlparse
+
+
+class TestFormat(TestCaseBase):
+
+ def test_keywordcase(self):
+ sql = 'select * from bar; -- select foo\n'
+ res = sqlparse.format(sql, keyword_case='upper')
+ self.ndiffAssertEqual(res, 'SELECT * FROM bar; -- select foo\n')
+ res = sqlparse.format(sql, keyword_case='capitalize')
+ self.ndiffAssertEqual(res, 'Select * From bar; -- select foo\n')
+ res = sqlparse.format(sql.upper(), keyword_case='lower')
+ self.ndiffAssertEqual(res, 'select * from BAR; -- SELECT FOO\n')
+ self.assertRaises(sqlparse.SQLParseError, sqlparse.format, sql,
+ keyword_case='foo')
+
+ def test_identifiercase(self):
+ sql = 'select * from bar; -- select foo\n'
+ res = sqlparse.format(sql, identifier_case='upper')
+ self.ndiffAssertEqual(res, 'select * from BAR; -- select foo\n')
+ res = sqlparse.format(sql, identifier_case='capitalize')
+ self.ndiffAssertEqual(res, 'select * from Bar; -- select foo\n')
+ res = sqlparse.format(sql.upper(), identifier_case='lower')
+ self.ndiffAssertEqual(res, 'SELECT * FROM bar; -- SELECT FOO\n')
+ self.assertRaises(sqlparse.SQLParseError, sqlparse.format, sql,
+ identifier_case='foo')
+ sql = 'select * from "foo"."bar"'
+ res = sqlparse.format(sql, identifier_case="upper")
+ self.ndiffAssertEqual(res, 'select * from "FOO"."BAR"')
+
+ def test_strip_comments_single(self):
+ sql = 'select *-- statement starts here\nfrom foo'
+ res = sqlparse.format(sql, strip_comments=True)
+ self.ndiffAssertEqual(res, 'select * from foo')
+ sql = 'select * -- statement starts here\nfrom foo'
+ res = sqlparse.format(sql, strip_comments=True)
+ self.ndiffAssertEqual(res, 'select * from foo')
+ sql = 'select-- foo\nfrom -- bar\nwhere'
+ res = sqlparse.format(sql, strip_comments=True)
+ self.ndiffAssertEqual(res, 'select from where')
+
+ def test_strip_comments_multi(self):
+ sql = '/* sql starts here */\nselect'
+ res = sqlparse.format(sql, strip_comments=True)
+ self.ndiffAssertEqual(res, 'select')
+ sql = '/* sql starts here */ select'
+ res = sqlparse.format(sql, strip_comments=True)
+ self.ndiffAssertEqual(res, 'select')
+ sql = '/*\n * sql starts here\n */\nselect'
+ res = sqlparse.format(sql, strip_comments=True)
+ self.ndiffAssertEqual(res, 'select')
+ sql = 'select (/* sql starts here */ select 2)'
+ res = sqlparse.format(sql, strip_comments=True)
+ self.ndiffAssertEqual(res, 'select (select 2)')
+
+ def test_strip_ws(self):
+ f = lambda sql: sqlparse.format(sql, strip_whitespace=True)
+ s = 'select\n* from foo\n\twhere ( 1 = 2 )\n'
+ self.ndiffAssertEqual(f(s), 'select * from foo where (1 = 2)')
+ s = 'select -- foo\nfrom bar\n'
+ self.ndiffAssertEqual(f(s), 'select -- foo\nfrom bar')
+
+
+class TestFormatReindent(TestCaseBase):
+
+ def test_stmts(self):
+ f = lambda sql: sqlparse.format(sql, reindent=True)
+ s = 'select foo; select bar'
+ self.ndiffAssertEqual(f(s), 'select foo;\n\nselect bar')
+ s = 'select foo'
+ self.ndiffAssertEqual(f(s), 'select foo')
+ s = 'select foo; -- test\n select bar'
+ self.ndiffAssertEqual(f(s), 'select foo; -- test\n\nselect bar')
+
+ def test_keywords(self):
+ f = lambda sql: sqlparse.format(sql, reindent=True)
+ s = 'select * from foo union select * from bar;'
+ self.ndiffAssertEqual(f(s), '\n'.join(['select *',
+ 'from foo',
+ 'union',
+ 'select *',
+ 'from bar;']))
+
+ def test_parenthesis(self):
+ f = lambda sql: sqlparse.format(sql, reindent=True)
+ s = 'select count(*) from (select * from foo);'
+ self.ndiffAssertEqual(f(s),
+ '\n'.join(['select count(*)',
+ 'from',
+ ' (select *',
+ ' from foo);',
+ ])
+ )
+
+ def test_where(self):
+ f = lambda sql: sqlparse.format(sql, reindent=True)
+ s = 'select * from foo where bar = 1 and baz = 2 or bzz = 3;'
+ self.ndiffAssertEqual(f(s), ('select *\nfrom foo\n'
+ 'where bar = 1\n'
+ ' and baz = 2\n'
+ ' or bzz = 3;'))
+ s = 'select * from foo where bar = 1 and (baz = 2 or bzz = 3);'
+ self.ndiffAssertEqual(f(s), ('select *\nfrom foo\n'
+ 'where bar = 1\n'
+ ' and (baz = 2\n'
+ ' or bzz = 3);'))
+
+ def test_join(self):
+ f = lambda sql: sqlparse.format(sql, reindent=True)
+ s = 'select * from foo join bar on 1 = 2'
+ self.ndiffAssertEqual(f(s), '\n'.join(['select *',
+ 'from foo',
+ 'join bar on 1 = 2']))
+ s = 'select * from foo inner join bar on 1 = 2'
+ self.ndiffAssertEqual(f(s), '\n'.join(['select *',
+ 'from foo',
+ 'inner join bar on 1 = 2']))
+ s = 'select * from foo left outer join bar on 1 = 2'
+ self.ndiffAssertEqual(f(s), '\n'.join(['select *',
+ 'from foo',
+ 'left outer join bar on 1 = 2']
+ ))
+
+ def test_identifier_list(self):
+ f = lambda sql: sqlparse.format(sql, reindent=True)
+ s = 'select foo, bar, baz from table1, table2 where 1 = 2'
+ self.ndiffAssertEqual(f(s), '\n'.join(['select foo,',
+ ' bar,',
+ ' baz',
+ 'from table1,',
+ ' table2',
+ 'where 1 = 2']))
+
+ def test_case(self):
+ f = lambda sql: sqlparse.format(sql, reindent=True)
+ s = 'case when foo = 1 then 2 when foo = 3 then 4 else 5 end'
+ self.ndiffAssertEqual(f(s), '\n'.join(['case when foo = 1 then 2',
+ ' when foo = 3 then 4',
+ ' else 5',
+ 'end']))
+
+
diff --git a/tests/test_grouping.py b/tests/test_grouping.py
new file mode 100644
index 0000000..fc3bea5
--- /dev/null
+++ b/tests/test_grouping.py
@@ -0,0 +1,86 @@
+# -*- coding: utf-8 -*-
+
+import sqlparse
+from sqlparse import tokens as T
+from sqlparse.engine.grouping import *
+
+from tests.utils import TestCaseBase
+
+
+class TestGrouping(TestCaseBase):
+
+ def test_parenthesis(self):
+ s ='x1 (x2 (x3) x2) foo (y2) bar'
+ parsed = sqlparse.parse(s)[0]
+ self.ndiffAssertEqual(s, str(parsed))
+ self.assertEqual(len(parsed.tokens), 9)
+ self.assert_(isinstance(parsed.tokens[2], Parenthesis))
+ self.assert_(isinstance(parsed.tokens[-3], Parenthesis))
+ self.assertEqual(len(parsed.tokens[2].tokens), 7)
+ self.assert_(isinstance(parsed.tokens[2].tokens[3], Parenthesis))
+ self.assertEqual(len(parsed.tokens[2].tokens[3].tokens), 3)
+
+ def test_comments(self):
+ s = '/*\n * foo\n */ \n bar'
+ parsed = sqlparse.parse(s)[0]
+ self.ndiffAssertEqual(s, unicode(parsed))
+ self.assertEqual(len(parsed.tokens), 2)
+
+ def test_identifiers(self):
+ s = 'select foo.bar from "myscheme"."table" where fail. order'
+ parsed = sqlparse.parse(s)[0]
+ self.ndiffAssertEqual(s, parsed.to_unicode())
+ self.assert_(isinstance(parsed.tokens[2], Identifier))
+ self.assert_(isinstance(parsed.tokens[6], Identifier))
+ self.assert_(isinstance(parsed.tokens[8], Where))
+ s = 'select * from foo where foo.id = 1'
+ parsed = sqlparse.parse(s)[0]
+ self.ndiffAssertEqual(s, parsed.to_unicode())
+ self.assert_(isinstance(parsed.tokens[-1].tokens[2], Identifier))
+ s = 'select * from (select "foo"."id" from foo)'
+ parsed = sqlparse.parse(s)[0]
+ self.ndiffAssertEqual(s, parsed.to_unicode())
+ self.assert_(isinstance(parsed.tokens[-1].tokens[3], Identifier))
+
+ def test_where(self):
+ s = 'select * from foo where bar = 1 order by id desc'
+ p = sqlparse.parse(s)[0]
+ self.ndiffAssertEqual(s, p.to_unicode())
+ self.assertTrue(len(p.tokens), 16)
+ s = 'select x from (select y from foo where bar = 1) z'
+ p = sqlparse.parse(s)[0]
+ self.ndiffAssertEqual(s, p.to_unicode())
+ self.assertTrue(isinstance(p.tokens[-3].tokens[-1], Where))
+
+ def test_typecast(self):
+ s = 'select foo::integer from bar'
+ p = sqlparse.parse(s)[0]
+ self.ndiffAssertEqual(s, p.to_unicode())
+ self.assertEqual(p.tokens[2].get_typecast(), 'integer')
+ self.assertEqual(p.tokens[2].get_name(), 'foo')
+ s = 'select (current_database())::information_schema.sql_identifier'
+ p = sqlparse.parse(s)[0]
+ self.ndiffAssertEqual(s, p.to_unicode())
+ self.assertEqual(p.tokens[2].get_typecast(),
+ 'information_schema.sql_identifier')
+
+ def test_alias(self):
+ s = 'select foo as bar from mytable'
+ p = sqlparse.parse(s)[0]
+ self.ndiffAssertEqual(s, p.to_unicode())
+ self.assertEqual(p.tokens[2].get_real_name(), 'foo')
+ self.assertEqual(p.tokens[2].get_alias(), 'bar')
+ s = 'select foo from mytable t1'
+ p = sqlparse.parse(s)[0]
+ self.ndiffAssertEqual(s, p.to_unicode())
+ self.assertEqual(p.tokens[6].get_real_name(), 'mytable')
+ self.assertEqual(p.tokens[6].get_alias(), 't1')
+ s = 'select foo::integer as bar from mytable'
+ p = sqlparse.parse(s)[0]
+ self.ndiffAssertEqual(s, p.to_unicode())
+ self.assertEqual(p.tokens[2].get_alias(), 'bar')
+ s = ('SELECT DISTINCT '
+ '(current_database())::information_schema.sql_identifier AS view')
+ p = sqlparse.parse(s)[0]
+ self.ndiffAssertEqual(s, p.to_unicode())
+ self.assertEqual(p.tokens[4].get_alias(), 'view')
diff --git a/tests/test_parse.py b/tests/test_parse.py
new file mode 100644
index 0000000..536b6f6
--- /dev/null
+++ b/tests/test_parse.py
@@ -0,0 +1,39 @@
+# -*- coding: utf-8 -*-
+
+"""Tests sqlparse function."""
+
+from utils import TestCaseBase
+
+import sqlparse
+
+
+class SQLParseTest(TestCaseBase):
+ """Tests sqlparse.parse()."""
+
+ def test_tokenize(self):
+ sql = 'select * from foo;'
+ stmts = sqlparse.parse(sql)
+ self.assertEqual(len(stmts), 1)
+ self.assertEqual(str(stmts[0]), sql)
+
+ def test_multistatement(self):
+ sql1 = 'select * from foo;'
+ sql2 = 'select * from bar;'
+ stmts = sqlparse.parse(sql1+sql2)
+ self.assertEqual(len(stmts), 2)
+ self.assertEqual(str(stmts[0]), sql1)
+ self.assertEqual(str(stmts[1]), sql2)
+
+ def test_newlines(self):
+ sql = u'select\n*from foo;'
+ p = sqlparse.parse(sql)[0]
+ self.assertEqual(unicode(p), sql)
+ sql = u'select\r\n*from foo'
+ p = sqlparse.parse(sql)[0]
+ self.assertEqual(unicode(p), sql)
+ sql = u'select\r*from foo'
+ p = sqlparse.parse(sql)[0]
+ self.assertEqual(unicode(p), sql)
+ sql = u'select\r\n*from foo\n'
+ p = sqlparse.parse(sql)[0]
+ self.assertEqual(unicode(p), sql)
diff --git a/tests/test_split.py b/tests/test_split.py
new file mode 100644
index 0000000..782b226
--- /dev/null
+++ b/tests/test_split.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+
+# Tests splitting functions.
+
+import unittest
+
+from utils import load_file, TestCaseBase
+
+import sqlparse
+
+
+class SQLSplitTest(TestCaseBase):
+ """Tests sqlparse.sqlsplit()."""
+
+ _sql1 = 'select * from foo;'
+ _sql2 = 'select * from bar;'
+
+ def test_split_semicolon(self):
+ sql2 = 'select * from foo where bar = \'foo;bar\';'
+ stmts = sqlparse.parse(''.join([self._sql1, sql2]))
+ self.assertEqual(len(stmts), 2)
+ self.ndiffAssertEqual(unicode(stmts[0]), self._sql1)
+ self.ndiffAssertEqual(unicode(stmts[1]), sql2)
+
+ def test_create_function(self):
+ sql = load_file('function.sql')
+ stmts = sqlparse.parse(sql)
+ self.assertEqual(len(stmts), 1)
+ self.ndiffAssertEqual(unicode(stmts[0]), sql)
+
+ def test_create_function_psql(self):
+ sql = load_file('function_psql.sql')
+ stmts = sqlparse.parse(sql)
+ self.assertEqual(len(stmts), 1)
+ self.ndiffAssertEqual(unicode(stmts[0]), sql)
+
+ def test_create_function_psql3(self):
+ sql = load_file('function_psql3.sql')
+ stmts = sqlparse.parse(sql)
+ self.assertEqual(len(stmts), 1)
+ self.ndiffAssertEqual(unicode(stmts[0]), sql)
+
+ def test_create_function_psql2(self):
+ sql = load_file('function_psql2.sql')
+ stmts = sqlparse.parse(sql)
+ self.assertEqual(len(stmts), 1)
+ self.ndiffAssertEqual(unicode(stmts[0]), sql)
+
+ def test_dashcomments(self):
+ sql = load_file('dashcomment.sql')
+ stmts = sqlparse.parse(sql)
+ self.assertEqual(len(stmts), 3)
+ self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
+
+ def test_begintag(self):
+ sql = load_file('begintag.sql')
+ stmts = sqlparse.parse(sql)
+ self.assertEqual(len(stmts), 3)
+ self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
+
+ def test_dropif(self):
+ sql = 'DROP TABLE IF EXISTS FOO;\n\nSELECT * FROM BAR;'
+ stmts = sqlparse.parse(sql)
+ self.assertEqual(len(stmts), 2)
+ self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
+
+ def test_comment_with_umlaut(self):
+ sql = (u'select * from foo;\n'
+ u'-- Testing an umlaut: ä\n'
+ u'select * from bar;')
+ stmts = sqlparse.parse(sql)
+ self.assertEqual(len(stmts), 2)
+ self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
+
+ def test_comment_end_of_line(self):
+ sql = ('select * from foo; -- foo\n'
+ 'select * from bar;')
+ stmts = sqlparse.parse(sql)
+ self.assertEqual(len(stmts), 2)
+ self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
+ # make sure the comment belongs to first query
+ self.ndiffAssertEqual(unicode(stmts[0]), 'select * from foo; -- foo\n')
+
+ def test_casewhen(self):
+ sql = ('SELECT case when val = 1 then 2 else null end as foo;\n'
+ 'comment on table actor is \'The actor table.\';')
+ stmts = sqlparse.split(sql)
+ self.assertEqual(len(stmts), 2)
diff --git a/tests/test_tokenize.py b/tests/test_tokenize.py
new file mode 100644
index 0000000..7106b3c
--- /dev/null
+++ b/tests/test_tokenize.py
@@ -0,0 +1,21 @@
+# -*- coding: utf-8 -*-
+
+import unittest
+import types
+
+from sqlparse import lexer
+from sqlparse.tokens import *
+
+
+class TestTokenize(unittest.TestCase):
+
+ def test_simple(self):
+ sql = 'select * from foo;'
+ stream = lexer.tokenize(sql)
+ self.assert_(type(stream) is types.GeneratorType)
+ tokens = list(stream)
+ self.assertEqual(len(tokens), 8)
+ self.assertEqual(len(tokens[0]), 2)
+ self.assertEqual(tokens[0], (Keyword.DML, u'select'))
+ self.assertEqual(tokens[-1], (Punctuation, u';'))
+
diff --git a/tests/utils.py b/tests/utils.py
new file mode 100644
index 0000000..a78b460
--- /dev/null
+++ b/tests/utils.py
@@ -0,0 +1,38 @@
+# -*- coding: utf-8 -*-
+
+"""Helpers for testing."""
+
+import codecs
+import difflib
+import os
+import unittest
+from StringIO import StringIO
+
+NL = '\n'
+DIR_PATH = os.path.abspath(os.path.dirname(__file__))
+PARENT_DIR = os.path.dirname(DIR_PATH)
+FILES_DIR = os.path.join(DIR_PATH, 'files')
+
+
+def load_file(filename, encoding='utf-8'):
+ """Opens filename with encoding and return it's contents."""
+ f = codecs.open(os.path.join(FILES_DIR, filename), 'r', encoding)
+ data = f.read()
+ f.close()
+ return data
+
+
+class TestCaseBase(unittest.TestCase):
+ """Base class for test cases with some additional checks."""
+
+ # Adopted from Python's tests.
+ def ndiffAssertEqual(self, first, second):
+ """Like failUnlessEqual except use ndiff for readable output."""
+ if first <> second:
+ sfirst = unicode(first)
+ ssecond = unicode(second)
+ diff = difflib.ndiff(sfirst.splitlines(), ssecond.splitlines())
+ fp = StringIO()
+ print >> fp, NL, NL.join(diff)
+ print fp.getvalue()
+ raise self.failureException, fp.getvalue()