summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndi Albrecht <albrecht.andi@gmail.com>2009-05-14 06:54:06 +0200
committerAndi Albrecht <albrecht.andi@gmail.com>2009-05-14 06:54:06 +0200
commit5ccb54dae178189623b6223ea95e261046c6bb1a (patch)
treec9c7d9ae5aad46e54cddcafec8016ef0f4ed8f77
parentab7666e2f1cf5ef09a808eb187da4c4b3a93b1bb (diff)
downloadsqlparse-5ccb54dae178189623b6223ea95e261046c6bb1a.tar.gz
Make sure that operators in comparsions are not handled too lazy.
-rw-r--r--sqlparse/engine/grouping.py14
-rw-r--r--tests/test_grouping.py7
2 files changed, 18 insertions, 3 deletions
diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py
index fd6af67..181dae4 100644
--- a/sqlparse/engine/grouping.py
+++ b/sqlparse/engine/grouping.py
@@ -11,6 +11,7 @@ from sqlparse.sql import *
def _group_left_right(tlist, ttype, value, cls,
check_right=lambda t: True,
+ check_left = lambda t: True,
include_semicolon=False):
[_group_left_right(sgroup, ttype, value, cls, check_right,
include_semicolon) for sgroup in tlist.get_sublists()
@@ -20,8 +21,10 @@ def _group_left_right(tlist, ttype, value, cls,
while token:
right = tlist.token_next(tlist.token_index(token))
left = tlist.token_prev(tlist.token_index(token))
- if (right is None or not check_right(right)
- or left is None):
+ if right is None or not check_right(right):
+ token = tlist.token_next_match(tlist.token_index(token)+1,
+ ttype, value)
+ elif left is None or not check_right(left):
token = tlist.token_next_match(tlist.token_index(token)+1,
ttype, value)
else:
@@ -92,7 +95,12 @@ def group_assignment(tlist):
include_semicolon=True)
def group_comparsion(tlist):
- _group_left_right(tlist, T.Operator, None, Comparsion)
+ def _parts_valid(token):
+ return (token.ttype in (T.String.Symbol, T.Name, T.Number,
+ T.Number.Integer, T.Literal)
+ or isinstance(token, (Identifier,)))
+ _group_left_right(tlist, T.Operator, None, Comparsion,
+ check_left=_parts_valid, check_right=_parts_valid)
def group_case(tlist):
diff --git a/tests/test_grouping.py b/tests/test_grouping.py
index 51b50e1..d2f08fe 100644
--- a/tests/test_grouping.py
+++ b/tests/test_grouping.py
@@ -132,6 +132,13 @@ class TestGrouping(TestCaseBase):
self.ndiffAssertEqual(s, p.to_unicode())
self.assertEqual(p.tokens[4].get_alias(), 'view')
+ def test_comparsion_exclude(self):
+ # make sure operators are not handled too lazy
+ p = sqlparse.parse('(+)')[0]
+ self.assert_(isinstance(p.tokens[0], Parenthesis))
+ self.assert_(not isinstance(p.tokens[0].tokens[1], Comparsion))
+ p = sqlparse.parse('(a+1)')[0]
+ self.assert_(isinstance(p.tokens[0].tokens[1], Comparsion))
class TestStatement(TestCaseBase):