summaryrefslogtreecommitdiff
path: root/sqlparse
diff options
context:
space:
mode:
authorJesús Leganés Combarro "Piranna" <piranna@gmail.com>2012-07-15 13:03:42 +0200
committerJesús Leganés Combarro "Piranna" <piranna@gmail.com>2012-07-15 13:03:42 +0200
commit76361c90e4c2821ef34c1c454e2c63fbb89b09d8 (patch)
tree8e7e7a34dba2ebbb7e15b9b217d551fdd939dec0 /sqlparse
parent92757b0ee0c1543c14f28d3b045ff11ddf5f4297 (diff)
downloadsqlparse-76361c90e4c2821ef34c1c454e2c63fbb89b09d8.tar.gz
Added support for several dirpaths on IncludeStatement
Diffstat (limited to 'sqlparse')
-rw-r--r--sqlparse/filters.py80
1 files changed, 43 insertions, 37 deletions
diff --git a/sqlparse/filters.py b/sqlparse/filters.py
index c5165be..873073e 100644
--- a/sqlparse/filters.py
+++ b/sqlparse/filters.py
@@ -94,16 +94,35 @@ def StripWhitespace(stream):
class IncludeStatement:
"""Filter that enable a INCLUDE statement"""
- def __init__(self, dirpath=".", maxrecursive=10, raiseexceptions=False):
+ def __init__(self, dirpaths=None, maxrecursive=10, raiseexceptions=False):
+ if dirpaths == None:
+ dirpaths = ['.']
+ elif isinstance(dirpaths, basestring):
+ dirpaths = [dirpaths]
+
if maxrecursive <= 0:
raise ValueError('Max recursion limit reached')
- self.dirpath = abspath(dirpath)
+ self.dirpaths = map(abspath, dirpaths)
self.maxRecursive = maxrecursive
self.raiseexceptions = raiseexceptions
self.detected = False
+ def includefile(self, path):
+ with open(path) as f:
+ raw_sql = f.read()
+
+ # Create new FilterStack to parse readed file
+ # and add all its tokens to the main stack recursively
+ stack = FilterStack()
+ stack.preprocess.append(IncludeStatement(self.dirpaths,
+ self.maxRecursive - 1,
+ self.raiseexceptions))
+
+ for tv in stack.run(raw_sql):
+ yield tv
+
@memoize_generator
def process(self, stack, stream):
# Run over all tokens in the stream
@@ -123,45 +142,32 @@ class IncludeStatement:
if token_type in String.Symbol:
# if token_type in tokens.String.Symbol:
- # Get path of file to include
- path = join(self.dirpath, value[1:-1])
-
- try:
- f = open(path)
- raw_sql = f.read()
- f.close()
-
- # There was a problem loading the include file
- except IOError, err:
- # Raise the exception to the interpreter
- if self.raiseexceptions:
- raise
+ found = False
+ maxrecursionreached = None
- # Put the exception as a comment on the SQL code
- yield Comment, u'-- IOError: %s\n' % err
-
- else:
- # Create new FilterStack to parse readed file
- # and add all its tokens to the main stack recursively
+ # Get path of file to include
+ value = value[1:-1]
+ for path in self.dirpaths:
try:
- filtr = IncludeStatement(self.dirpath,
- self.maxRecursive - 1,
- self.raiseexceptions)
+ path = join(path, value)
- # Max recursion limit reached
- except ValueError, err:
- # Raise the exception to the interpreter
- if self.raiseexceptions:
- raise
+ except IOError, err:
+ continue
- # Put the exception as a comment on the SQL code
- yield Comment, u'-- ValueError: %s\n' % err
-
- stack = FilterStack()
- stack.preprocess.append(filtr)
-
- for tv in stack.run(raw_sql):
- yield tv
+ except ValueError, err:
+ maxrecursionreached = err
+ break
+
+ else:
+ for tv in self.includefile(path):
+ yield tv
+ found = True
+ break
+
+ if not found:
+ yield Comment, u'-- IOError: %s\n' % value
+ elif maxrecursionreached:
+ yield Comment, u'-- ValueError: %s\n' % maxrecursionreached
# Set normal mode
self.detected = False