diff options
Diffstat (limited to 'lib/sql.py')
-rw-r--r-- | lib/sql.py | 100 |
1 files changed, 98 insertions, 2 deletions
@@ -23,6 +23,8 @@ # FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public # License for more details. +import re +import collections from psycopg2 import extensions as ext @@ -187,8 +189,102 @@ class Placeholder(Composible): return "%s" -def compose(sql, args=()): - raise NotImplementedError +re_compose = re.compile(""" + % # percent sign + (?: + ([%s]) # either % or s + | \( ([^\)]+) \) s # or a (named)s placeholder (named captured) + ) + """, re.VERBOSE) + + +def compose(sql, args=None): + phs = list(re_compose.finditer(sql)) + + # check placeholders consistent + counts = {'%': 0, 's': 0, None: 0} + for ph in phs: + counts[ph.group(1)] += 1 + + npos = counts['s'] + nnamed = counts[None] + + if npos and nnamed: + raise ValueError( + "the sql string contains both named and positional placeholders") + + elif npos: + if not isinstance(args, collections.Sequence): + raise TypeError( + "the sql string expects values in a sequence, got %s instead" + % type(args).__name__) + + if len(args) != npos: + raise ValueError( + "the sql string expects %s values, got %s" % (npos, len(args))) + + return _compose_seq(sql, phs, args) + + elif nnamed: + if not isinstance(args, collections.Mapping): + raise TypeError( + "the sql string expects values in a mapping, got %s instead" + % type(args)) + + return _compose_map(sql, phs, args) + + else: + if not isinstance(args, collections.Sequence) and args: + raise TypeError( + "the sql string expects no value, got %s instead" % len(args)) + # If args are a mapping, no placeholder is an acceptable case + + # Convert %% into % + return _compose_seq(sql, phs, ()) + + +def _compose_seq(sql, phs, args): + rv = [] + j = 0 + for i, ph in enumerate(phs): + if i: + rv.append(SQL(sql[phs[i - 1].end():ph.start()])) + else: + rv.append(SQL(sql[0:ph.start()])) + + if ph.group(1) == 's': + rv.append(args[j]) + j += 1 + else: + rv.append(SQL('%')) + + if phs: + rv.append(SQL(sql[phs[-1].end():])) + else: + rv.append(sql) + + return Composed(rv) + + +def _compose_map(sql, phs, args): + rv = [] + for i, ph in enumerate(phs): + if i: + rv.append(SQL(sql[phs[i - 1].end():ph.start()])) + else: + rv.append(SQL(sql[0:ph.start()])) + + if ph.group(2): + rv.append(args[ph.group(2)]) + else: + rv.append(SQL('%')) + + if phs: + rv.append(SQL(sql[phs[-1].end():])) + else: + rv.append(sql) + + return Composed(rv) # Alias |