diff options
| author | Andi Albrecht <albrecht.andi@gmail.com> | 2009-04-03 21:26:42 +0200 |
|---|---|---|
| committer | Andi Albrecht <albrecht.andi@gmail.com> | 2009-04-03 21:26:42 +0200 |
| commit | 361122eb22d5681c58dac731009e4814b3dd5fa5 (patch) | |
| tree | b096496bc9c6b8febe092d0aefd56de1a4f8f4a0 /tests | |
| download | sqlparse-361122eb22d5681c58dac731009e4814b3dd5fa5.tar.gz | |
Initial import.
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/__init__.py | 0 | ||||
| -rw-r--r-- | tests/files/begintag.sql | 4 | ||||
| -rw-r--r-- | tests/files/dashcomment.sql | 5 | ||||
| -rw-r--r-- | tests/files/function.sql | 13 | ||||
| -rw-r--r-- | tests/files/function_psql.sql | 72 | ||||
| -rw-r--r-- | tests/files/function_psql2.sql | 7 | ||||
| -rw-r--r-- | tests/files/function_psql3.sql | 8 | ||||
| -rwxr-xr-x | tests/run_tests.py | 31 | ||||
| -rw-r--r-- | tests/test_format.py | 146 | ||||
| -rw-r--r-- | tests/test_grouping.py | 86 | ||||
| -rw-r--r-- | tests/test_parse.py | 39 | ||||
| -rw-r--r-- | tests/test_split.py | 88 | ||||
| -rw-r--r-- | tests/test_tokenize.py | 21 | ||||
| -rw-r--r-- | tests/utils.py | 38 |
14 files changed, 558 insertions, 0 deletions
diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/__init__.py diff --git a/tests/files/begintag.sql b/tests/files/begintag.sql new file mode 100644 index 0000000..699b365 --- /dev/null +++ b/tests/files/begintag.sql @@ -0,0 +1,4 @@ +begin; +update foo + set bar = 1; +commit;
\ No newline at end of file diff --git a/tests/files/dashcomment.sql b/tests/files/dashcomment.sql new file mode 100644 index 0000000..0d5ac62 --- /dev/null +++ b/tests/files/dashcomment.sql @@ -0,0 +1,5 @@ +select * from user; +--select * from host; +select * from user; +select * -- foo; +from foo;
\ No newline at end of file diff --git a/tests/files/function.sql b/tests/files/function.sql new file mode 100644 index 0000000..d19227f --- /dev/null +++ b/tests/files/function.sql @@ -0,0 +1,13 @@ +CREATE OR REPLACE FUNCTION foo( + p_in1 VARCHAR + , p_in2 INTEGER +) RETURNS INTEGER AS + + DECLARE + v_foo INTEGER; + BEGIN + SELECT * + FROM foo + INTO v_foo; + RETURN v_foo.id; + END;
\ No newline at end of file diff --git a/tests/files/function_psql.sql b/tests/files/function_psql.sql new file mode 100644 index 0000000..e485f7a --- /dev/null +++ b/tests/files/function_psql.sql @@ -0,0 +1,72 @@ +CREATE OR REPLACE FUNCTION public.delete_data ( + p_tabelle VARCHAR + , p_key VARCHAR + , p_value INTEGER +) RETURNS INTEGER AS +$$ +DECLARE + p_retval INTEGER; + v_constraint RECORD; + v_count INTEGER; + v_data RECORD; + v_fieldname VARCHAR; + v_sql VARCHAR; + v_key VARCHAR; + v_value INTEGER; +BEGIN + v_sql := 'SELECT COUNT(*) FROM ' || p_tabelle || ' WHERE ' || p_key || ' = ' || p_value; + --RAISE NOTICE '%', v_sql; + EXECUTE v_sql INTO v_count; + IF v_count::integer != 0 THEN + SELECT att.attname + INTO v_key + FROM pg_attribute att + LEFT JOIN pg_constraint con ON con.conrelid = att.attrelid + AND con.conkey[1] = att.attnum + AND con.contype = 'p', pg_type typ, pg_class rel, pg_namespace ns + WHERE att.attrelid = rel.oid + AND att.attnum > 0 + AND typ.oid = att.atttypid + AND att.attisdropped = false + AND rel.relname = p_tabelle + AND con.conkey[1] = 1 + AND ns.oid = rel.relnamespace + AND ns.nspname = 'public' + ORDER BY att.attnum; + v_sql := 'SELECT ' || v_key || ' AS id FROM ' || p_tabelle || ' WHERE ' || p_key || ' = ' || p_value; + FOR v_data IN EXECUTE v_sql + LOOP + --RAISE NOTICE ' -> % %', p_tabelle, v_data.id; + FOR v_constraint IN SELECT t.constraint_name + , t.constraint_type + , t.table_name + , c.column_name + FROM public.v_table_constraints t + , public.v_constraint_columns c + WHERE t.constraint_name = c.constraint_name + AND t.constraint_type = 'FOREIGN KEY' + AND c.table_name = p_tabelle + AND t.table_schema = 'public' + AND c.table_schema = 'public' + LOOP + v_fieldname := substring(v_constraint.constraint_name from 1 for length(v_constraint.constraint_name) - length(v_constraint.column_name) - 1); + IF (v_constraint.table_name = p_tabelle) AND (p_value = v_data.id) THEN + --RAISE NOTICE 'Skip (Selbstverweis)'; + CONTINUE; + ELSE + PERFORM delete_data(v_constraint.table_name::varchar, v_fieldname::varchar, v_data.id::integer); + END IF; + END LOOP; + END LOOP; + v_sql := 'DELETE FROM ' || p_tabelle || ' WHERE ' || p_key || ' = ' || p_value; + --RAISE NOTICE '%', v_sql; + EXECUTE v_sql; + p_retval := 1; + ELSE + --RAISE NOTICE ' -> Keine Sätze gefunden'; + p_retval := 0; + END IF; + RETURN p_retval; +END; +$$ +LANGUAGE plpgsql;
\ No newline at end of file diff --git a/tests/files/function_psql2.sql b/tests/files/function_psql2.sql new file mode 100644 index 0000000..b5d494c --- /dev/null +++ b/tests/files/function_psql2.sql @@ -0,0 +1,7 @@ +CREATE OR REPLACE FUNCTION update_something() RETURNS void AS +$body$ +BEGIN + raise notice 'foo'; +END; +$body$ +LANGUAGE 'plpgsql' VOLATILE CALLED ON NULL INPUT SECURITY INVOKER;
\ No newline at end of file diff --git a/tests/files/function_psql3.sql b/tests/files/function_psql3.sql new file mode 100644 index 0000000..b25d818 --- /dev/null +++ b/tests/files/function_psql3.sql @@ -0,0 +1,8 @@ +CREATE OR REPLACE FUNCTION foo() RETURNS integer AS +$body$ +DECLARE +BEGIN + select * from foo; +END; +$body$ +LANGUAGE 'plpgsql' VOLATILE CALLED ON NULL INPUT SECURITY INVOKER;
\ No newline at end of file diff --git a/tests/run_tests.py b/tests/run_tests.py new file mode 100755 index 0000000..1c7960e --- /dev/null +++ b/tests/run_tests.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +"""Test runner for sqlparse.""" + +import os +import sys +import unittest + +sys.path.insert(1, os.path.join(os.path.dirname(__file__), '../')) + + +def main(): + """Create a TestSuite and run it.""" + loader = unittest.TestLoader() + suite = unittest.TestSuite() + fnames = [os.path.split(f)[-1] for f in sys.argv[1:]] + for fname in os.listdir(os.path.dirname(__file__)): + if (not fname.startswith('test_') or not fname.endswith('.py') + or (fnames and fname not in fnames)): + continue + modname = os.path.splitext(fname)[0] + mod = __import__(os.path.splitext(fname)[0]) + suite.addTests(loader.loadTestsFromModule(mod)) + unittest.TextTestRunner(verbosity=2).run(suite) + + + + +if __name__ == '__main__': + main() diff --git a/tests/test_format.py b/tests/test_format.py new file mode 100644 index 0000000..5748704 --- /dev/null +++ b/tests/test_format.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- + +from tests.utils import TestCaseBase + +import sqlparse + + +class TestFormat(TestCaseBase): + + def test_keywordcase(self): + sql = 'select * from bar; -- select foo\n' + res = sqlparse.format(sql, keyword_case='upper') + self.ndiffAssertEqual(res, 'SELECT * FROM bar; -- select foo\n') + res = sqlparse.format(sql, keyword_case='capitalize') + self.ndiffAssertEqual(res, 'Select * From bar; -- select foo\n') + res = sqlparse.format(sql.upper(), keyword_case='lower') + self.ndiffAssertEqual(res, 'select * from BAR; -- SELECT FOO\n') + self.assertRaises(sqlparse.SQLParseError, sqlparse.format, sql, + keyword_case='foo') + + def test_identifiercase(self): + sql = 'select * from bar; -- select foo\n' + res = sqlparse.format(sql, identifier_case='upper') + self.ndiffAssertEqual(res, 'select * from BAR; -- select foo\n') + res = sqlparse.format(sql, identifier_case='capitalize') + self.ndiffAssertEqual(res, 'select * from Bar; -- select foo\n') + res = sqlparse.format(sql.upper(), identifier_case='lower') + self.ndiffAssertEqual(res, 'SELECT * FROM bar; -- SELECT FOO\n') + self.assertRaises(sqlparse.SQLParseError, sqlparse.format, sql, + identifier_case='foo') + sql = 'select * from "foo"."bar"' + res = sqlparse.format(sql, identifier_case="upper") + self.ndiffAssertEqual(res, 'select * from "FOO"."BAR"') + + def test_strip_comments_single(self): + sql = 'select *-- statement starts here\nfrom foo' + res = sqlparse.format(sql, strip_comments=True) + self.ndiffAssertEqual(res, 'select * from foo') + sql = 'select * -- statement starts here\nfrom foo' + res = sqlparse.format(sql, strip_comments=True) + self.ndiffAssertEqual(res, 'select * from foo') + sql = 'select-- foo\nfrom -- bar\nwhere' + res = sqlparse.format(sql, strip_comments=True) + self.ndiffAssertEqual(res, 'select from where') + + def test_strip_comments_multi(self): + sql = '/* sql starts here */\nselect' + res = sqlparse.format(sql, strip_comments=True) + self.ndiffAssertEqual(res, 'select') + sql = '/* sql starts here */ select' + res = sqlparse.format(sql, strip_comments=True) + self.ndiffAssertEqual(res, 'select') + sql = '/*\n * sql starts here\n */\nselect' + res = sqlparse.format(sql, strip_comments=True) + self.ndiffAssertEqual(res, 'select') + sql = 'select (/* sql starts here */ select 2)' + res = sqlparse.format(sql, strip_comments=True) + self.ndiffAssertEqual(res, 'select (select 2)') + + def test_strip_ws(self): + f = lambda sql: sqlparse.format(sql, strip_whitespace=True) + s = 'select\n* from foo\n\twhere ( 1 = 2 )\n' + self.ndiffAssertEqual(f(s), 'select * from foo where (1 = 2)') + s = 'select -- foo\nfrom bar\n' + self.ndiffAssertEqual(f(s), 'select -- foo\nfrom bar') + + +class TestFormatReindent(TestCaseBase): + + def test_stmts(self): + f = lambda sql: sqlparse.format(sql, reindent=True) + s = 'select foo; select bar' + self.ndiffAssertEqual(f(s), 'select foo;\n\nselect bar') + s = 'select foo' + self.ndiffAssertEqual(f(s), 'select foo') + s = 'select foo; -- test\n select bar' + self.ndiffAssertEqual(f(s), 'select foo; -- test\n\nselect bar') + + def test_keywords(self): + f = lambda sql: sqlparse.format(sql, reindent=True) + s = 'select * from foo union select * from bar;' + self.ndiffAssertEqual(f(s), '\n'.join(['select *', + 'from foo', + 'union', + 'select *', + 'from bar;'])) + + def test_parenthesis(self): + f = lambda sql: sqlparse.format(sql, reindent=True) + s = 'select count(*) from (select * from foo);' + self.ndiffAssertEqual(f(s), + '\n'.join(['select count(*)', + 'from', + ' (select *', + ' from foo);', + ]) + ) + + def test_where(self): + f = lambda sql: sqlparse.format(sql, reindent=True) + s = 'select * from foo where bar = 1 and baz = 2 or bzz = 3;' + self.ndiffAssertEqual(f(s), ('select *\nfrom foo\n' + 'where bar = 1\n' + ' and baz = 2\n' + ' or bzz = 3;')) + s = 'select * from foo where bar = 1 and (baz = 2 or bzz = 3);' + self.ndiffAssertEqual(f(s), ('select *\nfrom foo\n' + 'where bar = 1\n' + ' and (baz = 2\n' + ' or bzz = 3);')) + + def test_join(self): + f = lambda sql: sqlparse.format(sql, reindent=True) + s = 'select * from foo join bar on 1 = 2' + self.ndiffAssertEqual(f(s), '\n'.join(['select *', + 'from foo', + 'join bar on 1 = 2'])) + s = 'select * from foo inner join bar on 1 = 2' + self.ndiffAssertEqual(f(s), '\n'.join(['select *', + 'from foo', + 'inner join bar on 1 = 2'])) + s = 'select * from foo left outer join bar on 1 = 2' + self.ndiffAssertEqual(f(s), '\n'.join(['select *', + 'from foo', + 'left outer join bar on 1 = 2'] + )) + + def test_identifier_list(self): + f = lambda sql: sqlparse.format(sql, reindent=True) + s = 'select foo, bar, baz from table1, table2 where 1 = 2' + self.ndiffAssertEqual(f(s), '\n'.join(['select foo,', + ' bar,', + ' baz', + 'from table1,', + ' table2', + 'where 1 = 2'])) + + def test_case(self): + f = lambda sql: sqlparse.format(sql, reindent=True) + s = 'case when foo = 1 then 2 when foo = 3 then 4 else 5 end' + self.ndiffAssertEqual(f(s), '\n'.join(['case when foo = 1 then 2', + ' when foo = 3 then 4', + ' else 5', + 'end'])) + + diff --git a/tests/test_grouping.py b/tests/test_grouping.py new file mode 100644 index 0000000..fc3bea5 --- /dev/null +++ b/tests/test_grouping.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- + +import sqlparse +from sqlparse import tokens as T +from sqlparse.engine.grouping import * + +from tests.utils import TestCaseBase + + +class TestGrouping(TestCaseBase): + + def test_parenthesis(self): + s ='x1 (x2 (x3) x2) foo (y2) bar' + parsed = sqlparse.parse(s)[0] + self.ndiffAssertEqual(s, str(parsed)) + self.assertEqual(len(parsed.tokens), 9) + self.assert_(isinstance(parsed.tokens[2], Parenthesis)) + self.assert_(isinstance(parsed.tokens[-3], Parenthesis)) + self.assertEqual(len(parsed.tokens[2].tokens), 7) + self.assert_(isinstance(parsed.tokens[2].tokens[3], Parenthesis)) + self.assertEqual(len(parsed.tokens[2].tokens[3].tokens), 3) + + def test_comments(self): + s = '/*\n * foo\n */ \n bar' + parsed = sqlparse.parse(s)[0] + self.ndiffAssertEqual(s, unicode(parsed)) + self.assertEqual(len(parsed.tokens), 2) + + def test_identifiers(self): + s = 'select foo.bar from "myscheme"."table" where fail. order' + parsed = sqlparse.parse(s)[0] + self.ndiffAssertEqual(s, parsed.to_unicode()) + self.assert_(isinstance(parsed.tokens[2], Identifier)) + self.assert_(isinstance(parsed.tokens[6], Identifier)) + self.assert_(isinstance(parsed.tokens[8], Where)) + s = 'select * from foo where foo.id = 1' + parsed = sqlparse.parse(s)[0] + self.ndiffAssertEqual(s, parsed.to_unicode()) + self.assert_(isinstance(parsed.tokens[-1].tokens[2], Identifier)) + s = 'select * from (select "foo"."id" from foo)' + parsed = sqlparse.parse(s)[0] + self.ndiffAssertEqual(s, parsed.to_unicode()) + self.assert_(isinstance(parsed.tokens[-1].tokens[3], Identifier)) + + def test_where(self): + s = 'select * from foo where bar = 1 order by id desc' + p = sqlparse.parse(s)[0] + self.ndiffAssertEqual(s, p.to_unicode()) + self.assertTrue(len(p.tokens), 16) + s = 'select x from (select y from foo where bar = 1) z' + p = sqlparse.parse(s)[0] + self.ndiffAssertEqual(s, p.to_unicode()) + self.assertTrue(isinstance(p.tokens[-3].tokens[-1], Where)) + + def test_typecast(self): + s = 'select foo::integer from bar' + p = sqlparse.parse(s)[0] + self.ndiffAssertEqual(s, p.to_unicode()) + self.assertEqual(p.tokens[2].get_typecast(), 'integer') + self.assertEqual(p.tokens[2].get_name(), 'foo') + s = 'select (current_database())::information_schema.sql_identifier' + p = sqlparse.parse(s)[0] + self.ndiffAssertEqual(s, p.to_unicode()) + self.assertEqual(p.tokens[2].get_typecast(), + 'information_schema.sql_identifier') + + def test_alias(self): + s = 'select foo as bar from mytable' + p = sqlparse.parse(s)[0] + self.ndiffAssertEqual(s, p.to_unicode()) + self.assertEqual(p.tokens[2].get_real_name(), 'foo') + self.assertEqual(p.tokens[2].get_alias(), 'bar') + s = 'select foo from mytable t1' + p = sqlparse.parse(s)[0] + self.ndiffAssertEqual(s, p.to_unicode()) + self.assertEqual(p.tokens[6].get_real_name(), 'mytable') + self.assertEqual(p.tokens[6].get_alias(), 't1') + s = 'select foo::integer as bar from mytable' + p = sqlparse.parse(s)[0] + self.ndiffAssertEqual(s, p.to_unicode()) + self.assertEqual(p.tokens[2].get_alias(), 'bar') + s = ('SELECT DISTINCT ' + '(current_database())::information_schema.sql_identifier AS view') + p = sqlparse.parse(s)[0] + self.ndiffAssertEqual(s, p.to_unicode()) + self.assertEqual(p.tokens[4].get_alias(), 'view') diff --git a/tests/test_parse.py b/tests/test_parse.py new file mode 100644 index 0000000..536b6f6 --- /dev/null +++ b/tests/test_parse.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- + +"""Tests sqlparse function.""" + +from utils import TestCaseBase + +import sqlparse + + +class SQLParseTest(TestCaseBase): + """Tests sqlparse.parse().""" + + def test_tokenize(self): + sql = 'select * from foo;' + stmts = sqlparse.parse(sql) + self.assertEqual(len(stmts), 1) + self.assertEqual(str(stmts[0]), sql) + + def test_multistatement(self): + sql1 = 'select * from foo;' + sql2 = 'select * from bar;' + stmts = sqlparse.parse(sql1+sql2) + self.assertEqual(len(stmts), 2) + self.assertEqual(str(stmts[0]), sql1) + self.assertEqual(str(stmts[1]), sql2) + + def test_newlines(self): + sql = u'select\n*from foo;' + p = sqlparse.parse(sql)[0] + self.assertEqual(unicode(p), sql) + sql = u'select\r\n*from foo' + p = sqlparse.parse(sql)[0] + self.assertEqual(unicode(p), sql) + sql = u'select\r*from foo' + p = sqlparse.parse(sql)[0] + self.assertEqual(unicode(p), sql) + sql = u'select\r\n*from foo\n' + p = sqlparse.parse(sql)[0] + self.assertEqual(unicode(p), sql) diff --git a/tests/test_split.py b/tests/test_split.py new file mode 100644 index 0000000..782b226 --- /dev/null +++ b/tests/test_split.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- + +# Tests splitting functions. + +import unittest + +from utils import load_file, TestCaseBase + +import sqlparse + + +class SQLSplitTest(TestCaseBase): + """Tests sqlparse.sqlsplit().""" + + _sql1 = 'select * from foo;' + _sql2 = 'select * from bar;' + + def test_split_semicolon(self): + sql2 = 'select * from foo where bar = \'foo;bar\';' + stmts = sqlparse.parse(''.join([self._sql1, sql2])) + self.assertEqual(len(stmts), 2) + self.ndiffAssertEqual(unicode(stmts[0]), self._sql1) + self.ndiffAssertEqual(unicode(stmts[1]), sql2) + + def test_create_function(self): + sql = load_file('function.sql') + stmts = sqlparse.parse(sql) + self.assertEqual(len(stmts), 1) + self.ndiffAssertEqual(unicode(stmts[0]), sql) + + def test_create_function_psql(self): + sql = load_file('function_psql.sql') + stmts = sqlparse.parse(sql) + self.assertEqual(len(stmts), 1) + self.ndiffAssertEqual(unicode(stmts[0]), sql) + + def test_create_function_psql3(self): + sql = load_file('function_psql3.sql') + stmts = sqlparse.parse(sql) + self.assertEqual(len(stmts), 1) + self.ndiffAssertEqual(unicode(stmts[0]), sql) + + def test_create_function_psql2(self): + sql = load_file('function_psql2.sql') + stmts = sqlparse.parse(sql) + self.assertEqual(len(stmts), 1) + self.ndiffAssertEqual(unicode(stmts[0]), sql) + + def test_dashcomments(self): + sql = load_file('dashcomment.sql') + stmts = sqlparse.parse(sql) + self.assertEqual(len(stmts), 3) + self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql) + + def test_begintag(self): + sql = load_file('begintag.sql') + stmts = sqlparse.parse(sql) + self.assertEqual(len(stmts), 3) + self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql) + + def test_dropif(self): + sql = 'DROP TABLE IF EXISTS FOO;\n\nSELECT * FROM BAR;' + stmts = sqlparse.parse(sql) + self.assertEqual(len(stmts), 2) + self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql) + + def test_comment_with_umlaut(self): + sql = (u'select * from foo;\n' + u'-- Testing an umlaut: ä\n' + u'select * from bar;') + stmts = sqlparse.parse(sql) + self.assertEqual(len(stmts), 2) + self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql) + + def test_comment_end_of_line(self): + sql = ('select * from foo; -- foo\n' + 'select * from bar;') + stmts = sqlparse.parse(sql) + self.assertEqual(len(stmts), 2) + self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql) + # make sure the comment belongs to first query + self.ndiffAssertEqual(unicode(stmts[0]), 'select * from foo; -- foo\n') + + def test_casewhen(self): + sql = ('SELECT case when val = 1 then 2 else null end as foo;\n' + 'comment on table actor is \'The actor table.\';') + stmts = sqlparse.split(sql) + self.assertEqual(len(stmts), 2) diff --git a/tests/test_tokenize.py b/tests/test_tokenize.py new file mode 100644 index 0000000..7106b3c --- /dev/null +++ b/tests/test_tokenize.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- + +import unittest +import types + +from sqlparse import lexer +from sqlparse.tokens import * + + +class TestTokenize(unittest.TestCase): + + def test_simple(self): + sql = 'select * from foo;' + stream = lexer.tokenize(sql) + self.assert_(type(stream) is types.GeneratorType) + tokens = list(stream) + self.assertEqual(len(tokens), 8) + self.assertEqual(len(tokens[0]), 2) + self.assertEqual(tokens[0], (Keyword.DML, u'select')) + self.assertEqual(tokens[-1], (Punctuation, u';')) + diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..a78b460 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- + +"""Helpers for testing.""" + +import codecs +import difflib +import os +import unittest +from StringIO import StringIO + +NL = '\n' +DIR_PATH = os.path.abspath(os.path.dirname(__file__)) +PARENT_DIR = os.path.dirname(DIR_PATH) +FILES_DIR = os.path.join(DIR_PATH, 'files') + + +def load_file(filename, encoding='utf-8'): + """Opens filename with encoding and return it's contents.""" + f = codecs.open(os.path.join(FILES_DIR, filename), 'r', encoding) + data = f.read() + f.close() + return data + + +class TestCaseBase(unittest.TestCase): + """Base class for test cases with some additional checks.""" + + # Adopted from Python's tests. + def ndiffAssertEqual(self, first, second): + """Like failUnlessEqual except use ndiff for readable output.""" + if first <> second: + sfirst = unicode(first) + ssecond = unicode(second) + diff = difflib.ndiff(sfirst.splitlines(), ssecond.splitlines()) + fp = StringIO() + print >> fp, NL, NL.join(diff) + print fp.getvalue() + raise self.failureException, fp.getvalue() |
