summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndi Albrecht <albrecht.andi@gmail.com>2014-03-10 13:34:43 +0100
committerAndi Albrecht <albrecht.andi@gmail.com>2014-03-10 13:34:43 +0100
commitee9ff5c3eebc3c7579a9a37028fcc5388968a10e (patch)
treead9fe3c141c22113769a7f50b588f1d8c6156819
parentff7ba6404342898616be24115f7be4744520289d (diff)
parent480e52fddf28fad591f3214ee28c2d2af8842ce1 (diff)
downloadsqlparse-ee9ff5c3eebc3c7579a9a37028fcc5388968a10e.tar.gz
Merge pull request #132 from petronius/master
Fix SerializerUnicode to split unquoted newlines
-rw-r--r--sqlparse/filters.py7
-rw-r--r--sqlparse/utils.py43
-rw-r--r--tests/test_format.py17
-rw-r--r--tests/utils.py9
4 files changed, 71 insertions, 5 deletions
diff --git a/sqlparse/filters.py b/sqlparse/filters.py
index 40caf51..5a613a0 100644
--- a/sqlparse/filters.py
+++ b/sqlparse/filters.py
@@ -11,6 +11,7 @@ from sqlparse.pipeline import Pipeline
from sqlparse.tokens import (Comment, Comparison, Keyword, Name, Punctuation,
String, Whitespace)
from sqlparse.utils import memoize_generator
+from sqlparse.utils import split_unquoted_newlines
# --------------------------
@@ -534,10 +535,8 @@ class SerializerUnicode:
def process(self, stack, stmt):
raw = unicode(stmt)
- add_nl = raw.endswith('\n')
- res = '\n'.join(line.rstrip() for line in raw.splitlines())
- if add_nl:
- res += '\n'
+ lines = split_unquoted_newlines(raw)
+ res = '\n'.join(line.rstrip() for line in lines)
return res
diff --git a/sqlparse/utils.py b/sqlparse/utils.py
index cdf27b1..2a7fb46 100644
--- a/sqlparse/utils.py
+++ b/sqlparse/utils.py
@@ -94,3 +94,46 @@ def memoize_generator(func):
yield item
return wrapped_func
+
+def split_unquoted_newlines(text):
+ """Split a string on all unquoted newlines
+
+ This is a fairly simplistic implementation of splitting a string on all
+ unescaped CR, LF, or CR+LF occurences. Only iterates the string once. Seemed
+ easier than a complex regular expression.
+ """
+ lines = ['']
+ quoted = None
+ escape_next = False
+ last_char = None
+ for c in text:
+ escaped = False
+ # If the previous character was an unescpaed '\', this character is
+ # escaped.
+ if escape_next:
+ escaped = True
+ escape_next = False
+ # If the current character is '\' and it is not escaped, the next
+ # character is escaped.
+ if c == '\\' and not escaped:
+ escape_next = True
+ # Start a quoted portion if a) we aren't in one already, and b) the
+ # quote isn't escaped.
+ if c in '"\'' and not escaped and not quoted:
+ quoted = c
+ # Escaped quotes (obvs) don't count as a closing match.
+ elif c == quoted and not escaped:
+ quoted = None
+
+ if not quoted and c in ['\r', '\n']:
+ if c == '\n' and last_char == '\r':
+ # It's a CR+LF, so don't append another line
+ pass
+ else:
+ lines.append('')
+ else:
+ lines[-1] += c
+
+ last_char = c
+
+ return lines \ No newline at end of file
diff --git a/tests/test_format.py b/tests/test_format.py
index 701540b..b77b7a1 100644
--- a/tests/test_format.py
+++ b/tests/test_format.py
@@ -77,6 +77,23 @@ class TestFormat(TestCaseBase):
s = 'select\n* /* foo */ from bar '
self.ndiffAssertEqual(f(s), 'select * /* foo */ from bar')
+ def test_notransform_of_quoted_crlf(self):
+ # Make sure that CR/CR+LF characters inside string literals don't get
+ # affected by the formatter.
+
+ s1 = "SELECT some_column LIKE 'value\r'"
+ s2 = "SELECT some_column LIKE 'value\r'\r\nWHERE id = 1\n"
+ s3 = "SELECT some_column LIKE 'value\\'\r' WHERE id = 1\r"
+ s4 = "SELECT some_column LIKE 'value\\\\\\'\r' WHERE id = 1\r\n"
+
+ f = lambda x: sqlparse.format(x)
+
+ # Because of the use of
+ self.ndiffAssertEqual(f(s1), "SELECT some_column LIKE 'value\r'")
+ self.ndiffAssertEqual(f(s2), "SELECT some_column LIKE 'value\r'\nWHERE id = 1\n")
+ self.ndiffAssertEqual(f(s3), "SELECT some_column LIKE 'value\\'\r' WHERE id = 1\n")
+ self.ndiffAssertEqual(f(s4), "SELECT some_column LIKE 'value\\\\\\'\r' WHERE id = 1\n")
+
def test_outputformat(self):
sql = 'select * from foo;'
self.assertRaises(SQLParseError, sqlparse.format, sql,
diff --git a/tests/utils.py b/tests/utils.py
index e2c01a3..9eb46bf 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -8,6 +8,8 @@ import os
import unittest
from StringIO import StringIO
+import sqlparse.utils
+
NL = '\n'
DIR_PATH = os.path.abspath(os.path.dirname(__file__))
PARENT_DIR = os.path.dirname(DIR_PATH)
@@ -31,7 +33,12 @@ class TestCaseBase(unittest.TestCase):
if first != second:
sfirst = unicode(first)
ssecond = unicode(second)
- diff = difflib.ndiff(sfirst.splitlines(), ssecond.splitlines())
+ # Using the built-in .splitlines() method here will cause incorrect
+ # results when splitting statements that have quoted CR/CR+LF
+ # characters.
+ sfirst = sqlparse.utils.split_unquoted_newlines(sfirst)
+ ssecond = sqlparse.utils.split_unquoted_newlines(ssecond)
+ diff = difflib.ndiff(sfirst, ssecond)
fp = StringIO()
fp.write(NL)
fp.write(NL.join(diff))