summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorprudhvi <prudhvi@pocketgems.com>2013-09-13 11:34:14 -0700
committerprudhvi <prudhvi@pocketgems.com>2013-09-13 13:12:11 -0700
commit9574c16bc7418e07e3a21a1c00e4ad5956b799c3 (patch)
tree5a7f76d66c1ba2d4438fb730816be1e0ceb68007
parent3b41501e850f3de9c0ac3c480bf63e73aa20a45d (diff)
downloadsqlparse-9574c16bc7418e07e3a21a1c00e4ad5956b799c3.tar.gz
Parenthesis, Functions and Arithmetic Expressions
are valid types to group an identifier
-rw-r--r--sqlparse/engine/grouping.py13
-rw-r--r--tests/test_grouping.py9
2 files changed, 18 insertions, 4 deletions
diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py
index 07f9392..4ba90aa 100644
--- a/sqlparse/engine/grouping.py
+++ b/sqlparse/engine/grouping.py
@@ -148,7 +148,9 @@ def group_identifier(tlist):
T.String.Single,
T.Name,
T.Wildcard,
- T.Literal.Number.Integer))))
+ T.Literal.Number.Integer,
+ T.Literal.Number.Float)
+ or isinstance(y, (sql.Parenthesis, sql.Function)))))
for t in tl.tokens[i:]:
# Don't take whitespaces into account.
if t.ttype is T.Whitespace:
@@ -163,8 +165,9 @@ def group_identifier(tlist):
# chooses the next token. if two tokens are found then the
# first is returned.
t1 = tl.token_next_by_type(
- i, (T.String.Symbol, T.String.Single, T.Name))
- t2 = tl.token_next_by_instance(i, sql.Function)
+ i, (T.String.Symbol, T.String.Single, T.Name, T.Literal.Number.Integer,
+ T.Literal.Number.Float))
+ t2 = tl.token_next_by_instance(i, (sql.Function, sql.Parenthesis))
if t1 and t2:
i1 = tl.token_index(t1)
i2 = tl.token_index(t2)
@@ -192,7 +195,9 @@ def group_identifier(tlist):
if identifier_tokens and identifier_tokens[-1].ttype is T.Whitespace:
identifier_tokens = identifier_tokens[:-1]
if not (len(identifier_tokens) == 1
- and isinstance(identifier_tokens[0], sql.Function)):
+ and (isinstance(identifier_tokens[0], (sql.Function, sql.Parenthesis))
+ or identifier_tokens[0].ttype in (T.Literal.Number.Integer,
+ T.Literal.Number.Float))):
group = tlist.group_tokens(sql.Identifier, identifier_tokens)
idx = tlist.token_index(group) + 1
else:
diff --git a/tests/test_grouping.py b/tests/test_grouping.py
index b22e543..2b7cefd 100644
--- a/tests/test_grouping.py
+++ b/tests/test_grouping.py
@@ -59,6 +59,15 @@ class TestGrouping(TestCaseBase):
self.assertEquals(types, [T.DML, T.Keyword, None,
T.Keyword, None, T.Punctuation])
+ s = "select 1.0*(a+b) as col, sum(c)/sum(d) from myschema.mytable"
+ parsed = sqlparse.parse(s)[0]
+ self.assertEqual(len(parsed.tokens), 7)
+ self.assert_(isinstance(parsed.tokens[2], sql.IdentifierList))
+ self.assertEqual(len(parsed.tokens[2].tokens), 4)
+ identifiers = list(parsed.tokens[2].get_identifiers())
+ self.assertEqual(len(identifiers), 2)
+ self.assertEquals(identifiers[0].get_alias(), u"col")
+
def test_identifier_wildcard(self):
p = sqlparse.parse('a.*, b.id')[0]
self.assert_(isinstance(p.tokens[0], sql.IdentifierList))