summaryrefslogtreecommitdiff
path: root/sqlparse
diff options
context:
space:
mode:
authorAdam Greenhall <agreenhall@lyft.com>2015-09-11 21:58:17 -0700
committerVictor Uriarte <victor.m.uriarte@intel.com>2016-06-06 06:31:35 -0700
commit9ad0acafabd8c8216fdacb71310f6ec56ef59ae9 (patch)
tree6704b2b820eae97bfe10f522d8db6c131ddb25cc /sqlparse
parent7bc47f0ab6a83aeeb98906b208cfba03c89bd7bd (diff)
downloadsqlparse-9ad0acafabd8c8216fdacb71310f6ec56ef59ae9.tar.gz
Fix Case statements
Diffstat (limited to 'sqlparse')
-rw-r--r--sqlparse/filters.py29
-rw-r--r--sqlparse/sql.py5
2 files changed, 33 insertions, 1 deletions
diff --git a/sqlparse/filters.py b/sqlparse/filters.py
index 78c8d34..dad754e 100644
--- a/sqlparse/filters.py
+++ b/sqlparse/filters.py
@@ -391,8 +391,37 @@ class AlignedIndentFilter:
# if not last column in select, add a comma seperator
new_tokens.append(sql.Token(T.Punctuation, ','))
tlist.tokens = new_tokens
+
+ # process any sub-sub statements (like case statements)
+ for sgroup in tlist.get_sublists():
+ self._process(sgroup, base_indent=base_indent)
return tlist
+ def _process_case(self, tlist, base_indent=0):
+ base_offset = base_indent + self._max_kwd_len + len('case ')
+ case_offset = len('when ')
+ cases = tlist.get_cases(skip_ws=True)
+ # align the end as well
+ end_token = tlist.token_next_match(0, T.Keyword, 'END')
+ cases.append((None, [end_token]))
+
+ condition_width = max(len(str(cond)) for cond, value in cases)
+ for i, (cond, value) in enumerate(cases):
+ if cond is None: # else or end
+ stmt = value[0]
+ line = value
+ else:
+ stmt = cond[0]
+ line = cond + value
+ if i > 0:
+ tlist.insert_before(stmt, self.whitespace(base_offset + case_offset - len(str(stmt))))
+ if cond:
+ tlist.insert_after(cond[-1], self.whitespace(condition_width - len(str(cond))))
+
+ if i < len(cases) - 1:
+ # if not the END add a newline
+ tlist.insert_after(line[-1], self.newline())
+
def _process_substatement(self, tlist, base_indent=0):
def _next_token(i):
t = tlist.token_next_match(i, T.Keyword, self.split_words, regex=True)
diff --git a/sqlparse/sql.py b/sqlparse/sql.py
index 57bf1e7..daa5cf5 100644
--- a/sqlparse/sql.py
+++ b/sqlparse/sql.py
@@ -538,7 +538,7 @@ class Case(TokenList):
M_OPEN = T.Keyword, 'CASE'
M_CLOSE = T.Keyword, 'END'
- def get_cases(self):
+ def get_cases(self, skip_ws=False):
"""Returns a list of 2-tuples (condition, value).
If an ELSE exists condition is None.
@@ -554,6 +554,9 @@ class Case(TokenList):
if token.match(T.Keyword, 'CASE'):
continue
+ elif skip_ws and token.ttype in T.Whitespace:
+ continue
+
elif token.match(T.Keyword, 'WHEN'):
ret.append(([], []))
mode = CONDITION