summaryrefslogtreecommitdiff
path: root/lib/extras.py
diff options
context:
space:
mode:
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>2017-02-03 04:56:02 +0000
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>2017-02-03 04:56:02 +0000
commitde8b335d80ecba0305e2b4796373a4064fd450b3 (patch)
tree26a5fab1a9f022adda5d5ad2de78b816fe0350e1 /lib/extras.py
parenta8a3a298f8ade3b0430ff2df0a5d5ee1fe920e3d (diff)
parentca42306d7916647448184907e03c77ff54ebd4f9 (diff)
downloadpsycopg2-sql-compose.tar.gz
Merge branch 'master' into sql-composesql-compose
Diffstat (limited to 'lib/extras.py')
-rw-r--r--lib/extras.py185
1 files changed, 165 insertions, 20 deletions
diff --git a/lib/extras.py b/lib/extras.py
index b59a2c7..38ca17a 100644
--- a/lib/extras.py
+++ b/lib/extras.py
@@ -106,18 +106,21 @@ class DictCursorBase(_cursor):
return res
def __iter__(self):
- if self._prefetch:
- res = super(DictCursorBase, self).__iter__()
- first = res.next()
- if self._query_executed:
- self._build_index()
- if not self._prefetch:
- res = super(DictCursorBase, self).__iter__()
- first = res.next()
-
- yield first
- while 1:
- yield res.next()
+ try:
+ if self._prefetch:
+ res = super(DictCursorBase, self).__iter__()
+ first = res.next()
+ if self._query_executed:
+ self._build_index()
+ if not self._prefetch:
+ res = super(DictCursorBase, self).__iter__()
+ first = res.next()
+
+ yield first
+ while 1:
+ yield res.next()
+ except StopIteration:
+ return
class DictConnection(_connection):
@@ -343,17 +346,20 @@ class NamedTupleCursor(_cursor):
return map(nt._make, ts)
def __iter__(self):
- it = super(NamedTupleCursor, self).__iter__()
- t = it.next()
+ try:
+ it = super(NamedTupleCursor, self).__iter__()
+ t = it.next()
- nt = self.Record
- if nt is None:
- nt = self.Record = self._make_nt()
+ nt = self.Record
+ if nt is None:
+ nt = self.Record = self._make_nt()
- yield nt._make(t)
+ yield nt._make(t)
- while 1:
- yield nt._make(it.next())
+ while 1:
+ yield nt._make(it.next())
+ except StopIteration:
+ return
try:
from collections import namedtuple
@@ -1135,3 +1141,142 @@ def register_composite(name, conn_or_curs, globally=False, factory=None):
caster.array_typecaster, not globally and conn_or_curs or None)
return caster
+
+
+def _paginate(seq, page_size):
+ """Consume an iterable and return it in chunks.
+
+ Every chunk is at most `page_size`. Never return an empty chunk.
+ """
+ page = []
+ it = iter(seq)
+ while 1:
+ try:
+ for i in xrange(page_size):
+ page.append(it.next())
+ yield page
+ page = []
+ except StopIteration:
+ if page:
+ yield page
+ return
+
+
+def execute_batch(cur, sql, argslist, page_size=100):
+ """Execute groups of statements in fewer server roundtrips.
+
+ Execute *sql* several times, against all parameters set (sequences or
+ mappings) found in *argslist*.
+
+ The function is semantically similar to
+
+ .. parsed-literal::
+
+ *cur*\.\ `~cursor.executemany`\ (\ *sql*\ , *argslist*\ )
+
+ but has a different implementation: Psycopg will join the statements into
+ fewer multi-statement commands, each one containing at most *page_size*
+ statements, resulting in a reduced number of server roundtrips.
+
+ """
+ for page in _paginate(argslist, page_size=page_size):
+ sqls = [cur.mogrify(sql, args) for args in page]
+ cur.execute(b";".join(sqls))
+
+
+def execute_values(cur, sql, argslist, template=None, page_size=100):
+ '''Execute a statement using :sql:`VALUES` with a sequence of parameters.
+
+ :param cur: the cursor to use to execute the query.
+
+ :param sql: the query to execute. It must contain a single ``%s``
+ placeholder, which will be replaced by a `VALUES list`__.
+ Example: ``"INSERT INTO mytable (id, f1, f2) VALUES %s"``.
+
+ :param argslist: sequence of sequences or dictionaries with the arguments
+ to send to the query. The type and content must be consistent with
+ *template*.
+
+ :param template: the snippet to merge to every item in *argslist* to
+ compose the query. If *argslist* items are sequences it should contain
+ positional placeholders (e.g. ``"(%s, %s, %s)"``, or ``"(%s, %s, 42)``"
+ if there are constants value...); If *argslist* is items are mapping
+ it should contain named placeholders (e.g. ``"(%(id)s, %(f1)s, 42)"``).
+ If not specified, assume the arguments are sequence and use a simple
+ positional template (i.e. ``(%s, %s, ...)``), with the number of
+ placeholders sniffed by the first element in *argslist*.
+
+ :param page_size: maximum number of *argslist* items to include in every
+ statement. If there are more items the function will execute more than
+ one statement.
+
+ .. __: https://www.postgresql.org/docs/current/static/queries-values.html
+
+ While :sql:`INSERT` is an obvious candidate for this function it is
+ possible to use it with other statements, for example::
+
+ >>> cur.execute(
+ ... "create table test (id int primary key, v1 int, v2 int)")
+
+ >>> execute_values(cur,
+ ... "INSERT INTO test (id, v1, v2) VALUES %s",
+ ... [(1, 2, 3), (4, 5, 6), (7, 8, 9)])
+
+ >>> execute_values(cur,
+ ... """UPDATE test SET v1 = data.v1 FROM (VALUES %s) AS data (id, v1)
+ ... WHERE test.id = data.id""",
+ ... [(1, 20), (4, 50)])
+
+ >>> cur.execute("select * from test order by id")
+ >>> cur.fetchall()
+ [(1, 20, 3), (4, 50, 6), (7, 8, 9)])
+
+ '''
+ # we can't just use sql % vals because vals is bytes: if sql is bytes
+ # there will be some decoding error because of stupid codec used, and Py3
+ # doesn't implement % on bytes.
+ if not isinstance(sql, bytes):
+ sql = sql.encode(_ext.encodings[cur.connection.encoding])
+ pre, post = _split_sql(sql)
+
+ for page in _paginate(argslist, page_size=page_size):
+ if template is None:
+ template = b'(' + b','.join([b'%s'] * len(page[0])) + b')'
+ parts = pre[:]
+ for args in page:
+ parts.append(cur.mogrify(template, args))
+ parts.append(b',')
+ parts[-1:] = post
+ cur.execute(b''.join(parts))
+
+
+def _split_sql(sql):
+ """Split *sql* on a single ``%s`` placeholder.
+
+ Split on the %s, perform %% replacement and return pre, post lists of
+ snippets.
+ """
+ curr = pre = []
+ post = []
+ tokens = _re.split(br'(%.)', sql)
+ for token in tokens:
+ if len(token) != 2 or token[:1] != b'%':
+ curr.append(token)
+ continue
+
+ if token[1:] == b's':
+ if curr is pre:
+ curr = post
+ else:
+ raise ValueError(
+ "the query contains more than one '%s' placeholder")
+ elif token[1:] == b'%':
+ curr.append(b'%')
+ else:
+ raise ValueError("unsupported format character: '%s'"
+ % token[1:].decode('ascii', 'replace'))
+
+ if curr is pre:
+ raise ValueError("the query doesn't contain any '%s' placeholder")
+
+ return pre, post