diff options
author | Daniele Varrazzo <daniele.varrazzo@gmail.com> | 2012-04-11 17:59:16 +0100 |
---|---|---|
committer | Daniele Varrazzo <daniele.varrazzo@gmail.com> | 2012-04-11 17:59:16 +0100 |
commit | c86ca7687fe66f01a1b645ebacad735f1b693a56 (patch) | |
tree | ab6d404cea335f28bcb27039b2e330c0a6cf7dec /lib/extras.py | |
parent | 095cce5605cc3bf460298c954ab6a90455e2a81a (diff) | |
download | psycopg2-c86ca7687fe66f01a1b645ebacad735f1b693a56.tar.gz |
Fixed cursor() arguments propagation to other connection classes
Diffstat (limited to 'lib/extras.py')
-rw-r--r-- | lib/extras.py | 46 |
1 files changed, 19 insertions, 27 deletions
diff --git a/lib/extras.py b/lib/extras.py index b1d4d9e..1560edd 100644 --- a/lib/extras.py +++ b/lib/extras.py @@ -104,8 +104,7 @@ class DictCursorBase(_cursor): class DictConnection(_connection): """A connection that uses `DictCursor` automatically.""" def cursor(self, *args, **kwargs): - if 'cursor_factory' not in kwargs: - kwargs['cursor_factory'] = DictCursor + kwargs.setdefault('cursor_factory', DictCursor) return _connection.cursor(self, *args, **kwargs) class DictCursor(DictCursorBase): @@ -196,8 +195,7 @@ class DictRow(list): class RealDictConnection(_connection): """A connection that uses `RealDictCursor` automatically.""" def cursor(self, *args, **kwargs): - if 'cursor_factory' not in kwargs: - kwargs['cursor_factory'] = RealDictCursor + kwargs.setdefault('cursor_factory', RealDictCursor) return _connection.cursor(self, *args, **kwargs) class RealDictCursor(DictCursorBase): @@ -252,8 +250,7 @@ class RealDictRow(dict): class NamedTupleConnection(_connection): """A connection that uses `NamedTupleCursor` automatically.""" def cursor(self, *args, **kwargs): - if 'cursor_factory' not in kwargs: - kwargs['cursor_factory'] = NamedTupleCursor + kwargs.setdefault('cursor_factory', NamedTupleCursor) return _connection.cursor(self, *args, **kwargs) class NamedTupleCursor(_cursor): @@ -348,7 +345,7 @@ class LoggingConnection(_connection): self.log = self._logtologger else: self.log = self._logtofile - + def filter(self, msg, curs): """Filter the query before logging it. @@ -357,26 +354,24 @@ class LoggingConnection(_connection): just does nothing. """ return msg - + def _logtofile(self, msg, curs): msg = self.filter(msg, curs) if msg: self._logobj.write(msg + os.linesep) - + def _logtologger(self, msg, curs): msg = self.filter(msg, curs) if msg: self._logobj.debug(msg) - + def _check(self): if not hasattr(self, '_logobj'): raise self.ProgrammingError( "LoggingConnection object has not been initialize()d") - - def cursor(self, name=None): + + def cursor(self, *args, **kwargs): self._check() - if name is None: - return _connection.cursor(self, cursor_factory=LoggingCursor) - else: - return _connection.cursor(self, name, cursor_factory=LoggingCursor) + kwargs.setdefault('cursor_factory', LoggingCursor) + return _connection.cursor(self, *args, **kwargs) class LoggingCursor(_cursor): """A cursor that logs queries using its connection logging facilities.""" @@ -389,19 +384,19 @@ class LoggingCursor(_cursor): def callproc(self, procname, vars=None): try: - return _cursor.callproc(self, procname, vars) + return _cursor.callproc(self, procname, vars) finally: self.connection.log(self.query, self) class MinTimeLoggingConnection(LoggingConnection): """A connection that logs queries based on execution time. - + This is just an example of how to sub-class `LoggingConnection` to provide some extra filtering for the logged queries. Both the `inizialize()` and `filter()` methods are overwritten to make sure that only queries executing for more than ``mintime`` ms are logged. - + Note that this connection uses the specialized cursor `MinTimeLoggingCursor`. """ @@ -414,20 +409,17 @@ class MinTimeLoggingConnection(LoggingConnection): if t > self._mintime: return msg + os.linesep + " (execution time: %d ms)" % t - def cursor(self, name=None): - self._check() - if name is None: - return _connection.cursor(self, cursor_factory=MinTimeLoggingCursor) - else: - return _connection.cursor(self, name, cursor_factory=MinTimeLoggingCursor) - + def cursor(self, *args, **kwargs): + kwargs.setdefault('cursor_factory', MinTimeLoggingCursor) + return LoggingConnection.cursor(self, *args, **kwargs) + class MinTimeLoggingCursor(LoggingCursor): """The cursor sub-class companion to `MinTimeLoggingConnection`.""" def execute(self, query, vars=None): self.timestamp = time.time() return LoggingCursor.execute(self, query, vars) - + def callproc(self, procname, vars=None): self.timestamp = time.time() return LoggingCursor.execute(self, procname, vars) |