diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/engine/base.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 27 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/strategies.py | 91 | ||||
| -rw-r--r-- | lib/sqlalchemy/util.py | 13 |
4 files changed, 84 insertions, 49 deletions
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 07a88659b..205e5aa02 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -396,7 +396,7 @@ class Engine(sql.Executor, Connectable): Connects a ConnectionProvider, a Dialect and a CompilerFactory together to provide a default implementation of SchemaEngine. """ - def __init__(self, connection_provider, dialect, echo=None, **kwargs): + def __init__(self, connection_provider, dialect, echo=None): self.connection_provider = connection_provider self.dialect=dialect self.echo = echo diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 02d3e4608..4af539e78 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -6,7 +6,6 @@ from sqlalchemy import schema, exceptions, util, sql, types -from sqlalchemy import pool as poollib import StringIO, sys, re from sqlalchemy.engine import base @@ -14,30 +13,8 @@ from sqlalchemy.engine import base class PoolConnectionProvider(base.ConnectionProvider): - def __init__(self, dialect, url, poolclass=None, pool=None, **kwargs): - (cargs, cparams) = dialect.create_connect_args(url) - cparams.update(kwargs.pop('connect_args', {})) - - if pool is None: - kwargs.setdefault('echo', False) - kwargs.setdefault('use_threadlocal',True) - if poolclass is None: - poolclass = poollib.QueuePool - dbapi = dialect.dbapi() - if dbapi is None: - raise exceptions.InvalidRequestError("Cant get DBAPI module for dialect '%s'" % dialect) - def connect(): - try: - return dbapi.connect(*cargs, **cparams) - except Exception, e: - raise exceptions.DBAPIError("Connection failed", e) - creator = kwargs.pop('creator', connect) - self._pool = poolclass(creator, **kwargs) - else: - if isinstance(pool, poollib.DBProxy): - self._pool = pool.get_pool(*cargs, **cparams) - else: - self._pool = pool + def __init__(self, pool): + self._pool = pool def get_connection(self): return self._pool.connect() def dispose(self): diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index fe30aeb8d..d48412160 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -6,6 +6,8 @@ this can be accomplished via a mod; see the sqlalchemy/mods package for details. from sqlalchemy.engine import base, default, threadlocal, url +from sqlalchemy import util, exceptions +from sqlalchemy import pool as poollib strategies = {} @@ -22,29 +24,74 @@ class EngineStrategy(object): raise NotImplementedError() class DefaultEngineStrategy(EngineStrategy): - def create(self, name_or_url, **kwargs): + def create(self, name_or_url, **kwargs): + # create url.URL object u = url.make_url(name_or_url) + + # get module from sqlalchemy.databases module = u.get_module() - dialect = module.dialect(**kwargs) + dialect_args = {} + # consume dialect arguments from kwargs + for k in util.get_cls_kwargs(module.dialect): + if k in kwargs: + dialect_args[k] = kwargs.pop(k) + + # create dialect + dialect = module.dialect(**dialect_args) - poolargs = {} - for key in (('echo_pool', 'echo'), ('pool_size', 'pool_size'), ('max_overflow', 'max_overflow'), ('poolclass', 'poolclass'), ('pool_timeout','timeout'), ('pool', 'pool'), ('pool_recycle','recycle'),('connect_args', 'connect_args'), ('creator', 'creator')): - if kwargs.has_key(key[0]): - poolargs[key[1]] = kwargs[key[0]] - poolclass = getattr(module, 'poolclass', None) - if poolclass is not None: - poolargs.setdefault('poolclass', poolclass) - poolargs['use_threadlocal'] = self.pool_threadlocal() - provider = self.get_pool_provider(dialect, u, **poolargs) + # assemble connection arguments + (cargs, cparams) = dialect.create_connect_args(u) + cparams.update(kwargs.pop('connect_args', {})) - return self.get_engine(provider, dialect, **kwargs) + # look for existing pool or create + pool = kwargs.pop('pool', None) + if pool is None: + dbapi = kwargs.pop('module', dialect.dbapi()) + if dbapi is None: + raise exceptions.InvalidRequestError("Cant get DBAPI module for dialect '%s'" % dialect) + def connect(): + try: + return dbapi.connect(*cargs, **cparams) + except Exception, e: + raise exceptions.DBAPIError("Connection failed", e) + creator = kwargs.pop('creator', connect) + + poolclass = kwargs.pop('poolclass', getattr(module, 'poolclass', poollib.QueuePool)) + pool_args = {} + # consume pool arguments from kwargs, translating a few of the arguments + for k in util.get_cls_kwargs(poolclass): + tk = {'echo':'echo_pool', 'timeout':'pool_timeout', 'recycle':'pool_recycle'}.get(k, k) + if tk in kwargs: + pool_args[k] = kwargs.pop(tk) + pool_args['use_threadlocal'] = self.pool_threadlocal() + pool = poolclass(creator, **pool_args) + else: + if isinstance(pool, poollib.DBProxy): + pool = pool.get_pool(*cargs, **cparams) + else: + pool = pool + + provider = self.get_pool_provider(pool) + + # create engine. + engineclass = self.get_engine_cls() + engine_args = {} + for k in util.get_cls_kwargs(engineclass): + if k in kwargs: + engine_args[k] = kwargs.pop(k) + + # all kwargs should be consumed + if len(kwargs): + raise TypeError("Invalid argument(s) %s sent to create_engine(), using configuration %s/%s/%s. Please check that the keyword arguments are appropriate for this combination of components." % (','.join(["'%s'" % k for k in kwargs]), dialect.__class__.__name__, pool.__class__.__name__, engineclass.__name__)) + + return engineclass(provider, dialect, **engine_args) def pool_threadlocal(self): raise NotImplementedError() - def get_pool_provider(self, dialect, url, **kwargs): + def get_pool_provider(self, pool): raise NotImplementedError() - def get_engine(self, provider, dialect, **kwargs): + def get_engine_cls(self): raise NotImplementedError() class PlainEngineStrategy(DefaultEngineStrategy): @@ -52,10 +99,10 @@ class PlainEngineStrategy(DefaultEngineStrategy): DefaultEngineStrategy.__init__(self, 'plain') def pool_threadlocal(self): return False - def get_pool_provider(self, dialect, url, **poolargs): - return default.PoolConnectionProvider(dialect, url, **poolargs) - def get_engine(self, provider, dialect, **kwargs): - return base.Engine(provider, dialect, **kwargs) + def get_pool_provider(self, pool): + return default.PoolConnectionProvider(pool) + def get_engine_cls(self): + return base.Engine PlainEngineStrategy() class ThreadLocalEngineStrategy(DefaultEngineStrategy): @@ -63,10 +110,10 @@ class ThreadLocalEngineStrategy(DefaultEngineStrategy): DefaultEngineStrategy.__init__(self, 'threadlocal') def pool_threadlocal(self): return True - def get_pool_provider(self, dialect, url, **poolargs): - return threadlocal.TLocalConnectionProvider(dialect, url, **poolargs) - def get_engine(self, provider, dialect, **kwargs): - return threadlocal.TLEngine(provider, dialect, **kwargs) + def get_pool_provider(self, pool): + return threadlocal.TLocalConnectionProvider(pool) + def get_engine_cls(self): + return threadlocal.TLEngine ThreadLocalEngineStrategy() diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 3636d5523..bd40039d4 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -57,7 +57,18 @@ class ArgSingleton(type): instance = type.__call__(self, *args) ArgSingleton.instances[hashkey] = instance return instance - + +def get_cls_kwargs(cls): + """return the full set of legal kwargs for the given cls""" + kw = [] + for c in cls.__mro__: + cons = c.__init__ + if hasattr(cons, 'func_code'): + for vn in cons.func_code.co_varnames: + if vn != 'self': + kw.append(vn) + return kw + class SimpleProperty(object): """a "default" property accessor.""" def __init__(self, key): |
