diff options
| -rw-r--r-- | sqlparse/__init__.py | 8 | ||||
| -rw-r--r-- | sqlparse/lexer.py | 57 | ||||
| -rw-r--r-- | tests/test_split.py | 15 | ||||
| -rw-r--r-- | tests/test_tokenize.py | 39 |
4 files changed, 101 insertions, 18 deletions
diff --git a/sqlparse/__init__.py b/sqlparse/__init__.py index f924c04..bacaa78 100644 --- a/sqlparse/__init__.py +++ b/sqlparse/__init__.py @@ -54,6 +54,14 @@ def split(sql): stack.split_statements = True return [unicode(stmt) for stmt in stack.run(sql)] +def splitstream(sql): + """Split *sql* into single statements. + + Returns a list of strings. + """ + stack = engine.FilterStack() + stack.split_statements = True + return stack.run(sql) from sqlparse.engine.filter import StatementFilter def split2(stream): diff --git a/sqlparse/lexer.py b/sqlparse/lexer.py index 321669d..67dbc29 100644 --- a/sqlparse/lexer.py +++ b/sqlparse/lexer.py @@ -159,6 +159,7 @@ class Lexer(object): stripnl = False tabsize = 0 flags = re.IGNORECASE + bufsize = 4096 tokens = { 'root': [ @@ -214,6 +215,21 @@ class Lexer(object): filter_ = filter_(**options) self.filters.append(filter_) + def _decode(self, text): + if self.encoding == 'guess': + try: + text = text.decode('utf-8') + if text.startswith(u'\ufeff'): + text = text[len(u'\ufeff'):] + except UnicodeDecodeError: + text = text.decode('latin1') + else: + text = text.decode(self.encoding) + + if self.tabsize > 0: + text = text.expandtabs(self.tabsize) + return text + def get_tokens(self, text, unfiltered=False): """ Return an iterable of (tokentype, value) pairs generated from @@ -223,24 +239,14 @@ class Lexer(object): Also preprocess the text, i.e. expand tabs and strip it if wanted and applies registered filters. """ - if not isinstance(text, unicode): - if self.encoding == 'guess': - try: - text = text.decode('utf-8') - if text.startswith(u'\ufeff'): - text = text[len(u'\ufeff'):] - except UnicodeDecodeError: - text = text.decode('latin1') - else: - text = text.decode(self.encoding) - if self.stripall: - text = text.strip() - elif self.stripnl: - text = text.strip('\n') - if self.tabsize > 0: - text = text.expandtabs(self.tabsize) -# if not text.endswith('\n'): -# text += '\n' + if isinstance(text, str): + text = self._decode(text) + + if isinstance(text, basestring): + if self.stripall: + text = text.strip() + elif self.stripnl: + text = text.strip('\n') def streamer(): for i, t, v in self.get_tokens_unprocessed(text): @@ -261,10 +267,19 @@ class Lexer(object): statestack = list(stack) statetokens = tokendefs[statestack[-1]] known_names = {} + + hasmore = False + if hasattr(text, 'read'): + o, text = text, self._decode(text.read(self.bufsize)) + hasmore = len(text) == self.bufsize + while 1: for rexmatch, action, new_state in statetokens: m = rexmatch(text, pos) if m: + if hasmore and m.end() == len(text): + continue + # print rex.pattern value = m.group() if value in known_names: @@ -307,6 +322,12 @@ class Lexer(object): statetokens = tokendefs['root'] yield pos, tokens.Text, u'\n' continue + if hasmore: + buf = self._decode(o.read(self.bufsize)) + hasmore = len(buf) == self.bufsize + text = text[pos:] + buf + pos = 0 + continue yield pos, tokens.Error, text[pos] pos += 1 except IndexError: diff --git a/tests/test_split.py b/tests/test_split.py index 29ee72e..c73d9d4 100644 --- a/tests/test_split.py +++ b/tests/test_split.py @@ -116,3 +116,18 @@ class SQLSplitTest(TestCaseBase): 'SELECT t FROM two') stmts = sqlparse.split(sql) self.assertEqual(len(stmts), 2) + + def test_split_stream(self): + import types + from cStringIO import StringIO + + stream = StringIO("SELECT 1; SELECT 2;") + stmts = sqlparse.splitstream(stream) + self.assertEqual(type(stmts), types.GeneratorType) + self.assertEqual(len(list(stmts)), 2) + + def test_encoding_splitstream(self): + from cStringIO import StringIO + stream = StringIO("SELECT 1; SELECT 2;") + stmts = list(sqlparse.splitstream(stream)) + self.assertEqual(type(stmts[0].tokens[0].value), unicode) diff --git a/tests/test_tokenize.py b/tests/test_tokenize.py index 3f51a46..5b403f9 100644 --- a/tests/test_tokenize.py +++ b/tests/test_tokenize.py @@ -69,6 +69,13 @@ class TestTokenize(unittest.TestCase): self.assertEqual(tokens[2][0], Number.Integer) self.assertEqual(tokens[2][1], '-1') + def test_tab_expansion(self): + sql = "\t" + lex = lexer.Lexer() + lex.tabsize = 5 + tokens = list(lex.get_tokens(sql)) + self.assertEqual(tokens[0][1], " " * 5) + class TestToken(unittest.TestCase): def test_str(self): @@ -116,3 +123,35 @@ class TestTokenList(unittest.TestCase): t2) self.assertEqual(x.token_matching(1, [lambda t: t.ttype is Keyword]), None) + +class TestStream(unittest.TestCase): + def test_simple(self): + import types + from cStringIO import StringIO + + stream = StringIO("SELECT 1; SELECT 2;") + lex = lexer.Lexer() + + tokens = lex.get_tokens(stream) + self.assertEqual(len(list(tokens)), 9) + + stream.seek(0) + lex.bufsize = 4 + tokens = list(lex.get_tokens(stream)) + self.assertEqual(len(tokens), 9) + + stream.seek(0) + lex.bufsize = len(stream.getvalue()) + tokens = list(lex.get_tokens(stream)) + self.assertEqual(len(tokens), 9) + + def test_error(self): + from cStringIO import StringIO + + stream = StringIO("FOOBAR{") + + lex = lexer.Lexer() + lex.bufsize = 4 + tokens = list(lex.get_tokens(stream)) + self.assertEqual(len(tokens), 2) + self.assertEqual(tokens[1][0], Error) |
