diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-08-25 16:27:10 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-08-25 16:27:10 +0000 |
| commit | 8260ca2723ab3b08339ec9273fa729f70862fdf3 (patch) | |
| tree | 8b32cc35e8b63a16eb55e5f136888cba5d4356ea /lib/sqlalchemy/pool.py | |
| parent | 367e3b61a1031e51ffd13acbc71245088f5ed15a (diff) | |
| download | sqlalchemy-8260ca2723ab3b08339ec9273fa729f70862fdf3.tar.gz | |
- cleanup on connection methods + documentation. custom DBAPI
arguments specified in query string, 'connect_args' argument
to 'create_engine', or custom creation function via 'creator'
function to 'create_engine'.
- added "recycle" argument to Pool, is "pool_recycle" on create_engine,
defaults to 3600 seconds; connections after this age will be closed and
replaced with a new one, to handle db's that automatically close
stale connections [ticket:274]
Diffstat (limited to 'lib/sqlalchemy/pool.py')
| -rw-r--r-- | lib/sqlalchemy/pool.py | 135 |
1 files changed, 83 insertions, 52 deletions
diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py index f601dd9a2..211f96070 100644 --- a/lib/sqlalchemy/pool.py +++ b/lib/sqlalchemy/pool.py @@ -10,7 +10,12 @@ on a thread local basis. Also provides a DBAPI2 transparency layer so that pool be managed automatically, based on module type and connect arguments, simply by calling regular DBAPI connect() methods.""" -import weakref, string, cPickle +import weakref, string, time, sys +try: + import cPickle as pickle +except: + import pickle + from sqlalchemy import util, exceptions import sqlalchemy.queue as Queue @@ -71,40 +76,36 @@ def clear_managers(): class Pool(object): - def __init__(self, echo = False, use_threadlocal = True, logger=None): + def __init__(self, creator, recycle=-1, echo = False, use_threadlocal = True, logger=None): self._threadconns = weakref.WeakValueDictionary() + self._creator = creator + self._recycle = recycle self._use_threadlocal = use_threadlocal self.echo = echo self._logger = logger or util.Logger(origin='pool') def unique_connection(self): - return ConnectionFairy(self).checkout() - + return _ConnectionFairy(self).checkout() + + def create_connection(self): + return _ConnectionRecord(self) + def connect(self): if not self._use_threadlocal: - return ConnectionFairy(self).checkout() + return _ConnectionFairy(self).checkout() try: - return self._threadconns[thread.get_ident()].checkout() + return self._threadconns[thread.get_ident()].connfairy().checkout() except KeyError: - agent = ConnectionFairy(self).checkout() - self._threadconns[thread.get_ident()] = agent + agent = _ConnectionFairy(self).checkout() + self._threadconns[thread.get_ident()] = agent._threadfairy return agent - def _purge_for_threadlocal(self): - if self._use_threadlocal: - try: - del self._threadconns[thread.get_ident()] - except KeyError: - pass - def return_conn(self, agent): - self._purge_for_threadlocal() - self.do_return_conn(agent.connection) + self.do_return_conn(agent._connection_record) def return_invalid(self, agent): - self._purge_for_threadlocal() - self.do_return_invalid(agent.connection) + self.do_return_invalid(agent._connection_record) def get(self): return self.do_get() @@ -129,29 +130,60 @@ class Pool(object): def __del__(self): self.dispose() - -class ConnectionFairy(object): - def __init__(self, pool, connection=None): + +class _ConnectionRecord(object): + def __init__(self, pool): self.pool = pool - self.__counter = 0 - if connection is not None: - self.connection = connection - else: + self.connection = self.__connect() + def close(self): + self.connection.close() + def get_connection(self): + if self.pool._recycle > -1 and time.time() - self.starttime > self.pool._recycle: + self.pool.log("Connection %s exceeded timeout; recycling" % repr(self.connection)) try: - self.connection = pool.get() - except: - self.connection = None - self.pool.return_invalid(self) - raise - if self.pool.echo: - self.pool.log("Connection %s checked out from pool" % repr(self.connection)) + self.connection.close() + except Exception, e: + self.pool.log("Connection %s threw an error: %s" % (repr(self.connection), str(e))) + self.connection = self.__connect() + return self.connection + def __connect(self): + try: + self.starttime = time.time() + return self.pool._creator() + except: + raise + # TODO: reconnect support here ? + +class _ThreadFairy(object): + """marks a thread identifier as owning a connection, for a thread local pool.""" + def __init__(self, connfairy): + self.connfairy = weakref.ref(connfairy) + +class _ConnectionFairy(object): + """proxies a DBAPI connection object and provides return-on-dereference support""" + def __init__(self, pool): + self._threadfairy = _ThreadFairy(self) + self.__pool = pool + self.__counter = 0 + try: + self._connection_record = pool.get() + self.connection = self._connection_record.get_connection() + except: + self.connection = None # helps with endless __getattr__ loops later on + self._connection_record = None + self.__pool.return_invalid(self) + raise + if self.__pool.echo: + self.__pool.log("Connection %s checked out from pool" % repr(self.connection)) def invalidate(self): - if self.pool.echo: - self.pool.log("Invalidate connection %s" % repr(self.connection)) + if self.__pool.echo: + self.__pool.log("Invalidate connection %s" % repr(self.connection)) self.connection = None - self.pool.return_invalid(self) + self._connection_record = None + self._threadfairy = None + self.__pool.return_invalid(self) def cursor(self, *args, **kwargs): - return CursorFairy(self, self.connection.cursor(*args, **kwargs)) + return _CursorFairy(self, self.connection.cursor(*args, **kwargs)) def __getattr__(self, key): return getattr(self.connection, key) def checkout(self): @@ -167,20 +199,22 @@ class ConnectionFairy(object): self._close() def _close(self): if self.connection is not None: - if self.pool.echo: - self.pool.log("Connection %s being returned to pool" % repr(self.connection)) + if self.__pool.echo: + self.__pool.log("Connection %s being returned to pool" % repr(self.connection)) try: self.connection.rollback() except: # damn mysql -- (todo look for NotSupportedError) pass - self.pool.return_conn(self) - self.pool = None - self.connection = None + self.__pool.return_conn(self) + self.__pool = None + self.connection = None + self._connection_record = None + self._threadfairy = None -class CursorFairy(object): +class _CursorFairy(object): def __init__(self, parent, cursor): - self.parent = parent + self.__parent = parent self.cursor = cursor def __getattr__(self, key): return getattr(self.cursor, key) @@ -189,9 +223,8 @@ class SingletonThreadPool(Pool): """Maintains one connection per each thread, never moving to another thread. this is used for SQLite.""" def __init__(self, creator, pool_size=5, **params): - Pool.__init__(self, **params) + Pool.__init__(self, creator, **params) self._conns = {} - self._creator = creator self.size = pool_size def dispose(self): @@ -234,7 +267,7 @@ class SingletonThreadPool(Pool): try: return self._conns[thread.get_ident()] except KeyError: - c = self._creator() + c = self.create_connection() self._conns[thread.get_ident()] = c if len(self._conns) > self.size: self.cleanup() @@ -243,10 +276,8 @@ class SingletonThreadPool(Pool): class QueuePool(Pool): """uses Queue.Queue to maintain a fixed-size list of connections.""" def __init__(self, creator, pool_size = 5, max_overflow = 10, timeout=30, **params): - Pool.__init__(self, **params) - self._creator = creator + Pool.__init__(self, creator, **params) self._pool = Queue.Queue(pool_size) - self._overflow = 0 - pool_size self._max_overflow = max_overflow self._timeout = timeout @@ -268,7 +299,7 @@ class QueuePool(Pool): if self._max_overflow > -1 and self._overflow >= self._max_overflow: raise exceptions.TimeoutError("QueuePool limit of size %d overflow %d reached, connection timed out" % (self.size(), self.overflow())) self._overflow += 1 - return self._creator() + return self.create_connection() def dispose(self): while True: @@ -344,5 +375,5 @@ class DBProxy(object): pass def _serialize(self, *args, **params): - return cPickle.dumps([args, params]) + return pickle.dumps([args, params]) |
