diff options
author | Daniele Varrazzo <daniele.varrazzo@gmail.com> | 2017-02-03 04:56:02 +0000 |
---|---|---|
committer | Daniele Varrazzo <daniele.varrazzo@gmail.com> | 2017-02-03 04:56:02 +0000 |
commit | de8b335d80ecba0305e2b4796373a4064fd450b3 (patch) | |
tree | 26a5fab1a9f022adda5d5ad2de78b816fe0350e1 /lib | |
parent | a8a3a298f8ade3b0430ff2df0a5d5ee1fe920e3d (diff) | |
parent | ca42306d7916647448184907e03c77ff54ebd4f9 (diff) | |
download | psycopg2-sql-compose.tar.gz |
Merge branch 'master' into sql-composesql-compose
Diffstat (limited to 'lib')
-rw-r--r-- | lib/__init__.py | 14 | ||||
-rw-r--r-- | lib/extras.py | 185 |
2 files changed, 175 insertions, 24 deletions
diff --git a/lib/__init__.py b/lib/__init__.py index fb22b4c..492b924 100644 --- a/lib/__init__.py +++ b/lib/__init__.py @@ -82,8 +82,7 @@ else: del Decimal, Adapter -def connect(dsn=None, connection_factory=None, cursor_factory=None, - async=False, **kwargs): +def connect(dsn=None, connection_factory=None, cursor_factory=None, **kwargs): """ Create a new database connection. @@ -111,17 +110,24 @@ def connect(dsn=None, connection_factory=None, cursor_factory=None, Using the *cursor_factory* parameter, a new default cursor factory will be used by cursor(). - Using *async*=True an asynchronous connection will be created. + Using *async*=True an asynchronous connection will be created. *async_* is + a valid alias (for Python versions where ``async`` is a keyword). Any other keyword parameter will be passed to the underlying client library: the list of supported parameters depends on the library version. """ + kwasync = {} + if 'async' in kwargs: + kwasync['async'] = kwargs.pop('async') + if 'async_' in kwargs: + kwasync['async_'] = kwargs.pop('async_') + if dsn is None and not kwargs: raise TypeError('missing dsn and no parameters') dsn = _ext.make_dsn(dsn, **kwargs) - conn = _connect(dsn, connection_factory=connection_factory, async=async) + conn = _connect(dsn, connection_factory=connection_factory, **kwasync) if cursor_factory is not None: conn.cursor_factory = cursor_factory diff --git a/lib/extras.py b/lib/extras.py index b59a2c7..38ca17a 100644 --- a/lib/extras.py +++ b/lib/extras.py @@ -106,18 +106,21 @@ class DictCursorBase(_cursor): return res def __iter__(self): - if self._prefetch: - res = super(DictCursorBase, self).__iter__() - first = res.next() - if self._query_executed: - self._build_index() - if not self._prefetch: - res = super(DictCursorBase, self).__iter__() - first = res.next() - - yield first - while 1: - yield res.next() + try: + if self._prefetch: + res = super(DictCursorBase, self).__iter__() + first = res.next() + if self._query_executed: + self._build_index() + if not self._prefetch: + res = super(DictCursorBase, self).__iter__() + first = res.next() + + yield first + while 1: + yield res.next() + except StopIteration: + return class DictConnection(_connection): @@ -343,17 +346,20 @@ class NamedTupleCursor(_cursor): return map(nt._make, ts) def __iter__(self): - it = super(NamedTupleCursor, self).__iter__() - t = it.next() + try: + it = super(NamedTupleCursor, self).__iter__() + t = it.next() - nt = self.Record - if nt is None: - nt = self.Record = self._make_nt() + nt = self.Record + if nt is None: + nt = self.Record = self._make_nt() - yield nt._make(t) + yield nt._make(t) - while 1: - yield nt._make(it.next()) + while 1: + yield nt._make(it.next()) + except StopIteration: + return try: from collections import namedtuple @@ -1135,3 +1141,142 @@ def register_composite(name, conn_or_curs, globally=False, factory=None): caster.array_typecaster, not globally and conn_or_curs or None) return caster + + +def _paginate(seq, page_size): + """Consume an iterable and return it in chunks. + + Every chunk is at most `page_size`. Never return an empty chunk. + """ + page = [] + it = iter(seq) + while 1: + try: + for i in xrange(page_size): + page.append(it.next()) + yield page + page = [] + except StopIteration: + if page: + yield page + return + + +def execute_batch(cur, sql, argslist, page_size=100): + """Execute groups of statements in fewer server roundtrips. + + Execute *sql* several times, against all parameters set (sequences or + mappings) found in *argslist*. + + The function is semantically similar to + + .. parsed-literal:: + + *cur*\.\ `~cursor.executemany`\ (\ *sql*\ , *argslist*\ ) + + but has a different implementation: Psycopg will join the statements into + fewer multi-statement commands, each one containing at most *page_size* + statements, resulting in a reduced number of server roundtrips. + + """ + for page in _paginate(argslist, page_size=page_size): + sqls = [cur.mogrify(sql, args) for args in page] + cur.execute(b";".join(sqls)) + + +def execute_values(cur, sql, argslist, template=None, page_size=100): + '''Execute a statement using :sql:`VALUES` with a sequence of parameters. + + :param cur: the cursor to use to execute the query. + + :param sql: the query to execute. It must contain a single ``%s`` + placeholder, which will be replaced by a `VALUES list`__. + Example: ``"INSERT INTO mytable (id, f1, f2) VALUES %s"``. + + :param argslist: sequence of sequences or dictionaries with the arguments + to send to the query. The type and content must be consistent with + *template*. + + :param template: the snippet to merge to every item in *argslist* to + compose the query. If *argslist* items are sequences it should contain + positional placeholders (e.g. ``"(%s, %s, %s)"``, or ``"(%s, %s, 42)``" + if there are constants value...); If *argslist* is items are mapping + it should contain named placeholders (e.g. ``"(%(id)s, %(f1)s, 42)"``). + If not specified, assume the arguments are sequence and use a simple + positional template (i.e. ``(%s, %s, ...)``), with the number of + placeholders sniffed by the first element in *argslist*. + + :param page_size: maximum number of *argslist* items to include in every + statement. If there are more items the function will execute more than + one statement. + + .. __: https://www.postgresql.org/docs/current/static/queries-values.html + + While :sql:`INSERT` is an obvious candidate for this function it is + possible to use it with other statements, for example:: + + >>> cur.execute( + ... "create table test (id int primary key, v1 int, v2 int)") + + >>> execute_values(cur, + ... "INSERT INTO test (id, v1, v2) VALUES %s", + ... [(1, 2, 3), (4, 5, 6), (7, 8, 9)]) + + >>> execute_values(cur, + ... """UPDATE test SET v1 = data.v1 FROM (VALUES %s) AS data (id, v1) + ... WHERE test.id = data.id""", + ... [(1, 20), (4, 50)]) + + >>> cur.execute("select * from test order by id") + >>> cur.fetchall() + [(1, 20, 3), (4, 50, 6), (7, 8, 9)]) + + ''' + # we can't just use sql % vals because vals is bytes: if sql is bytes + # there will be some decoding error because of stupid codec used, and Py3 + # doesn't implement % on bytes. + if not isinstance(sql, bytes): + sql = sql.encode(_ext.encodings[cur.connection.encoding]) + pre, post = _split_sql(sql) + + for page in _paginate(argslist, page_size=page_size): + if template is None: + template = b'(' + b','.join([b'%s'] * len(page[0])) + b')' + parts = pre[:] + for args in page: + parts.append(cur.mogrify(template, args)) + parts.append(b',') + parts[-1:] = post + cur.execute(b''.join(parts)) + + +def _split_sql(sql): + """Split *sql* on a single ``%s`` placeholder. + + Split on the %s, perform %% replacement and return pre, post lists of + snippets. + """ + curr = pre = [] + post = [] + tokens = _re.split(br'(%.)', sql) + for token in tokens: + if len(token) != 2 or token[:1] != b'%': + curr.append(token) + continue + + if token[1:] == b's': + if curr is pre: + curr = post + else: + raise ValueError( + "the query contains more than one '%s' placeholder") + elif token[1:] == b'%': + curr.append(b'%') + else: + raise ValueError("unsupported format character: '%s'" + % token[1:].decode('ascii', 'replace')) + + if curr is pre: + raise ValueError("the query doesn't contain any '%s' placeholder") + + return pre, post |