# -*- coding: utf-8 -*- import pytest from tests.utils import TestCaseBase import sqlparse from sqlparse.exceptions import SQLParseError class TestFormat(TestCaseBase): def test_keywordcase(self): sql = 'select * from bar; -- select foo\n' res = sqlparse.format(sql, keyword_case='upper') self.ndiffAssertEqual(res, 'SELECT * FROM bar; -- select foo\n') res = sqlparse.format(sql, keyword_case='capitalize') self.ndiffAssertEqual(res, 'Select * From bar; -- select foo\n') res = sqlparse.format(sql.upper(), keyword_case='lower') self.ndiffAssertEqual(res, 'select * from BAR; -- SELECT FOO\n') self.assertRaises(SQLParseError, sqlparse.format, sql, keyword_case='foo') def test_identifiercase(self): sql = 'select * from bar; -- select foo\n' res = sqlparse.format(sql, identifier_case='upper') self.ndiffAssertEqual(res, 'select * from BAR; -- select foo\n') res = sqlparse.format(sql, identifier_case='capitalize') self.ndiffAssertEqual(res, 'select * from Bar; -- select foo\n') res = sqlparse.format(sql.upper(), identifier_case='lower') self.ndiffAssertEqual(res, 'SELECT * FROM bar; -- SELECT FOO\n') self.assertRaises(SQLParseError, sqlparse.format, sql, identifier_case='foo') sql = 'select * from "foo"."bar"' res = sqlparse.format(sql, identifier_case="upper") self.ndiffAssertEqual(res, 'select * from "foo"."bar"') def test_strip_comments_single(self): sql = 'select *-- statement starts here\nfrom foo' res = sqlparse.format(sql, strip_comments=True) self.ndiffAssertEqual(res, 'select * from foo') sql = 'select * -- statement starts here\nfrom foo' res = sqlparse.format(sql, strip_comments=True) self.ndiffAssertEqual(res, 'select * from foo') sql = 'select-- foo\nfrom -- bar\nwhere' res = sqlparse.format(sql, strip_comments=True) self.ndiffAssertEqual(res, 'select from where') self.assertRaises(SQLParseError, sqlparse.format, sql, strip_comments=None) def test_strip_comments_multi(self): sql = '/* sql starts here */\nselect' res = sqlparse.format(sql, strip_comments=True) self.ndiffAssertEqual(res, 'select') sql = '/* sql starts here */ select' res = sqlparse.format(sql, strip_comments=True) self.ndiffAssertEqual(res, 'select') sql = '/*\n * sql starts here\n */\nselect' res = sqlparse.format(sql, strip_comments=True) self.ndiffAssertEqual(res, 'select') sql = 'select (/* sql starts here */ select 2)' res = sqlparse.format(sql, strip_comments=True) self.ndiffAssertEqual(res, 'select (select 2)') def test_strip_ws(self): f = lambda sql: sqlparse.format(sql, strip_whitespace=True) s = 'select\n* from foo\n\twhere ( 1 = 2 )\n' self.ndiffAssertEqual(f(s), 'select * from foo where (1 = 2)') s = 'select -- foo\nfrom bar\n' self.ndiffAssertEqual(f(s), 'select -- foo\nfrom bar') self.assertRaises(SQLParseError, sqlparse.format, s, strip_whitespace=None) def test_preserve_ws(self): # preserve at least one whitespace after subgroups f = lambda sql: sqlparse.format(sql, strip_whitespace=True) s = 'select\n* /* foo */ from bar ' self.ndiffAssertEqual(f(s), 'select * /* foo */ from bar') def test_outputformat(self): sql = 'select * from foo;' self.assertRaises(SQLParseError, sqlparse.format, sql, output_format='foo') class TestFormatReindent(TestCaseBase): def test_option(self): self.assertRaises(SQLParseError, sqlparse.format, 'foo', reindent=2) self.assertRaises(SQLParseError, sqlparse.format, 'foo', indent_tabs=2) self.assertRaises(SQLParseError, sqlparse.format, 'foo', reindent=True, indent_width='foo') self.assertRaises(SQLParseError, sqlparse.format, 'foo', reindent=True, indent_width=-12) def test_stmts(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select foo; select bar' self.ndiffAssertEqual(f(s), 'select foo;\n\nselect bar') s = 'select foo' self.ndiffAssertEqual(f(s), 'select foo') s = 'select foo; -- test\n select bar' self.ndiffAssertEqual(f(s), 'select foo; -- test\n\nselect bar') def test_keywords(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select * from foo union select * from bar;' self.ndiffAssertEqual(f(s), '\n'.join(['select *', 'from foo', 'union', 'select *', 'from bar;'])) def test_keywords_between(self): # issue 14 # don't break AND after BETWEEN f = lambda sql: sqlparse.format(sql, reindent=True) s = 'and foo between 1 and 2 and bar = 3' self.ndiffAssertEqual(f(s), '\n'.join(['', 'and foo between 1 and 2', 'and bar = 3'])) def test_parenthesis(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select count(*) from (select * from foo);' self.ndiffAssertEqual(f(s), '\n'.join(['select count(*)', 'from', ' (select *', ' from foo);', ]) ) def test_where(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select * from foo where bar = 1 and baz = 2 or bzz = 3;' self.ndiffAssertEqual(f(s), ('select *\nfrom foo\n' 'where bar = 1\n' ' and baz = 2\n' ' or bzz = 3;')) s = 'select * from foo where bar = 1 and (baz = 2 or bzz = 3);' self.ndiffAssertEqual(f(s), ('select *\nfrom foo\n' 'where bar = 1\n' ' and (baz = 2\n' ' or bzz = 3);')) def test_join(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select * from foo join bar on 1 = 2' self.ndiffAssertEqual(f(s), '\n'.join(['select *', 'from foo', 'join bar on 1 = 2'])) s = 'select * from foo inner join bar on 1 = 2' self.ndiffAssertEqual(f(s), '\n'.join(['select *', 'from foo', 'inner join bar on 1 = 2'])) s = 'select * from foo left outer join bar on 1 = 2' self.ndiffAssertEqual(f(s), '\n'.join(['select *', 'from foo', 'left outer join bar on 1 = 2'] )) s = 'select * from foo straight_join bar on 1 = 2' self.ndiffAssertEqual(f(s), '\n'.join(['select *', 'from foo', 'straight_join bar on 1 = 2'] )) def test_identifier_list(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select foo, bar, baz from table1, table2 where 1 = 2' self.ndiffAssertEqual(f(s), '\n'.join(['select foo,', ' bar,', ' baz', 'from table1,', ' table2', 'where 1 = 2'])) s = 'select a.*, b.id from a, b' self.ndiffAssertEqual(f(s), '\n'.join(['select a.*,', ' b.id', 'from a,', ' b'])) def test_identifier_list_with_functions(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = ("select 'abc' as foo, coalesce(col1, col2)||col3 as bar," "col3 from my_table") self.ndiffAssertEqual(f(s), '\n'.join( ["select 'abc' as foo,", " coalesce(col1, col2)||col3 as bar,", " col3", "from my_table"])) def test_case(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'case when foo = 1 then 2 when foo = 3 then 4 else 5 end' self.ndiffAssertEqual(f(s), '\n'.join(['case', ' when foo = 1 then 2', ' when foo = 3 then 4', ' else 5', 'end'])) def test_case2(self): f = lambda sql: sqlparse.format(sql, reindent=True) s = 'case(foo) when bar = 1 then 2 else 3 end' self.ndiffAssertEqual(f(s), '\n'.join(['case(foo)', ' when bar = 1 then 2', ' else 3', 'end'])) def test_nested_identifier_list(self): # issue4 f = lambda sql: sqlparse.format(sql, reindent=True) s = '(foo as bar, bar1, bar2 as bar3, b4 as b5)' self.ndiffAssertEqual(f(s), '\n'.join(['(foo as bar,', ' bar1,', ' bar2 as bar3,', ' b4 as b5)'])) def test_duplicate_linebreaks(self): # issue3 f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select c1 -- column1\nfrom foo' self.ndiffAssertEqual(f(s), '\n'.join(['select c1 -- column1', 'from foo'])) s = 'select c1 -- column1\nfrom foo' r = sqlparse.format(s, reindent=True, strip_comments=True) self.ndiffAssertEqual(r, '\n'.join(['select c1', 'from foo'])) s = 'select c1\nfrom foo\norder by c1' self.ndiffAssertEqual(f(s), '\n'.join(['select c1', 'from foo', 'order by c1'])) s = 'select c1 from t1 where (c1 = 1) order by c1' self.ndiffAssertEqual(f(s), '\n'.join(['select c1', 'from t1', 'where (c1 = 1)', 'order by c1'])) def test_keywordfunctions(self): # issue36 f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select max(a) b, foo, bar' self.ndiffAssertEqual(f(s), '\n'.join(['select max(a) b,', ' foo,', ' bar'])) def test_identifier_and_functions(self): # issue45 f = lambda sql: sqlparse.format(sql, reindent=True) s = 'select foo.bar, nvl(1) from dual' self.ndiffAssertEqual(f(s), '\n'.join(['select foo.bar,', ' nvl(1)', 'from dual'])) class TestOutputFormat(TestCaseBase): def test_python(self): sql = 'select * from foo;' f = lambda sql: sqlparse.format(sql, output_format='python') self.ndiffAssertEqual(f(sql), "sql = 'select * from foo;'") f = lambda sql: sqlparse.format(sql, output_format='python', reindent=True) self.ndiffAssertEqual(f(sql), ("sql = ('select * '\n" " 'from foo;')")) def test_php(self): sql = 'select * from foo;' f = lambda sql: sqlparse.format(sql, output_format='php') self.ndiffAssertEqual(f(sql), '$sql = "select * from foo;";') f = lambda sql: sqlparse.format(sql, output_format='php', reindent=True) self.ndiffAssertEqual(f(sql), ('$sql = "select * ";\n' '$sql .= "from foo;";')) def test_sql(self): # "sql" is an allowed option but has no effect sql = 'select * from foo;' f = lambda sql: sqlparse.format(sql, output_format='sql') self.ndiffAssertEqual(f(sql), 'select * from foo;') def test_format_column_ordering(): # issue89 sql = 'select * from foo order by c1 desc, c2, c3;' formatted = sqlparse.format(sql, reindent=True) expected = '\n'.join(['select *', 'from foo', 'order by c1 desc,', ' c2,', ' c3;']) assert formatted == expected def test_truncate_strings(): sql = 'update foo set value = \'' + 'x' * 1000 + '\';' formatted = sqlparse.format(sql, truncate_strings=10) assert formatted == 'update foo set value = \'xxxxxxxxxx[...]\';' formatted = sqlparse.format(sql, truncate_strings=3, truncate_char='YYY') assert formatted == 'update foo set value = \'xxxYYY\';' def test_truncate_strings_invalid_option(): pytest.raises(SQLParseError, sqlparse.format, 'foo', truncate_strings='bar') pytest.raises(SQLParseError, sqlparse.format, 'foo', truncate_strings=-1) pytest.raises(SQLParseError, sqlparse.format, 'foo', truncate_strings=0) @pytest.mark.parametrize('sql', ['select verrrylongcolumn from foo', 'select "verrrylongcolumn" from "foo"']) def test_truncate_strings_doesnt_truncate_identifiers(sql): formatted = sqlparse.format(sql, truncate_strings=2) assert formatted == sql