summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sqlparse/engine/grouping.py80
-rw-r--r--sqlparse/lexer.py6
-rw-r--r--sqlparse/sql.py18
-rw-r--r--sqlparse/tokens.py1
-rw-r--r--tests/test_parse.py72
5 files changed, 125 insertions, 52 deletions
diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py
index 9314b89..a317044 100644
--- a/sqlparse/engine/grouping.py
+++ b/sqlparse/engine/grouping.py
@@ -51,19 +51,21 @@ def _group_left_right(tlist, ttype, value, cls,
ttype, value)
+def _find_matching(idx, tlist, start_ttype, start_value, end_ttype, end_value):
+ depth = 1
+ for tok in tlist.tokens[idx:]:
+ if tok.match(start_ttype, start_value):
+ depth += 1
+ elif tok.match(end_ttype, end_value):
+ depth -= 1
+ if depth == 1:
+ return tok
+ return None
+
+
def _group_matching(tlist, start_ttype, start_value, end_ttype, end_value,
cls, include_semicolon=False, recurse=False):
- def _find_matching(i, tl, stt, sva, ett, eva):
- depth = 1
- for n in xrange(i, len(tl.tokens)):
- t = tl.tokens[n]
- if t.match(stt, sva):
- depth += 1
- elif t.match(ett, eva):
- depth -= 1
- if depth == 1:
- return t
- return None
+
[_group_matching(sgroup, start_ttype, start_value, end_ttype, end_value,
cls, include_semicolon) for sgroup in tlist.get_sublists()
if recurse]
@@ -157,16 +159,17 @@ def group_identifier(tlist):
lambda y: (y.match(T.Punctuation, '.')
or y.ttype in (T.Operator,
T.Wildcard,
- T.ArrayIndex,
- T.Name)),
+ T.Name)
+ or isinstance(y, sql.SquareBrackets)),
lambda y: (y.ttype in (T.String.Symbol,
T.Name,
T.Wildcard,
- T.ArrayIndex,
T.Literal.String.Single,
T.Literal.Number.Integer,
T.Literal.Number.Float)
- or isinstance(y, (sql.Parenthesis, sql.Function)))))
+ or isinstance(y, (sql.Parenthesis,
+ sql.SquareBrackets,
+ sql.Function)))))
for t in tl.tokens[i:]:
# Don't take whitespaces into account.
if t.ttype is T.Whitespace:
@@ -275,9 +278,48 @@ def group_identifier_list(tlist):
tcomma = next_
-def group_parenthesis(tlist):
- _group_matching(tlist, T.Punctuation, '(', T.Punctuation, ')',
- sql.Parenthesis)
+def group_brackets(tlist):
+ """Group parentheses () or square brackets []
+
+ This is just like _group_matching, but complicated by the fact that
+ round brackets can contain square bracket groups and vice versa
+ """
+
+ if isinstance(tlist, (sql.Parenthesis, sql.SquareBrackets)):
+ idx = 1
+ else:
+ idx = 0
+
+ # Find the first opening bracket
+ token = tlist.token_next_match(idx, T.Punctuation, ['(', '['])
+
+ while token:
+ start_val = token.value # either '(' or '['
+ if start_val == '(':
+ end_val = ')'
+ group_class = sql.Parenthesis
+ else:
+ end_val = ']'
+ group_class = sql.SquareBrackets
+
+ tidx = tlist.token_index(token)
+
+ # Find the corresponding closing bracket
+ end = _find_matching(tidx, tlist, T.Punctuation, start_val,
+ T.Punctuation, end_val)
+
+ if end is None:
+ idx = tidx + 1
+ else:
+ group = tlist.group_tokens(group_class,
+ tlist.tokens_between(token, end))
+
+ # Check for nested bracket groups within this group
+ group_brackets(group)
+ idx = tlist.token_index(group) + 1
+
+ # Find the next opening bracket
+ token = tlist.token_next_match(idx, T.Punctuation, ['(', '['])
def group_comments(tlist):
@@ -393,7 +435,7 @@ def align_comments(tlist):
def group(tlist):
for func in [
group_comments,
- group_parenthesis,
+ group_brackets,
group_functions,
group_where,
group_case,
diff --git a/sqlparse/lexer.py b/sqlparse/lexer.py
index 999eb2c..4707990 100644
--- a/sqlparse/lexer.py
+++ b/sqlparse/lexer.py
@@ -194,8 +194,10 @@ class Lexer(object):
(r"'(''|\\\\|\\'|[^'])*'", tokens.String.Single),
# not a real string literal in ANSI SQL:
(r'(""|".*?[^\\]")', tokens.String.Symbol),
- (r'(?<=[\w\]])(\[[^\]]*?\])', tokens.Punctuation.ArrayIndex),
- (r'(\[[^\]]+\])', tokens.Name),
+ # sqlite names can be escaped with [square brackets]. left bracket
+ # cannot be preceded by word character or a right bracket --
+ # otherwise it's probably an array index
+ (r'(?<![\w\])])(\[[^\]]+\])', tokens.Name),
(r'((LEFT\s+|RIGHT\s+|FULL\s+)?(INNER\s+|OUTER\s+|STRAIGHT\s+)?|(CROSS\s+|NATURAL\s+)?)?JOIN\b', tokens.Keyword),
(r'END(\s+IF|\s+LOOP)?\b', tokens.Keyword),
(r'NOT NULL\b', tokens.Keyword),
diff --git a/sqlparse/sql.py b/sqlparse/sql.py
index 8492c5e..9fcb546 100644
--- a/sqlparse/sql.py
+++ b/sqlparse/sql.py
@@ -511,11 +511,12 @@ class Identifier(TokenList):
return ordering.value.upper()
def get_array_indices(self):
- """Returns an iterator of index expressions as strings"""
+ """Returns an iterator of index token lists"""
- # Use [1:-1] index to discard the square brackets
- return (tok.value[1:-1] for tok in self.tokens
- if tok.ttype in T.ArrayIndex)
+ for tok in self.tokens:
+ if isinstance(tok, SquareBrackets):
+ # Use [1:-1] index to discard the square brackets
+ yield tok.tokens[1:-1]
class IdentifierList(TokenList):
@@ -542,6 +543,15 @@ class Parenthesis(TokenList):
return self.tokens[1:-1]
+class SquareBrackets(TokenList):
+ """Tokens between square brackets"""
+
+ __slots__ = ('value', 'ttype', 'tokens')
+
+ @property
+ def _groupable_tokens(self):
+ return self.tokens[1:-1]
+
class Assignment(TokenList):
"""An assignment like 'var := val;'"""
__slots__ = ('value', 'ttype', 'tokens')
diff --git a/sqlparse/tokens.py b/sqlparse/tokens.py
index 014984b..01a9b89 100644
--- a/sqlparse/tokens.py
+++ b/sqlparse/tokens.py
@@ -57,7 +57,6 @@ Literal = Token.Literal
String = Literal.String
Number = Literal.Number
Punctuation = Token.Punctuation
-ArrayIndex = Punctuation.ArrayIndex
Operator = Token.Operator
Comparison = Operator.Comparison
Wildcard = Token.Wildcard
diff --git a/tests/test_parse.py b/tests/test_parse.py
index ad5d2db..f6e796f 100644
--- a/tests/test_parse.py
+++ b/tests/test_parse.py
@@ -215,7 +215,7 @@ def test_single_quotes_with_linebreaks(): # issue118
assert p[0].ttype is T.String.Single
-def test_array_indexed_column():
+def test_sqlite_identifiers():
# Make sure we still parse sqlite style escapes
p = sqlparse.parse('[col1],[col2]')[0].tokens
assert (len(p) == 1
@@ -227,39 +227,59 @@ def test_array_indexed_column():
types = [tok.ttype for tok in p.flatten()]
assert types == [T.Name, T.Operator, T.Name]
+
+def test_simple_1d_array_index():
p = sqlparse.parse('col[1]')[0].tokens
- assert (len(p) == 1
- and tuple(p[0].get_array_indices()) == ('1',)
- and p[0].get_name() == 'col')
+ assert len(p) == 1
+ assert p[0].get_name() == 'col'
+ indices = list(p[0].get_array_indices())
+ assert (len(indices) == 1 # 1-dimensional index
+ and len(indices[0]) == 1 # index is single token
+ and indices[0][0].value == '1')
- p = sqlparse.parse('col[1][1:5] as mycol')[0].tokens
- assert (len(p) == 1
- and tuple(p[0].get_array_indices()) == ('1', '1:5')
- and p[0].get_name() == 'mycol'
- and p[0].get_real_name() == 'col')
-
- p = sqlparse.parse('col[1][other_col]')[0].tokens
- assert len(p) == 1 and tuple(p[0].get_array_indices()) == ('1', 'other_col')
-
- sql = 'SELECT col1, my_1d_array[2] as alias1, my_2d_array[2][5] as alias2'
- p = sqlparse.parse(sql)[0].tokens
- assert len(p) == 3 and isinstance(p[2], sqlparse.sql.IdentifierList)
- ids = list(p[2].get_identifiers())
- assert (ids[0].get_name() == 'col1'
- and tuple(ids[0].get_array_indices()) == ()
- and ids[1].get_name() == 'alias1'
- and ids[1].get_real_name() == 'my_1d_array'
- and tuple(ids[1].get_array_indices()) == ('2',)
- and ids[2].get_name() == 'alias2'
- and ids[2].get_real_name() == 'my_2d_array'
- and tuple(ids[2].get_array_indices()) == ('2', '5'))
+
+def test_2d_array_index():
+ p = sqlparse.parse('col[x][(y+1)*2]')[0].tokens
+ assert len(p) == 1
+ assert p[0].get_name() == 'col'
+ assert len(list(p[0].get_array_indices())) == 2 # 2-dimensional index
+
+
+def test_array_index_function_result():
+ p = sqlparse.parse('somefunc()[1]')[0].tokens
+ assert len(p) == 1
+ assert len(list(p[0].get_array_indices())) == 1
+
+
+def test_schema_qualified_array_index():
+ p = sqlparse.parse('schem.col[1]')[0].tokens
+ assert len(p) == 1
+ assert p[0].get_parent_name() == 'schem'
+ assert p[0].get_name() == 'col'
+ assert list(p[0].get_array_indices())[0][0].value == '1'
+
+
+def test_aliased_array_index():
+ p = sqlparse.parse('col[1] x')[0].tokens
+ assert len(p) == 1
+ assert p[0].get_alias() == 'x'
+ assert p[0].get_real_name() == 'col'
+ assert list(p[0].get_array_indices())[0][0].value == '1'
+
+
+def test_array_literal():
+ # See issue #176
+ p = sqlparse.parse('ARRAY[%s, %s]')[0]
+ assert len(p.tokens) == 2
+ assert len(list(p.flatten())) == 7
def test_typed_array_definition():
# array indices aren't grouped with builtins, but make sure we can extract
# indentifer names
p = sqlparse.parse('x int, y int[], z int')[0]
- names = [x.get_name() for x in p.get_sublists()]
+ names = [x.get_name() for x in p.get_sublists()
+ if isinstance(x, sqlparse.sql.Identifier)]
assert names == ['x', 'y', 'z']