diff options
Diffstat (limited to 'lib/sqlalchemy/engine')
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 5 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/strategies.py | 71 |
2 files changed, 39 insertions, 37 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 6bef1fabd..f73ede756 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -16,6 +16,8 @@ import base class PoolConnectionProvider(base.ConnectionProvider): def __init__(self, dialect, url, poolclass=None, pool=None, **kwargs): (cargs, cparams) = dialect.create_connect_args(url) + cparams.update(kwargs.pop('connect_args', {})) + if pool is None: kwargs.setdefault('echo', False) kwargs.setdefault('use_threadlocal',True) @@ -29,7 +31,8 @@ class PoolConnectionProvider(base.ConnectionProvider): return dbapi.connect(*cargs, **cparams) except Exception, e: raise exceptions.DBAPIError("Connection failed", e) - self._pool = poolclass(connect, **kwargs) + creator = kwargs.pop('creator', connect) + self._pool = poolclass(creator, **kwargs) else: if isinstance(pool, sqlalchemy.pool.DBProxy): self._pool = pool.get_pool(*cargs, **cparams) diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index e2f5c8b7c..716e5ffb9 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -20,54 +20,53 @@ class EngineStrategy(object): def create(self, *args, **kwargs): """given arguments, returns a new sql.Engine instance.""" raise NotImplementedError() - -class PlainEngineStrategy(EngineStrategy): - def __init__(self): - EngineStrategy.__init__(self, 'plain') - def create(self, name_or_url, **kwargs): +class DefaultEngineStrategy(EngineStrategy): + def create(self, name_or_url, **kwargs): u = url.make_url(name_or_url) module = u.get_module() - args = u.query.copy() - args.update(kwargs) - dialect = module.dialect(**args) + dialect = module.dialect(**kwargs) poolargs = {} - for key in (('echo_pool', 'echo'), ('pool_size', 'pool_size'), ('max_overflow', 'max_overflow'), ('poolclass', 'poolclass'), ('pool_timeout','timeout'), ('pool', 'pool')): - if kwargs.has_key(key[0]): - poolargs[key[1]] = kwargs[key[0]] + for key in (('echo_pool', 'echo'), ('pool_size', 'pool_size'), ('max_overflow', 'max_overflow'), ('poolclass', 'poolclass'), ('pool_timeout','timeout'), ('pool', 'pool'), ('pool_recycle','recycle'),('connect_args', 'connect_args'), ('creator', 'creator')): + if kwargs.has_key(key[0]): + poolargs[key[1]] = kwargs[key[0]] poolclass = getattr(module, 'poolclass', None) if poolclass is not None: - poolargs.setdefault('poolclass', poolclass) - poolargs['use_threadlocal'] = False - provider = default.PoolConnectionProvider(dialect, u, **poolargs) + poolargs.setdefault('poolclass', poolclass) + poolargs['use_threadlocal'] = self.pool_threadlocal() + provider = self.get_pool_provider(dialect, u, **poolargs) - return base.ComposedSQLEngine(provider, dialect, **args) -PlainEngineStrategy() + return self.get_engine(provider, dialect, **kwargs) -class ThreadLocalEngineStrategy(EngineStrategy): + def pool_threadlocal(self): + raise NotImplementedError() + def get_pool_provider(self, dialect, url, **kwargs): + raise NotImplementedError() + def get_engine(self, provider, dialect, **kwargs): + raise NotImplementedError() + +class PlainEngineStrategy(DefaultEngineStrategy): def __init__(self): - EngineStrategy.__init__(self, 'threadlocal') - def create(self, name_or_url, **kwargs): - u = url.make_url(name_or_url) - module = u.get_module() - - args = u.query.copy() - args.update(kwargs) - dialect = module.dialect(**args) - - poolargs = {} - for key in (('echo_pool', 'echo'), ('pool_size', 'pool_size'), ('max_overflow', 'max_overflow'), ('poolclass', 'poolclass'), ('pool_timeout','timeout'), ('pool', 'pool')): - if kwargs.has_key(key[0]): - poolargs[key[1]] = kwargs[key[0]] - poolclass = getattr(module, 'poolclass', None) - if poolclass is not None: - poolargs.setdefault('poolclass', poolclass) - poolargs['use_threadlocal'] = True - provider = threadlocal.TLocalConnectionProvider(dialect, u, **poolargs) + DefaultEngineStrategy.__init__(self, 'plain') + def pool_threadlocal(self): + return False + def get_pool_provider(self, dialect, url, **poolargs): + return default.PoolConnectionProvider(dialect, url, **poolargs) + def get_engine(self, provider, dialect, **kwargs): + return base.ComposedSQLEngine(provider, dialect, **kwargs) +PlainEngineStrategy() - return threadlocal.TLEngine(provider, dialect, **args) +class ThreadLocalEngineStrategy(DefaultEngineStrategy): + def __init__(self): + DefaultEngineStrategy.__init__(self, 'threadlocal') + def pool_threadlocal(self): + return True + def get_pool_provider(self, dialect, url, **poolargs): + return threadlocal.TLocalConnectionProvider(dialect, url, **poolargs) + def get_engine(self, provider, dialect, **kwargs): + return threadlocal.TLEngine(provider, dialect, **kwargs) ThreadLocalEngineStrategy() |
