diff options
author | Vik <vmuriart@gmail.com> | 2016-06-16 02:33:28 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-06-16 02:33:28 -0700 |
commit | 92b5f2bb88ed1c1080ecf7eb7449f5c642ae196a (patch) | |
tree | 30b53c5970fc01fab5a14c9a0298f4e8d4eba077 | |
parent | 451d6d5d380cb4246c47e374aa9c4034fc7f9805 (diff) | |
parent | 9fcf1f2cda629cdf11a8a4ac596fb7cae0e89de9 (diff) | |
download | sqlparse-92b5f2bb88ed1c1080ecf7eb7449f5c642ae196a.tar.gz |
Merge pull request #260 from vmuriart/long_live_indexes
Long live indexes - Improve performance
-rw-r--r-- | AUTHORS | 1 | ||||
-rw-r--r-- | CHANGELOG | 1 | ||||
-rw-r--r-- | examples/column_defs_lowlevel.py | 8 | ||||
-rw-r--r-- | sqlparse/engine/grouping.py | 400 | ||||
-rw-r--r-- | sqlparse/filters/aligned_indent.py | 29 | ||||
-rw-r--r-- | sqlparse/filters/others.py | 34 | ||||
-rw-r--r-- | sqlparse/filters/reindent.py | 61 | ||||
-rw-r--r-- | sqlparse/sql.py | 162 | ||||
-rw-r--r-- | sqlparse/utils.py | 12 | ||||
-rw-r--r-- | tests/test_grouping.py | 6 | ||||
-rw-r--r-- | tests/test_regressions.py | 10 | ||||
-rw-r--r-- | tests/test_tokenize.py | 18 |
12 files changed, 424 insertions, 318 deletions
@@ -34,6 +34,7 @@ Alphabetical list of contributors: * Ryan Wooden <rygwdn@gmail.com> * saaj <id@saaj.me> * Shen Longxing <shenlongxing2012@gmail.com> +* Sjoerd Job Postmus * spigwitmer <itgpmc@gmail.com> * Tenghuan <tenghuanhe@gmail.com> * Tim Graham <timograham@gmail.com> @@ -10,6 +10,7 @@ Internal Changes sqlparse.exceptions. * sqlparse.sql.Token.to_unicode was removed. * Lots of code cleanups and modernization (thanks esp. to vmuriart!). +* Improved grouping performance. (sjoerdjob) Enhancements diff --git a/examples/column_defs_lowlevel.py b/examples/column_defs_lowlevel.py index 5e98be3..584b3f3 100644 --- a/examples/column_defs_lowlevel.py +++ b/examples/column_defs_lowlevel.py @@ -17,16 +17,16 @@ def extract_definitions(token_list): definitions = [] tmp = [] # grab the first token, ignoring whitespace. idx=1 to skip open ( - token = token_list.token_next(1) + tidx, token = token_list.token_next(1) while token and not token.match(sqlparse.tokens.Punctuation, ')'): tmp.append(token) # grab the next token, this times including whitespace - token = token_list.token_next(token, skip_ws=False) + tidx, token = token_list.token_next(tidx, skip_ws=False) # split on ",", except when on end of statement if token and token.match(sqlparse.tokens.Punctuation, ','): definitions.append(tmp) tmp = [] - token = token_list.token_next(token) + tidx, token = token_list.token_next(tidx) if tmp and isinstance(tmp[0], sqlparse.sql.Identifier): definitions.append(tmp) return definitions @@ -41,7 +41,7 @@ if __name__ == '__main__': parsed = sqlparse.parse(SQL)[0] # extract the parenthesis which holds column definitions - par = parsed.token_next_by(i=sqlparse.sql.Parenthesis) + _, par = parsed.token_next_by(i=sqlparse.sql.Parenthesis) columns = extract_definitions(par) for column in columns: diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py index 6e414b8..62f37a6 100644 --- a/sqlparse/engine/grouping.py +++ b/sqlparse/engine/grouping.py @@ -7,56 +7,58 @@ from sqlparse import sql from sqlparse import tokens as T -from sqlparse.utils import recurse, imt, find_matching - -M_ROLE = (T.Keyword, ('null', 'role')) -M_SEMICOLON = (T.Punctuation, ';') -M_COMMA = (T.Punctuation, ',') +from sqlparse.utils import recurse, imt T_NUMERICAL = (T.Number, T.Number.Integer, T.Number.Float) T_STRING = (T.String, T.String.Single, T.String.Symbol) T_NAME = (T.Name, T.Name.Placeholder) -def _group_left_right(tlist, m, cls, - valid_left=lambda t: t is not None, - valid_right=lambda t: t is not None, - semicolon=False): - """Groups together tokens that are joined by a middle token. ie. x < y""" +def _group_matching(tlist, cls): + """Groups Tokens that have beginning and end.""" + opens = [] + tidx_offset = 0 + for idx, token in enumerate(list(tlist)): + tidx = idx - tidx_offset + + if token.is_whitespace(): + # ~50% of tokens will be whitespace. Will checking early + # for them avoid 3 comparisons, but then add 1 more comparison + # for the other ~50% of tokens... + continue - for token in list(tlist): if token.is_group() and not isinstance(token, cls): - _group_left_right(token, m, cls, valid_left, valid_right, - semicolon) - - if not token.match(*m): + # Check inside previously grouped (ie. parenthesis) if group + # of differnt type is inside (ie, case). though ideally should + # should check for all open/close tokens at once to avoid recursion + _group_matching(token, cls) continue - left, right = tlist.token_prev(token), tlist.token_next(token) + if token.match(*cls.M_OPEN): + opens.append(tidx) - if valid_left(left) and valid_right(right): - if semicolon: - # only overwrite if a semicolon present. - sright = tlist.token_next_by(m=M_SEMICOLON, idx=right) - right = sright or right - tokens = tlist.tokens_between(left, right) - tlist.group_tokens(cls, tokens, extend=True) + elif token.match(*cls.M_CLOSE): + try: + open_idx = opens.pop() + except IndexError: + # this indicates invalid sql and unbalanced tokens. + # instead of break, continue in case other "valid" groups exist + continue + close_idx = tidx + tlist.group_tokens(cls, open_idx, close_idx) + tidx_offset += close_idx - open_idx -def _group_matching(tlist, cls): - """Groups Tokens that have beginning and end.""" - [_group_matching(sgroup, cls) for sgroup in tlist.get_sublists() - if not isinstance(sgroup, cls)] - idx = 1 if isinstance(tlist, cls) else 0 +def group_brackets(tlist): + _group_matching(tlist, sql.SquareBrackets) - token = tlist.token_next_by(m=cls.M_OPEN, idx=idx) - while token: - end = find_matching(tlist, token, cls.M_OPEN, cls.M_CLOSE) - if end is not None: - tokens = tlist.tokens_between(token, end) - token = tlist.group_tokens(cls, tokens) - _group_matching(token, cls) - token = tlist.token_next_by(m=cls.M_OPEN, idx=token) + +def group_parenthesis(tlist): + _group_matching(tlist, sql.Parenthesis) + + +def group_case(tlist): + _group_matching(tlist, sql.Case) def group_if(tlist): @@ -67,149 +69,202 @@ def group_for(tlist): _group_matching(tlist, sql.For) -def group_foreach(tlist): - _group_matching(tlist, sql.For) - - def group_begin(tlist): _group_matching(tlist, sql.Begin) +def group_typecasts(tlist): + def match(token): + return token.match(T.Punctuation, '::') + + def valid(token): + return token is not None + + def post(tlist, pidx, tidx, nidx): + return pidx, nidx + + valid_prev = valid_next = valid + _group(tlist, sql.Identifier, match, valid_prev, valid_next, post) + + +def group_period(tlist): + def match(token): + return token.match(T.Punctuation, '.') + + def valid_prev(token): + sqlcls = sql.SquareBrackets, sql.Identifier + ttypes = T.Name, T.String.Symbol + return imt(token, i=sqlcls, t=ttypes) + + def valid_next(token): + sqlcls = sql.SquareBrackets, sql.Function + ttypes = T.Name, T.String.Symbol, T.Wildcard + return imt(token, i=sqlcls, t=ttypes) + + def post(tlist, pidx, tidx, nidx): + return pidx, nidx + + _group(tlist, sql.Identifier, match, valid_prev, valid_next, post) + + def group_as(tlist): - lfunc = lambda tk: not imt(tk, t=T.Keyword) or tk.value == 'NULL' - rfunc = lambda tk: not imt(tk, t=(T.DML, T.DDL)) - _group_left_right(tlist, (T.Keyword, 'AS'), sql.Identifier, - valid_left=lfunc, valid_right=rfunc) + def match(token): + return token.is_keyword and token.normalized == 'AS' + + def valid_prev(token): + return token.normalized == 'NULL' or not token.is_keyword + + def valid_next(token): + ttypes = T.DML, T.DDL + return not imt(token, t=ttypes) + + def post(tlist, pidx, tidx, nidx): + return pidx, nidx + + _group(tlist, sql.Identifier, match, valid_prev, valid_next, post) def group_assignment(tlist): - _group_left_right(tlist, (T.Assignment, ':='), sql.Assignment, - semicolon=True) + def match(token): + return token.match(T.Assignment, ':=') + def valid(token): + return token is not None -def group_comparison(tlist): - I_COMPERABLE = (sql.Parenthesis, sql.Function, sql.Identifier, - sql.Operation) - T_COMPERABLE = T_NUMERICAL + T_STRING + T_NAME + def post(tlist, pidx, tidx, nidx): + m_semicolon = T.Punctuation, ';' + snidx, _ = tlist.token_next_by(m=m_semicolon, idx=nidx) + nidx = snidx or nidx + return pidx, nidx - func = lambda tk: (imt(tk, t=T_COMPERABLE, i=I_COMPERABLE) or - (tk and tk.is_keyword and tk.normalized == 'NULL')) + valid_prev = valid_next = valid + _group(tlist, sql.Assignment, match, valid_prev, valid_next, post) - _group_left_right(tlist, (T.Operator.Comparison, None), sql.Comparison, - valid_left=func, valid_right=func) +def group_comparison(tlist): + sqlcls = (sql.Parenthesis, sql.Function, sql.Identifier, + sql.Operation) + ttypes = T_NUMERICAL + T_STRING + T_NAME + + def match(token): + return token.ttype == T.Operator.Comparison + + def valid(token): + if imt(token, t=ttypes, i=sqlcls): + return True + elif token and token.is_keyword and token.normalized == 'NULL': + return True + else: + return False -def group_case(tlist): - _group_matching(tlist, sql.Case) + def post(tlist, pidx, tidx, nidx): + return pidx, nidx + + valid_prev = valid_next = valid + _group(tlist, sql.Comparison, match, + valid_prev, valid_next, post, extend=False) @recurse(sql.Identifier) def group_identifier(tlist): - T_IDENT = (T.String.Symbol, T.Name) + ttypes = (T.String.Symbol, T.Name) - token = tlist.token_next_by(t=T_IDENT) + tidx, token = tlist.token_next_by(t=ttypes) while token: - token = tlist.group_tokens(sql.Identifier, [token, ]) - token = tlist.token_next_by(t=T_IDENT, idx=token) + tlist.group_tokens(sql.Identifier, tidx, tidx) + tidx, token = tlist.token_next_by(t=ttypes, idx=tidx) -def group_period(tlist): - lfunc = lambda tk: imt(tk, i=(sql.SquareBrackets, sql.Identifier), - t=(T.Name, T.String.Symbol,)) +def group_arrays(tlist): + sqlcls = sql.SquareBrackets, sql.Identifier, sql.Function + ttypes = T.Name, T.String.Symbol - rfunc = lambda tk: imt(tk, i=(sql.SquareBrackets, sql.Function), - t=(T.Name, T.String.Symbol, T.Wildcard)) + def match(token): + return isinstance(token, sql.SquareBrackets) - _group_left_right(tlist, (T.Punctuation, '.'), sql.Identifier, - valid_left=lfunc, valid_right=rfunc) + def valid_prev(token): + return imt(token, i=sqlcls, t=ttypes) + def valid_next(token): + return True -def group_arrays(tlist): - token = tlist.token_next_by(i=sql.SquareBrackets) - while token: - prev = tlist.token_prev(token) - if imt(prev, i=(sql.SquareBrackets, sql.Identifier, sql.Function), - t=(T.Name, T.String.Symbol,)): - tokens = tlist.tokens_between(prev, token) - token = tlist.group_tokens(sql.Identifier, tokens, extend=True) - token = tlist.token_next_by(i=sql.SquareBrackets, idx=token) + def post(tlist, pidx, tidx, nidx): + return pidx, tidx + _group(tlist, sql.Identifier, match, + valid_prev, valid_next, post, extend=True, recurse=False) -@recurse(sql.Identifier) -def group_operator(tlist): - I_CYCLE = (sql.SquareBrackets, sql.Parenthesis, sql.Function, - sql.Identifier, sql.Operation) - # wilcards wouldn't have operations next to them - T_CYCLE = T_NUMERICAL + T_STRING + T_NAME - func = lambda tk: imt(tk, i=I_CYCLE, t=T_CYCLE) - token = tlist.token_next_by(t=(T.Operator, T.Wildcard)) - while token: - left, right = tlist.token_prev(token), tlist.token_next(token) - - if func(left) and func(right): - token.ttype = T.Operator - tokens = tlist.tokens_between(left, right) - token = tlist.group_tokens(sql.Operation, tokens) +def group_operator(tlist): + ttypes = T_NUMERICAL + T_STRING + T_NAME + sqlcls = (sql.SquareBrackets, sql.Parenthesis, sql.Function, + sql.Identifier, sql.Operation) - token = tlist.token_next_by(t=(T.Operator, T.Wildcard), idx=token) + def match(token): + return imt(token, t=(T.Operator, T.Wildcard)) + def valid(token): + return imt(token, i=sqlcls, t=ttypes) -@recurse(sql.IdentifierList) -def group_identifier_list(tlist): - I_IDENT_LIST = (sql.Function, sql.Case, sql.Identifier, sql.Comparison, - sql.IdentifierList, sql.Operation) - T_IDENT_LIST = (T_NUMERICAL + T_STRING + T_NAME + - (T.Keyword, T.Comment, T.Wildcard)) + def post(tlist, pidx, tidx, nidx): + tlist[tidx].ttype = T.Operator + return pidx, nidx - func = lambda t: imt(t, i=I_IDENT_LIST, m=M_ROLE, t=T_IDENT_LIST) - token = tlist.token_next_by(m=M_COMMA) + valid_prev = valid_next = valid + _group(tlist, sql.Operation, match, + valid_prev, valid_next, post, extend=False) - while token: - before, after = tlist.token_prev(token), tlist.token_next(token) - if func(before) and func(after): - tokens = tlist.tokens_between(before, after) - token = tlist.group_tokens(sql.IdentifierList, tokens, extend=True) - token = tlist.token_next_by(m=M_COMMA, idx=token) +def group_identifier_list(tlist): + m_role = T.Keyword, ('null', 'role') + m_comma = T.Punctuation, ',' + sqlcls = (sql.Function, sql.Case, sql.Identifier, sql.Comparison, + sql.IdentifierList, sql.Operation) + ttypes = (T_NUMERICAL + T_STRING + T_NAME + + (T.Keyword, T.Comment, T.Wildcard)) + def match(token): + return imt(token, m=m_comma) -def group_brackets(tlist): - _group_matching(tlist, sql.SquareBrackets) + def valid(token): + return imt(token, i=sqlcls, m=m_role, t=ttypes) + def post(tlist, pidx, tidx, nidx): + return pidx, nidx -def group_parenthesis(tlist): - _group_matching(tlist, sql.Parenthesis) + valid_prev = valid_next = valid + _group(tlist, sql.IdentifierList, match, + valid_prev, valid_next, post, extend=True) @recurse(sql.Comment) def group_comments(tlist): - token = tlist.token_next_by(t=T.Comment) + tidx, token = tlist.token_next_by(t=T.Comment) while token: - end = tlist.token_not_matching( - token, lambda tk: imt(tk, t=T.Comment) or tk.is_whitespace()) + eidx, end = tlist.token_not_matching( + lambda tk: imt(tk, t=T.Comment) or tk.is_whitespace(), idx=tidx) if end is not None: - end = tlist.token_prev(end, False) - tokens = tlist.tokens_between(token, end) - token = tlist.group_tokens(sql.Comment, tokens) + eidx, end = tlist.token_prev(eidx, skip_ws=False) + tlist.group_tokens(sql.Comment, tidx, eidx) - token = tlist.token_next_by(t=T.Comment, idx=token) + tidx, token = tlist.token_next_by(t=T.Comment, idx=tidx) @recurse(sql.Where) def group_where(tlist): - token = tlist.token_next_by(m=sql.Where.M_OPEN) + tidx, token = tlist.token_next_by(m=sql.Where.M_OPEN) while token: - end = tlist.token_next_by(m=sql.Where.M_CLOSE, idx=token) + eidx, end = tlist.token_next_by(m=sql.Where.M_CLOSE, idx=tidx) if end is None: - tokens = tlist.tokens_between(token, tlist._groupable_tokens[-1]) + end = tlist._groupable_tokens[-1] else: - tokens = tlist.tokens_between( - token, tlist.tokens[tlist.token_index(end) - 1]) - - token = tlist.group_tokens(sql.Where, tokens) - token = tlist.token_next_by(m=sql.Where.M_OPEN, idx=token) + end = tlist.tokens[eidx - 1] + # TODO: convert this to eidx instead of end token. + # i think above values are len(tlist) and eidx-1 + eidx = tlist.token_index(end) + tlist.group_tokens(sql.Where, tidx, eidx) + tidx, token = tlist.token_next_by(m=sql.Where.M_OPEN, idx=tidx) @recurse() @@ -217,17 +272,12 @@ def group_aliased(tlist): I_ALIAS = (sql.Parenthesis, sql.Function, sql.Case, sql.Identifier, sql.Operation) - token = tlist.token_next_by(i=I_ALIAS, t=T.Number) + tidx, token = tlist.token_next_by(i=I_ALIAS, t=T.Number) while token: - next_ = tlist.token_next(token) + nidx, next_ = tlist.token_next(tidx) if imt(next_, i=sql.Identifier): - tokens = tlist.tokens_between(token, next_) - token = tlist.group_tokens(sql.Identifier, tokens, extend=True) - token = tlist.token_next_by(i=I_ALIAS, t=T.Number, idx=token) - - -def group_typecasts(tlist): - _group_left_right(tlist, (T.Punctuation, '::'), sql.Identifier) + tlist.group_tokens(sql.Identifier, tidx, nidx, extend=True) + tidx, token = tlist.token_next_by(i=I_ALIAS, t=T.Number, idx=tidx) @recurse(sql.Function) @@ -241,45 +291,51 @@ def group_functions(tlist): has_table = True if has_create and has_table: return - token = tlist.token_next_by(t=T.Name) + + tidx, token = tlist.token_next_by(t=T.Name) while token: - next_ = tlist.token_next(token) - if imt(next_, i=sql.Parenthesis): - tokens = tlist.tokens_between(token, next_) - token = tlist.group_tokens(sql.Function, tokens) - token = tlist.token_next_by(t=T.Name, idx=token) + nidx, next_ = tlist.token_next(tidx) + if isinstance(next_, sql.Parenthesis): + tlist.group_tokens(sql.Function, tidx, nidx) + tidx, token = tlist.token_next_by(t=T.Name, idx=tidx) def group_order(tlist): """Group together Identifier and Asc/Desc token""" - token = tlist.token_next_by(t=T.Keyword.Order) + tidx, token = tlist.token_next_by(t=T.Keyword.Order) while token: - prev = tlist.token_prev(token) - if imt(prev, i=sql.Identifier, t=T.Number): - tokens = tlist.tokens_between(prev, token) - token = tlist.group_tokens(sql.Identifier, tokens) - token = tlist.token_next_by(t=T.Keyword.Order, idx=token) + pidx, prev_ = tlist.token_prev(tidx) + if imt(prev_, i=sql.Identifier, t=T.Number): + tlist.group_tokens(sql.Identifier, pidx, tidx) + tidx = pidx + tidx, token = tlist.token_next_by(t=T.Keyword.Order, idx=tidx) @recurse() def align_comments(tlist): - token = tlist.token_next_by(i=sql.Comment) + tidx, token = tlist.token_next_by(i=sql.Comment) while token: - before = tlist.token_prev(token) - if isinstance(before, sql.TokenList): - tokens = tlist.tokens_between(before, token) - token = tlist.group_tokens(sql.TokenList, tokens, extend=True) - token = tlist.token_next_by(i=sql.Comment, idx=token) + pidx, prev_ = tlist.token_prev(tidx) + if isinstance(prev_, sql.TokenList): + tlist.group_tokens(sql.TokenList, pidx, tidx, extend=True) + tidx = pidx + tidx, token = tlist.token_next_by(i=sql.Comment, idx=tidx) def group(stmt): for func in [ group_comments, + + # _group_matching group_brackets, group_parenthesis, + group_case, + group_if, + group_for, + group_begin, + group_functions, group_where, - group_case, group_period, group_arrays, group_identifier, @@ -290,12 +346,42 @@ def group(stmt): group_aliased, group_assignment, group_comparison, + align_comments, group_identifier_list, - group_if, - group_for, - group_foreach, - group_begin, ]: func(stmt) return stmt + + +def _group(tlist, cls, match, + valid_prev=lambda t: True, + valid_next=lambda t: True, + post=None, + extend=True, + recurse=True + ): + """Groups together tokens that are joined by a middle token. ie. x < y""" + + tidx_offset = 0 + pidx, prev_ = None, None + for idx, token in enumerate(list(tlist)): + tidx = idx - tidx_offset + + if token.is_whitespace(): + continue + + if recurse and token.is_group() and not isinstance(token, cls): + _group(token, cls, match, valid_prev, valid_next, post, extend) + + if match(token): + nidx, next_ = tlist.token_next(tidx) + if valid_prev(prev_) and valid_next(next_): + from_idx, to_idx = post(tlist, pidx, tidx, nidx) + grp = tlist.group_tokens(cls, from_idx, to_idx, extend=extend) + + tidx_offset += to_idx - from_idx + pidx, prev_ = from_idx, grp + continue + + pidx, prev_ = tidx, token diff --git a/sqlparse/filters/aligned_indent.py b/sqlparse/filters/aligned_indent.py index ea749e9..2fea4d2 100644 --- a/sqlparse/filters/aligned_indent.py +++ b/sqlparse/filters/aligned_indent.py @@ -46,7 +46,8 @@ class AlignedIndentFilter(object): def _process_parenthesis(self, tlist): # if this isn't a subquery, don't re-indent - if tlist.token_next_by(m=(T.DML, 'SELECT')): + _, token = tlist.token_next_by(m=(T.DML, 'SELECT')) + if token is not None: with indent(self): tlist.insert_after(tlist[0], self.nl('SELECT')) # process the inside of the parantheses @@ -66,7 +67,7 @@ class AlignedIndentFilter(object): offset_ = len('case ') + len('when ') cases = tlist.get_cases(skip_ws=True) # align the end as well - end_token = tlist.token_next_by(m=(T.Keyword, 'END')) + _, end_token = tlist.token_next_by(m=(T.Keyword, 'END')) cases.append((None, [end_token])) condition_width = [len(' '.join(map(text_type, cond))) if cond else 0 @@ -85,18 +86,18 @@ class AlignedIndentFilter(object): max_cond_width - condition_width[i])) tlist.insert_after(cond[-1], ws) - def _next_token(self, tlist, idx=0): + def _next_token(self, tlist, idx=-1): split_words = T.Keyword, self.split_words, True - token = tlist.token_next_by(m=split_words, idx=idx) + tidx, token = tlist.token_next_by(m=split_words, idx=idx) # treat "BETWEEN x and y" as a single statement - if token and token.value.upper() == 'BETWEEN': - token = self._next_token(tlist, token) - if token and token.value.upper() == 'AND': - token = self._next_token(tlist, token) - return token + if token and token.normalized == 'BETWEEN': + tidx, token = self._next_token(tlist, tidx) + if token and token.normalized == 'AND': + tidx, token = self._next_token(tlist, tidx) + return tidx, token def _split_kwds(self, tlist): - token = self._next_token(tlist) + tidx, token = self._next_token(tlist) while token: # joins are special case. only consider the first word as aligner if token.match(T.Keyword, self.join_words, regex=True): @@ -104,15 +105,17 @@ class AlignedIndentFilter(object): else: token_indent = text_type(token) tlist.insert_before(token, self.nl(token_indent)) - token = self._next_token(tlist, token) + tidx += 1 + tidx, token = self._next_token(tlist, tidx) def _process_default(self, tlist): self._split_kwds(tlist) # process any sub-sub statements for sgroup in tlist.get_sublists(): - prev = tlist.token_prev(sgroup) + idx = tlist.token_index(sgroup) + pidx, prev_ = tlist.token_prev(idx) # HACK: make "group/order by" work. Longer than max_len. - offset_ = 3 if (prev and prev.match(T.Keyword, 'BY')) else 0 + offset_ = 3 if (prev_ and prev_.match(T.Keyword, 'BY')) else 0 with offset(self, offset_): self._process(sgroup) diff --git a/sqlparse/filters/others.py b/sqlparse/filters/others.py index 71b1f8e..9d4a1d1 100644 --- a/sqlparse/filters/others.py +++ b/sqlparse/filters/others.py @@ -16,21 +16,20 @@ class StripCommentsFilter(object): # TODO(andi) Comment types should be unified, see related issue38 return tlist.token_next_by(i=sql.Comment, t=T.Comment) - token = get_next_comment() + tidx, token = get_next_comment() while token: - prev = tlist.token_prev(token, skip_ws=False) - next_ = tlist.token_next(token, skip_ws=False) + pidx, prev_ = tlist.token_prev(tidx, skip_ws=False) + nidx, next_ = tlist.token_next(tidx, skip_ws=False) # Replace by whitespace if prev and next exist and if they're not # whitespaces. This doesn't apply if prev or next is a paranthesis. - if (prev is None or next_ is None or - prev.is_whitespace() or prev.match(T.Punctuation, '(') or + if (prev_ is None or next_ is None or + prev_.is_whitespace() or prev_.match(T.Punctuation, '(') or next_.is_whitespace() or next_.match(T.Punctuation, ')')): tlist.tokens.remove(token) else: - tidx = tlist.token_index(token) tlist.tokens[tidx] = sql.Token(T.Whitespace, ' ') - token = get_next_comment() + tidx, token = get_next_comment() def process(self, stmt): [self.process(sgroup) for sgroup in stmt.get_sublists()] @@ -86,20 +85,21 @@ class StripWhitespaceFilter(object): class SpacesAroundOperatorsFilter(object): @staticmethod def _process(tlist): - def next_token(idx=0): - return tlist.token_next_by(t=(T.Operator, T.Comparison), idx=idx) - token = next_token() + ttypes = (T.Operator, T.Comparison) + tidx, token = tlist.token_next_by(t=ttypes) while token: - prev_ = tlist.token_prev(token, skip_ws=False) - if prev_ and prev_.ttype != T.Whitespace: - tlist.insert_before(token, sql.Token(T.Whitespace, ' ')) - - next_ = tlist.token_next(token, skip_ws=False) + nidx, next_ = tlist.token_next(tidx, skip_ws=False) if next_ and next_.ttype != T.Whitespace: - tlist.insert_after(token, sql.Token(T.Whitespace, ' ')) + tlist.insert_after(tidx, sql.Token(T.Whitespace, ' ')) + + pidx, prev_ = tlist.token_prev(tidx, skip_ws=False) + if prev_ and prev_.ttype != T.Whitespace: + tlist.insert_before(tidx, sql.Token(T.Whitespace, ' ')) + tidx += 1 # has to shift since token inserted before it - token = next_token(idx=token) + # assert tlist.token_index(token) == tidx + tidx, token = tlist.token_next_by(t=ttypes, idx=tidx) def process(self, stmt): [self.process(sgroup) for sgroup in stmt.get_sublists()] diff --git a/sqlparse/filters/reindent.py b/sqlparse/filters/reindent.py index b490631..68595a5 100644 --- a/sqlparse/filters/reindent.py +++ b/sqlparse/filters/reindent.py @@ -44,44 +44,50 @@ class ReindentFilter(object): def nl(self): return sql.Token(T.Whitespace, self.n + self.char * self.leading_ws) - def _next_token(self, tlist, idx=0): + def _next_token(self, tlist, idx=-1): split_words = ('FROM', 'STRAIGHT_JOIN$', 'JOIN$', 'AND', 'OR', 'GROUP', 'ORDER', 'UNION', 'VALUES', 'SET', 'BETWEEN', 'EXCEPT', 'HAVING') - token = tlist.token_next_by(m=(T.Keyword, split_words, True), idx=idx) + m_split = T.Keyword, split_words, True + tidx, token = tlist.token_next_by(m=m_split, idx=idx) - if token and token.value.upper() == 'BETWEEN': - token = self._next_token(tlist, token) + if token and token.normalized == 'BETWEEN': + tidx, token = self._next_token(tlist, tidx) - if token and token.value.upper() == 'AND': - token = self._next_token(tlist, token) + if token and token.normalized == 'AND': + tidx, token = self._next_token(tlist, tidx) - return token + return tidx, token def _split_kwds(self, tlist): - token = self._next_token(tlist) + tidx, token = self._next_token(tlist) while token: - prev = tlist.token_prev(token, skip_ws=False) - uprev = text_type(prev) + pidx, prev_ = tlist.token_prev(tidx, skip_ws=False) + uprev = text_type(prev_) - if prev and prev.is_whitespace(): - tlist.tokens.remove(prev) + if prev_ and prev_.is_whitespace(): + del tlist.tokens[pidx] + tidx -= 1 if not (uprev.endswith('\n') or uprev.endswith('\r')): - tlist.insert_before(token, self.nl()) + tlist.insert_before(tidx, self.nl()) + tidx += 1 - token = self._next_token(tlist, token) + tidx, token = self._next_token(tlist, tidx) def _split_statements(self, tlist): - token = tlist.token_next_by(t=(T.Keyword.DDL, T.Keyword.DML)) + ttypes = T.Keyword.DML, T.Keyword.DDL + tidx, token = tlist.token_next_by(t=ttypes) while token: - prev = tlist.token_prev(token, skip_ws=False) - if prev and prev.is_whitespace(): - tlist.tokens.remove(prev) + pidx, prev_ = tlist.token_prev(tidx, skip_ws=False) + if prev_ and prev_.is_whitespace(): + del tlist.tokens[pidx] + tidx -= 1 # only break if it's not the first token - tlist.insert_before(token, self.nl()) if prev else None - token = tlist.token_next_by(t=(T.Keyword.DDL, T.Keyword.DML), - idx=token) + if prev_: + tlist.insert_before(tidx, self.nl()) + tidx += 1 + tidx, token = tlist.token_next_by(t=ttypes, idx=tidx) def _process(self, tlist): func_name = '_process_{cls}'.format(cls=type(tlist).__name__) @@ -89,16 +95,17 @@ class ReindentFilter(object): func(tlist) def _process_where(self, tlist): - token = tlist.token_next_by(m=(T.Keyword, 'WHERE')) + tidx, token = tlist.token_next_by(m=(T.Keyword, 'WHERE')) # issue121, errors in statement fixed?? - tlist.insert_before(token, self.nl()) + tlist.insert_before(tidx, self.nl()) with indent(self): self._process_default(tlist) def _process_parenthesis(self, tlist): - is_dml_dll = tlist.token_next_by(t=(T.Keyword.DML, T.Keyword.DDL)) - first = tlist.token_next_by(m=sql.Parenthesis.M_OPEN) + ttypes = T.Keyword.DML, T.Keyword.DDL + _, is_dml_dll = tlist.token_next_by(t=ttypes) + fidx, first = tlist.token_next_by(m=sql.Parenthesis.M_OPEN) with indent(self, 1 if is_dml_dll else 0): tlist.tokens.insert(0, self.nl()) if is_dml_dll else None @@ -135,8 +142,8 @@ class ReindentFilter(object): # len "when ", "then ", "else " with offset(self, len("WHEN ")): self._process_default(tlist) - end = tlist.token_next_by(m=sql.Case.M_CLOSE) - tlist.insert_before(end, self.nl()) + end_idx, end = tlist.token_next_by(m=sql.Case.M_CLOSE) + tlist.insert_before(end_idx, self.nl()) def _process_default(self, tlist, stmts=True): self._split_statements(tlist) if stmts else None diff --git a/sqlparse/sql.py b/sqlparse/sql.py index 52b3bf1..9656390 100644 --- a/sqlparse/sql.py +++ b/sqlparse/sql.py @@ -209,109 +209,138 @@ class TokenList(Token): if start is None: return None - if not isinstance(start, int): - start = self.token_index(start) + 1 - if not isinstance(funcs, (list, tuple)): funcs = (funcs,) if reverse: - iterable = reversed(self.tokens[end:start - 1]) + assert end is None + for idx in range(start - 2, -1, -1): + token = self.tokens[idx] + for func in funcs: + if func(token): + return idx, token else: - iterable = self.tokens[start:end] + for idx, token in enumerate(self.tokens[start:end], start=start): + for func in funcs: + if func(token): + return idx, token + return None, None + + def token_first(self, skip_ws=True, skip_cm=False): + """Returns the first child token. - for token in iterable: - for func in funcs: - if func(token): - return token + If *skip_ws* is ``True`` (the default), whitespace + tokens are ignored. - def token_next_by(self, i=None, m=None, t=None, idx=0, end=None): + if *skip_cm* is ``True`` (default: ``False``), comments are + ignored too. + """ + # this on is inconsistent, using Comment instead of T.Comment... + funcs = lambda tk: not ((skip_ws and tk.is_whitespace()) or + (skip_cm and imt(tk, t=T.Comment, i=Comment))) + return self._token_matching(funcs)[1] + + def token_next_by(self, i=None, m=None, t=None, idx=-1, end=None): funcs = lambda tk: imt(tk, i, m, t) + idx += 1 return self._token_matching(funcs, idx, end) - def token_not_matching(self, idx, funcs): + def token_not_matching(self, funcs, idx): funcs = (funcs,) if not isinstance(funcs, (list, tuple)) else funcs funcs = [lambda tk: not func(tk) for func in funcs] return self._token_matching(funcs, idx) - def token_matching(self, idx, funcs): - return self._token_matching(funcs, idx) + def token_matching(self, funcs, idx): + return self._token_matching(funcs, idx)[1] - def token_prev(self, idx=0, skip_ws=True, skip_cm=False): + def token_prev(self, idx, skip_ws=True, skip_cm=False): """Returns the previous token relative to *idx*. If *skip_ws* is ``True`` (the default) whitespace tokens are ignored. ``None`` is returned if there's no previous token. """ + if idx is None: + return None, None + idx += 1 # alot of code usage current pre-compensates for this funcs = lambda tk: not ((skip_ws and tk.is_whitespace()) or (skip_cm and imt(tk, t=T.Comment, i=Comment))) return self._token_matching(funcs, idx, reverse=True) - def token_next(self, idx=0, skip_ws=True, skip_cm=False): + # TODO: May need to implement skip_cm for upstream changes. + # TODO: May need to re-add default value to idx + def token_next(self, idx, skip_ws=True, skip_cm=False): """Returns the next token relative to *idx*. - If called with idx = 0. Returns the first child token. If *skip_ws* is ``True`` (the default) whitespace tokens are ignored. - If *skip_cm* is ``True`` (default: ``False``), comments are ignored. ``None`` is returned if there's no next token. """ - funcs = lambda tk: not ((skip_ws and tk.is_whitespace()) or - (skip_cm and imt(tk, t=T.Comment, i=Comment))) - return self._token_matching(funcs, idx) + if idx is None: + return None, None + idx += 1 # alot of code usage current pre-compensates for this + try: + if not skip_ws: + return idx, self.tokens[idx] + else: + while True: + token = self.tokens[idx] + if not token.is_whitespace(): + return idx, token + idx += 1 + except IndexError: + return None, None def token_index(self, token, start=0): """Return list index of token.""" start = start if isinstance(start, int) else self.token_index(start) return start + self.tokens[start:].index(token) - def tokens_between(self, start, end, include_end=True): - """Return all tokens between (and including) start and end. - - If *include_end* is ``False`` (default is ``True``) the end token - is excluded. - """ - start_idx = self.token_index(start) - end_idx = include_end + self.token_index(end) - return self.tokens[start_idx:end_idx] - - def group_tokens(self, grp_cls, tokens, skip_ws=False, extend=False): + def group_tokens(self, grp_cls, start, end, include_end=True, + extend=False): """Replace tokens by an instance of *grp_cls*.""" + start_idx = start + start = self.tokens[start_idx] - while skip_ws and tokens and tokens[-1].is_whitespace(): - tokens = tokens[:-1] + end_idx = end + include_end - left = tokens[0] - idx = self.token_index(left) + # will be needed later for new group_clauses + # while skip_ws and tokens and tokens[-1].is_whitespace(): + # tokens = tokens[:-1] - if extend and isinstance(left, grp_cls): - grp = left - grp.tokens.extend(tokens[1:]) + if extend and isinstance(start, grp_cls): + subtokens = self.tokens[start_idx + 1:end_idx] + + grp = start + grp.tokens.extend(subtokens) + del self.tokens[start_idx + 1:end_idx] + grp.value = text_type(start) else: - grp = grp_cls(tokens) + subtokens = self.tokens[start_idx:end_idx] + grp = grp_cls(subtokens) + self.tokens[start_idx:end_idx] = [grp] + grp.parent = self - for token in tokens: + for token in subtokens: token.parent = grp - # Improve performance. LOOP(list.remove()) is O(n**2) operation - self.tokens = [token for token in self.tokens if token not in tokens] - - self.tokens.insert(idx, grp) - grp.parent = self return grp def insert_before(self, where, token): """Inserts *token* before *where*.""" + if not isinstance(where, int): + where = self.token_index(where) token.parent = self - self.tokens.insert(self.token_index(where), token) + self.tokens.insert(where, token) def insert_after(self, where, token, skip_ws=True): """Inserts *token* after *where*.""" - next_token = self.token_next(where, skip_ws=skip_ws) + if not isinstance(where, int): + where = self.token_index(where) + nidx, next_ = self.token_next(where, skip_ws=skip_ws) token.parent = self - if next_token is None: + if next_ is None: self.tokens.append(token) else: - self.insert_before(next_token, token) + self.tokens.insert(nidx, token) def has_alias(self): """Returns ``True`` if an alias is present.""" @@ -321,12 +350,13 @@ class TokenList(Token): """Returns the alias for this identifier or ``None``.""" # "name AS alias" - kw = self.token_next_by(m=(T.Keyword, 'AS')) + kw_idx, kw = self.token_next_by(m=(T.Keyword, 'AS')) if kw is not None: - return self._get_first_name(kw, keywords=True) + return self._get_first_name(kw_idx + 1, keywords=True) # "name alias" or "complicated column expression alias" - if len(self.tokens) > 2 and self.token_next_by(t=T.Whitespace): + _, ws = self.token_next_by(t=T.Whitespace) + if len(self.tokens) > 2 and ws is not None: return self._get_first_name(reverse=True) def get_name(self): @@ -341,24 +371,21 @@ class TokenList(Token): def get_real_name(self): """Returns the real name (object name) of this identifier.""" # a.b - dot = self.token_next_by(m=(T.Punctuation, '.')) - return self._get_first_name(dot) + dot_idx, _ = self.token_next_by(m=(T.Punctuation, '.')) + return self._get_first_name(dot_idx) def get_parent_name(self): """Return name of the parent object if any. A parent object is identified by the first occuring dot. """ - dot = self.token_next_by(m=(T.Punctuation, '.')) - prev_ = self.token_prev(dot) + dot_idx, _ = self.token_next_by(m=(T.Punctuation, '.')) + _, prev_ = self.token_prev(dot_idx) return remove_quotes(prev_.value) if prev_ is not None else None def _get_first_name(self, idx=None, reverse=False, keywords=False): """Returns the name of the first token with a name""" - if idx and not isinstance(idx, int): - idx = self.token_index(idx) + 1 - tokens = self.tokens[idx:] if idx else self.tokens tokens = reversed(tokens) if reverse else tokens types = [T.Name, T.Wildcard, T.String.Symbol] @@ -386,7 +413,7 @@ class Statement(TokenList): Whitespaces and comments at the beginning of the statement are ignored. """ - first_token = self.token_next(skip_cm=True) + first_token = self.token_first(skip_cm=True) if first_token is None: # An "empty" statement that either has not tokens at all # or only whitespace tokens. @@ -399,9 +426,10 @@ class Statement(TokenList): # The WITH keyword should be followed by either an Identifier or # an IdentifierList containing the CTE definitions; the actual # DML keyword (e.g. SELECT, INSERT) will follow next. - token = self.token_next(first_token, skip_ws=True) + fidx = self.token_index(first_token) + tidx, token = self.token_next(fidx, skip_ws=True) if isinstance(token, (Identifier, IdentifierList)): - dml_keyword = self.token_next(token, skip_ws=True) + _, dml_keyword = self.token_next(tidx, skip_ws=True) if dml_keyword.ttype == T.Keyword.DML: return dml_keyword.normalized @@ -418,18 +446,18 @@ class Identifier(TokenList): def is_wildcard(self): """Return ``True`` if this identifier contains a wildcard.""" - token = self.token_next_by(t=T.Wildcard) + _, token = self.token_next_by(t=T.Wildcard) return token is not None def get_typecast(self): """Returns the typecast or ``None`` of this object as a string.""" - marker = self.token_next_by(m=(T.Punctuation, '::')) - next_ = self.token_next(marker, skip_ws=False) + midx, marker = self.token_next_by(m=(T.Punctuation, '::')) + nidx, next_ = self.token_next(midx, skip_ws=False) return next_.value if next_ else None def get_ordering(self): """Returns the ordering or ``None`` as uppercase string.""" - ordering = self.token_next_by(t=T.Keyword.Order) + _, ordering = self.token_next_by(t=T.Keyword.Order) return ordering.normalized if ordering else None def get_array_indices(self): @@ -576,7 +604,7 @@ class Function(TokenList): """Return a list of parameters.""" parenthesis = self.tokens[-1] for token in parenthesis.tokens: - if imt(token, i=IdentifierList): + if isinstance(token, IdentifierList): return token.get_identifiers() elif imt(token, i=(Function, Identifier), t=T.Literal): return [token, ] diff --git a/sqlparse/utils.py b/sqlparse/utils.py index 4a8646d..c3542b8 100644 --- a/sqlparse/utils.py +++ b/sqlparse/utils.py @@ -103,18 +103,6 @@ def imt(token, i=None, m=None, t=None): return False -def find_matching(tlist, token, open_pattern, close_pattern): - idx = tlist.token_index(token) - depth = 0 - for token in tlist.tokens[idx:]: - if token.match(*open_pattern): - depth += 1 - elif token.match(*close_pattern): - depth -= 1 - if depth == 0: - return token - - def consume(iterator, n): """Advance the iterator n-steps ahead. If n is none, consume entirely.""" deque(itertools.islice(iterator, n), maxlen=0) diff --git a/tests/test_grouping.py b/tests/test_grouping.py index 147162f..272d266 100644 --- a/tests/test_grouping.py +++ b/tests/test_grouping.py @@ -128,11 +128,11 @@ class TestGrouping(TestCaseBase): p = sqlparse.parse("select * from (" "select a, b + c as d from table) sub")[0] subquery = p.tokens[-1].tokens[0] - iden_list = subquery.token_next_by(i=sql.IdentifierList) + idx, iden_list = subquery.token_next_by(i=sql.IdentifierList) self.assert_(iden_list is not None) # all the identifiers should be within the IdentifierList - self.assert_(subquery.token_next_by(i=sql.Identifier, - idx=iden_list) is None) + _, ilist = subquery.token_next_by(i=sql.Identifier, idx=idx) + self.assert_(ilist is None) def test_identifier_list_case(self): p = sqlparse.parse('a, case when 1 then 2 else 3 end as b, c')[0] diff --git a/tests/test_regressions.py b/tests/test_regressions.py index 616c321..b55939a 100644 --- a/tests/test_regressions.py +++ b/tests/test_regressions.py @@ -2,6 +2,7 @@ import sys +import pytest # noqa from tests.utils import TestCaseBase, load_file import sqlparse @@ -48,7 +49,7 @@ class RegressionTests(TestCaseBase): self.assert_(p.tokens[0].ttype is T.Comment.Single) def test_issue34(self): - t = sqlparse.parse("create")[0].token_next() + t = sqlparse.parse("create")[0].token_first() self.assertEqual(t.match(T.Keyword.DDL, "create"), True) self.assertEqual(t.match(T.Keyword.DDL, "CREATE"), True) @@ -311,10 +312,3 @@ def test_issue207_runaway_format(): " 2 as two,", " 3", " from dual) t0"]) - - -def test_case_within_parenthesis(): - # see issue #164 - s = '(case when 1=1 then 2 else 5 end)' - p = sqlparse.parse(s)[0] - assert isinstance(p[0][1], sql.Case) diff --git a/tests/test_tokenize.py b/tests/test_tokenize.py index 7200682..61eaa3e 100644 --- a/tests/test_tokenize.py +++ b/tests/test_tokenize.py @@ -104,23 +104,21 @@ class TestTokenList(unittest.TestCase): def test_token_first(self): p = sqlparse.parse(' select foo')[0] - first = p.token_next() + first = p.token_first() self.assertEqual(first.value, 'select') - self.assertEqual(p.token_next(skip_ws=False).value, ' ') - self.assertEqual(sql.TokenList([]).token_next(), None) + self.assertEqual(p.token_first(skip_ws=False).value, ' ') + self.assertEqual(sql.TokenList([]).token_first(), None) def test_token_matching(self): t1 = sql.Token(T.Keyword, 'foo') t2 = sql.Token(T.Punctuation, ',') x = sql.TokenList([t1, t2]) - self.assertEqual(x.token_matching(0, [lambda t: t.ttype is T.Keyword]), - t1) self.assertEqual(x.token_matching( - 0, - [lambda t: t.ttype is T.Punctuation]), - t2) - self.assertEqual(x.token_matching(1, [lambda t: t.ttype is T.Keyword]), - None) + [lambda t: t.ttype is T.Keyword], 0), t1) + self.assertEqual(x.token_matching( + [lambda t: t.ttype is T.Punctuation], 0), t2) + self.assertEqual(x.token_matching( + [lambda t: t.ttype is T.Keyword], 1), None) class TestStream(unittest.TestCase): |