summaryrefslogtreecommitdiff
path: root/lib/extras.py
diff options
context:
space:
mode:
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>2017-02-02 17:29:17 +0000
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>2017-02-02 17:29:17 +0000
commitdc1b4fff9001964c719e3f4471cc5a6fe6533e3a (patch)
treef3ce470b63ba65b21c963d8b6d47c87aa5b22cf4 /lib/extras.py
parentd2fdc5ca9f6d5ac76ee39fc6b7db626345a6c84c (diff)
downloadpsycopg2-dc1b4fff9001964c719e3f4471cc5a6fe6533e3a.tar.gz
Avoid an useless encode/decode roundtrip in execute_values()
Tests moved into a separate module.
Diffstat (limited to 'lib/extras.py')
-rw-r--r--lib/extras.py48
1 files changed, 44 insertions, 4 deletions
diff --git a/lib/extras.py b/lib/extras.py
index 1aad3d1..80034e6 100644
--- a/lib/extras.py
+++ b/lib/extras.py
@@ -1232,10 +1232,50 @@ def execute_values(cur, sql, argslist, template=None, page_size=100):
[(1, 20, 3), (4, 50, 6), (7, 8, 9)])
'''
+ # we can't just use sql % vals because vals is bytes: if sql is bytes
+ # there will be some decoding error because of stupid codec used, and Py3
+ # doesn't implement % on bytes.
+ if not isinstance(sql, bytes):
+ sql = sql.encode(_ext.encodings[cur.connection.encoding])
+ pre, post = _split_sql(sql)
+
for page in _paginate(argslist, page_size=page_size):
if template is None:
template = '(%s)' % ','.join(['%s'] * len(page[0]))
- values = b",".join(cur.mogrify(template, args) for args in page)
- if isinstance(values, bytes):
- values = values.decode(_ext.encodings[cur.connection.encoding])
- cur.execute(sql % (values,))
+ parts = [pre]
+ for args in page:
+ parts.append(cur.mogrify(template, args))
+ parts.append(b',')
+ parts[-1] = post
+ cur.execute(b''.join(parts))
+
+
+def _split_sql(sql):
+ """Split *sql* on a single ``%s`` placeholder.
+
+ Return a (pre, post) pair around the ``%s``, with ``%%`` -> ``%`` replacement.
+ """
+ curr = pre = []
+ post = []
+ tokens = _re.split(br'(%.)', sql)
+ for token in tokens:
+ if len(token) != 2 or token[:1] != b'%':
+ curr.append(token)
+ continue
+
+ if token[1:] == b's':
+ if curr is pre:
+ curr = post
+ else:
+ raise ValueError(
+ "the query contains more than one '%s' placeholder")
+ elif token[1:] == b'%':
+ curr.append(b'%')
+ else:
+ raise ValueError("unsupported format character: '%s'"
+ % token[1:].decode('ascii', 'replace'))
+
+ if curr is pre:
+ raise ValueError("the query doesn't contain any '%s' placeholder")
+
+ return b''.join(pre), b''.join(post)