diff options
author | Daniele Varrazzo <daniele.varrazzo@gmail.com> | 2017-02-02 17:29:17 +0000 |
---|---|---|
committer | Daniele Varrazzo <daniele.varrazzo@gmail.com> | 2017-02-02 17:29:17 +0000 |
commit | dc1b4fff9001964c719e3f4471cc5a6fe6533e3a (patch) | |
tree | f3ce470b63ba65b21c963d8b6d47c87aa5b22cf4 /lib/extras.py | |
parent | d2fdc5ca9f6d5ac76ee39fc6b7db626345a6c84c (diff) | |
download | psycopg2-dc1b4fff9001964c719e3f4471cc5a6fe6533e3a.tar.gz |
Avoid an useless encode/decode roundtrip in execute_values()
Tests moved into a separate module.
Diffstat (limited to 'lib/extras.py')
-rw-r--r-- | lib/extras.py | 48 |
1 files changed, 44 insertions, 4 deletions
diff --git a/lib/extras.py b/lib/extras.py index 1aad3d1..80034e6 100644 --- a/lib/extras.py +++ b/lib/extras.py @@ -1232,10 +1232,50 @@ def execute_values(cur, sql, argslist, template=None, page_size=100): [(1, 20, 3), (4, 50, 6), (7, 8, 9)]) ''' + # we can't just use sql % vals because vals is bytes: if sql is bytes + # there will be some decoding error because of stupid codec used, and Py3 + # doesn't implement % on bytes. + if not isinstance(sql, bytes): + sql = sql.encode(_ext.encodings[cur.connection.encoding]) + pre, post = _split_sql(sql) + for page in _paginate(argslist, page_size=page_size): if template is None: template = '(%s)' % ','.join(['%s'] * len(page[0])) - values = b",".join(cur.mogrify(template, args) for args in page) - if isinstance(values, bytes): - values = values.decode(_ext.encodings[cur.connection.encoding]) - cur.execute(sql % (values,)) + parts = [pre] + for args in page: + parts.append(cur.mogrify(template, args)) + parts.append(b',') + parts[-1] = post + cur.execute(b''.join(parts)) + + +def _split_sql(sql): + """Split *sql* on a single ``%s`` placeholder. + + Return a (pre, post) pair around the ``%s``, with ``%%`` -> ``%`` replacement. + """ + curr = pre = [] + post = [] + tokens = _re.split(br'(%.)', sql) + for token in tokens: + if len(token) != 2 or token[:1] != b'%': + curr.append(token) + continue + + if token[1:] == b's': + if curr is pre: + curr = post + else: + raise ValueError( + "the query contains more than one '%s' placeholder") + elif token[1:] == b'%': + curr.append(b'%') + else: + raise ValueError("unsupported format character: '%s'" + % token[1:].decode('ascii', 'replace')) + + if curr is pre: + raise ValueError("the query doesn't contain any '%s' placeholder") + + return b''.join(pre), b''.join(post) |