summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sqlparse/__init__.py8
-rw-r--r--sqlparse/lexer.py57
-rw-r--r--tests/test_split.py15
-rw-r--r--tests/test_tokenize.py39
4 files changed, 101 insertions, 18 deletions
diff --git a/sqlparse/__init__.py b/sqlparse/__init__.py
index f924c04..bacaa78 100644
--- a/sqlparse/__init__.py
+++ b/sqlparse/__init__.py
@@ -54,6 +54,14 @@ def split(sql):
stack.split_statements = True
return [unicode(stmt) for stmt in stack.run(sql)]
+def splitstream(sql):
+ """Split *sql* into single statements.
+
+ Returns a list of strings.
+ """
+ stack = engine.FilterStack()
+ stack.split_statements = True
+ return stack.run(sql)
from sqlparse.engine.filter import StatementFilter
def split2(stream):
diff --git a/sqlparse/lexer.py b/sqlparse/lexer.py
index 321669d..67dbc29 100644
--- a/sqlparse/lexer.py
+++ b/sqlparse/lexer.py
@@ -159,6 +159,7 @@ class Lexer(object):
stripnl = False
tabsize = 0
flags = re.IGNORECASE
+ bufsize = 4096
tokens = {
'root': [
@@ -214,6 +215,21 @@ class Lexer(object):
filter_ = filter_(**options)
self.filters.append(filter_)
+ def _decode(self, text):
+ if self.encoding == 'guess':
+ try:
+ text = text.decode('utf-8')
+ if text.startswith(u'\ufeff'):
+ text = text[len(u'\ufeff'):]
+ except UnicodeDecodeError:
+ text = text.decode('latin1')
+ else:
+ text = text.decode(self.encoding)
+
+ if self.tabsize > 0:
+ text = text.expandtabs(self.tabsize)
+ return text
+
def get_tokens(self, text, unfiltered=False):
"""
Return an iterable of (tokentype, value) pairs generated from
@@ -223,24 +239,14 @@ class Lexer(object):
Also preprocess the text, i.e. expand tabs and strip it if
wanted and applies registered filters.
"""
- if not isinstance(text, unicode):
- if self.encoding == 'guess':
- try:
- text = text.decode('utf-8')
- if text.startswith(u'\ufeff'):
- text = text[len(u'\ufeff'):]
- except UnicodeDecodeError:
- text = text.decode('latin1')
- else:
- text = text.decode(self.encoding)
- if self.stripall:
- text = text.strip()
- elif self.stripnl:
- text = text.strip('\n')
- if self.tabsize > 0:
- text = text.expandtabs(self.tabsize)
-# if not text.endswith('\n'):
-# text += '\n'
+ if isinstance(text, str):
+ text = self._decode(text)
+
+ if isinstance(text, basestring):
+ if self.stripall:
+ text = text.strip()
+ elif self.stripnl:
+ text = text.strip('\n')
def streamer():
for i, t, v in self.get_tokens_unprocessed(text):
@@ -261,10 +267,19 @@ class Lexer(object):
statestack = list(stack)
statetokens = tokendefs[statestack[-1]]
known_names = {}
+
+ hasmore = False
+ if hasattr(text, 'read'):
+ o, text = text, self._decode(text.read(self.bufsize))
+ hasmore = len(text) == self.bufsize
+
while 1:
for rexmatch, action, new_state in statetokens:
m = rexmatch(text, pos)
if m:
+ if hasmore and m.end() == len(text):
+ continue
+
# print rex.pattern
value = m.group()
if value in known_names:
@@ -307,6 +322,12 @@ class Lexer(object):
statetokens = tokendefs['root']
yield pos, tokens.Text, u'\n'
continue
+ if hasmore:
+ buf = self._decode(o.read(self.bufsize))
+ hasmore = len(buf) == self.bufsize
+ text = text[pos:] + buf
+ pos = 0
+ continue
yield pos, tokens.Error, text[pos]
pos += 1
except IndexError:
diff --git a/tests/test_split.py b/tests/test_split.py
index 29ee72e..c73d9d4 100644
--- a/tests/test_split.py
+++ b/tests/test_split.py
@@ -116,3 +116,18 @@ class SQLSplitTest(TestCaseBase):
'SELECT t FROM two')
stmts = sqlparse.split(sql)
self.assertEqual(len(stmts), 2)
+
+ def test_split_stream(self):
+ import types
+ from cStringIO import StringIO
+
+ stream = StringIO("SELECT 1; SELECT 2;")
+ stmts = sqlparse.splitstream(stream)
+ self.assertEqual(type(stmts), types.GeneratorType)
+ self.assertEqual(len(list(stmts)), 2)
+
+ def test_encoding_splitstream(self):
+ from cStringIO import StringIO
+ stream = StringIO("SELECT 1; SELECT 2;")
+ stmts = list(sqlparse.splitstream(stream))
+ self.assertEqual(type(stmts[0].tokens[0].value), unicode)
diff --git a/tests/test_tokenize.py b/tests/test_tokenize.py
index 3f51a46..5b403f9 100644
--- a/tests/test_tokenize.py
+++ b/tests/test_tokenize.py
@@ -69,6 +69,13 @@ class TestTokenize(unittest.TestCase):
self.assertEqual(tokens[2][0], Number.Integer)
self.assertEqual(tokens[2][1], '-1')
+ def test_tab_expansion(self):
+ sql = "\t"
+ lex = lexer.Lexer()
+ lex.tabsize = 5
+ tokens = list(lex.get_tokens(sql))
+ self.assertEqual(tokens[0][1], " " * 5)
+
class TestToken(unittest.TestCase):
def test_str(self):
@@ -116,3 +123,35 @@ class TestTokenList(unittest.TestCase):
t2)
self.assertEqual(x.token_matching(1, [lambda t: t.ttype is Keyword]),
None)
+
+class TestStream(unittest.TestCase):
+ def test_simple(self):
+ import types
+ from cStringIO import StringIO
+
+ stream = StringIO("SELECT 1; SELECT 2;")
+ lex = lexer.Lexer()
+
+ tokens = lex.get_tokens(stream)
+ self.assertEqual(len(list(tokens)), 9)
+
+ stream.seek(0)
+ lex.bufsize = 4
+ tokens = list(lex.get_tokens(stream))
+ self.assertEqual(len(tokens), 9)
+
+ stream.seek(0)
+ lex.bufsize = len(stream.getvalue())
+ tokens = list(lex.get_tokens(stream))
+ self.assertEqual(len(tokens), 9)
+
+ def test_error(self):
+ from cStringIO import StringIO
+
+ stream = StringIO("FOOBAR{")
+
+ lex = lexer.Lexer()
+ lex.bufsize = 4
+ tokens = list(lex.get_tokens(stream))
+ self.assertEqual(len(tokens), 2)
+ self.assertEqual(tokens[1][0], Error)