diff options
author | Andi Albrecht <albrecht.andi@gmail.com> | 2014-03-10 13:34:43 +0100 |
---|---|---|
committer | Andi Albrecht <albrecht.andi@gmail.com> | 2014-03-10 13:34:43 +0100 |
commit | ee9ff5c3eebc3c7579a9a37028fcc5388968a10e (patch) | |
tree | ad9fe3c141c22113769a7f50b588f1d8c6156819 | |
parent | ff7ba6404342898616be24115f7be4744520289d (diff) | |
parent | 480e52fddf28fad591f3214ee28c2d2af8842ce1 (diff) | |
download | sqlparse-ee9ff5c3eebc3c7579a9a37028fcc5388968a10e.tar.gz |
Merge pull request #132 from petronius/master
Fix SerializerUnicode to split unquoted newlines
-rw-r--r-- | sqlparse/filters.py | 7 | ||||
-rw-r--r-- | sqlparse/utils.py | 43 | ||||
-rw-r--r-- | tests/test_format.py | 17 | ||||
-rw-r--r-- | tests/utils.py | 9 |
4 files changed, 71 insertions, 5 deletions
diff --git a/sqlparse/filters.py b/sqlparse/filters.py index 40caf51..5a613a0 100644 --- a/sqlparse/filters.py +++ b/sqlparse/filters.py @@ -11,6 +11,7 @@ from sqlparse.pipeline import Pipeline from sqlparse.tokens import (Comment, Comparison, Keyword, Name, Punctuation, String, Whitespace) from sqlparse.utils import memoize_generator +from sqlparse.utils import split_unquoted_newlines # -------------------------- @@ -534,10 +535,8 @@ class SerializerUnicode: def process(self, stack, stmt): raw = unicode(stmt) - add_nl = raw.endswith('\n') - res = '\n'.join(line.rstrip() for line in raw.splitlines()) - if add_nl: - res += '\n' + lines = split_unquoted_newlines(raw) + res = '\n'.join(line.rstrip() for line in lines) return res diff --git a/sqlparse/utils.py b/sqlparse/utils.py index cdf27b1..2a7fb46 100644 --- a/sqlparse/utils.py +++ b/sqlparse/utils.py @@ -94,3 +94,46 @@ def memoize_generator(func): yield item return wrapped_func + +def split_unquoted_newlines(text): + """Split a string on all unquoted newlines + + This is a fairly simplistic implementation of splitting a string on all + unescaped CR, LF, or CR+LF occurences. Only iterates the string once. Seemed + easier than a complex regular expression. + """ + lines = [''] + quoted = None + escape_next = False + last_char = None + for c in text: + escaped = False + # If the previous character was an unescpaed '\', this character is + # escaped. + if escape_next: + escaped = True + escape_next = False + # If the current character is '\' and it is not escaped, the next + # character is escaped. + if c == '\\' and not escaped: + escape_next = True + # Start a quoted portion if a) we aren't in one already, and b) the + # quote isn't escaped. + if c in '"\'' and not escaped and not quoted: + quoted = c + # Escaped quotes (obvs) don't count as a closing match. + elif c == quoted and not escaped: + quoted = None + + if not quoted and c in ['\r', '\n']: + if c == '\n' and last_char == '\r': + # It's a CR+LF, so don't append another line + pass + else: + lines.append('') + else: + lines[-1] += c + + last_char = c + + return lines
\ No newline at end of file diff --git a/tests/test_format.py b/tests/test_format.py index 701540b..b77b7a1 100644 --- a/tests/test_format.py +++ b/tests/test_format.py @@ -77,6 +77,23 @@ class TestFormat(TestCaseBase): s = 'select\n* /* foo */ from bar ' self.ndiffAssertEqual(f(s), 'select * /* foo */ from bar') + def test_notransform_of_quoted_crlf(self): + # Make sure that CR/CR+LF characters inside string literals don't get + # affected by the formatter. + + s1 = "SELECT some_column LIKE 'value\r'" + s2 = "SELECT some_column LIKE 'value\r'\r\nWHERE id = 1\n" + s3 = "SELECT some_column LIKE 'value\\'\r' WHERE id = 1\r" + s4 = "SELECT some_column LIKE 'value\\\\\\'\r' WHERE id = 1\r\n" + + f = lambda x: sqlparse.format(x) + + # Because of the use of + self.ndiffAssertEqual(f(s1), "SELECT some_column LIKE 'value\r'") + self.ndiffAssertEqual(f(s2), "SELECT some_column LIKE 'value\r'\nWHERE id = 1\n") + self.ndiffAssertEqual(f(s3), "SELECT some_column LIKE 'value\\'\r' WHERE id = 1\n") + self.ndiffAssertEqual(f(s4), "SELECT some_column LIKE 'value\\\\\\'\r' WHERE id = 1\n") + def test_outputformat(self): sql = 'select * from foo;' self.assertRaises(SQLParseError, sqlparse.format, sql, diff --git a/tests/utils.py b/tests/utils.py index e2c01a3..9eb46bf 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -8,6 +8,8 @@ import os import unittest from StringIO import StringIO +import sqlparse.utils + NL = '\n' DIR_PATH = os.path.abspath(os.path.dirname(__file__)) PARENT_DIR = os.path.dirname(DIR_PATH) @@ -31,7 +33,12 @@ class TestCaseBase(unittest.TestCase): if first != second: sfirst = unicode(first) ssecond = unicode(second) - diff = difflib.ndiff(sfirst.splitlines(), ssecond.splitlines()) + # Using the built-in .splitlines() method here will cause incorrect + # results when splitting statements that have quoted CR/CR+LF + # characters. + sfirst = sqlparse.utils.split_unquoted_newlines(sfirst) + ssecond = sqlparse.utils.split_unquoted_newlines(ssecond) + diff = difflib.ndiff(sfirst, ssecond) fp = StringIO() fp.write(NL) fp.write(NL.join(diff)) |