summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/engine/default.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/engine/default.py')
-rw-r--r--lib/sqlalchemy/engine/default.py213
1 files changed, 213 insertions, 0 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
new file mode 100644
index 000000000..40978204a
--- /dev/null
+++ b/lib/sqlalchemy/engine/default.py
@@ -0,0 +1,213 @@
+# engine/default.py
+# Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+
+from sqlalchemy import schema, exceptions, util, sql, types
+import sqlalchemy.pool
+import StringIO, sys, re
+import base
+
+"""provides default implementations of the engine interfaces"""
+
+
+class PoolConnectionProvider(base.ConnectionProvider):
+ def __init__(self, dialect, url, poolclass=None, pool=None, **kwargs):
+ (cargs, cparams) = dialect.create_connect_args(url)
+ if pool is None:
+ kwargs.setdefault('echo', False)
+ kwargs.setdefault('use_threadlocal',True)
+ if poolclass is None:
+ poolclass = sqlalchemy.pool.QueuePool
+ dbapi = dialect.dbapi()
+ if dbapi is None:
+ raise exceptions.InvalidRequestException("Cant get DBAPI module for dialect '%s'" % dialect)
+ self._pool = poolclass(lambda: dbapi.connect(*cargs, **cparams), **kwargs)
+ else:
+ if isinstance(pool, sqlalchemy.pool.DBProxy):
+ self._pool = pool.get_pool(*cargs, **cparams)
+ else:
+ self._pool = pool
+ def get_connection(self):
+ return self._pool.connect()
+ def dispose(self):
+ self._pool.dispose()
+ if hasattr(self, '_dbproxy'):
+ self._dbproxy.dispose()
+
+class DefaultDialect(base.Dialect):
+ """default implementation of Dialect"""
+ def __init__(self, convert_unicode=False, encoding='utf-8', **kwargs):
+ self.convert_unicode = convert_unicode
+ self.supports_autoclose_results = True
+ self.encoding = encoding
+ self.positional = False
+ self.paramstyle = 'named'
+ self._ischema = None
+ self._figure_paramstyle()
+ def create_execution_context(self):
+ return DefaultExecutionContext(self)
+ def type_descriptor(self, typeobj):
+ """provides a database-specific TypeEngine object, given the generic object
+ which comes from the types module. Subclasses will usually use the adapt_type()
+ method in the types module to make this job easy."""
+ if type(typeobj) is type:
+ typeobj = typeobj()
+ return typeobj
+ def oid_column_name(self):
+ return None
+ def supports_sane_rowcount(self):
+ return True
+ def do_begin(self, connection):
+ """implementations might want to put logic here for turning autocommit on/off,
+ etc."""
+ pass
+ def do_rollback(self, connection):
+ """implementations might want to put logic here for turning autocommit on/off,
+ etc."""
+ #print "ENGINE ROLLBACK ON ", connection.connection
+ connection.rollback()
+ def do_commit(self, connection):
+ """implementations might want to put logic here for turning autocommit on/off, etc."""
+ #print "ENGINE COMMIT ON ", connection.connection
+ connection.commit()
+ def do_executemany(self, cursor, statement, parameters, **kwargs):
+ cursor.executemany(statement, parameters)
+ def do_execute(self, cursor, statement, parameters, **kwargs):
+ cursor.execute(statement, parameters)
+ def defaultrunner(self, engine, proxy):
+ return base.DefaultRunner(engine, proxy)
+
+ def _set_paramstyle(self, style):
+ self._paramstyle = style
+ self._figure_paramstyle(style)
+ paramstyle = property(lambda s:s._paramstyle, _set_paramstyle)
+
+ def convert_compiled_params(self, parameters):
+ executemany = parameters is not None and isinstance(parameters, list)
+ # the bind params are a CompiledParams object. but all the DBAPI's hate
+ # that object (or similar). so convert it to a clean
+ # dictionary/list/tuple of dictionary/tuple of list
+ if parameters is not None:
+ if self.positional:
+ if executemany:
+ parameters = [p.values() for p in parameters]
+ else:
+ parameters = parameters.values()
+ else:
+ if executemany:
+ parameters = [p.get_raw_dict() for p in parameters]
+ else:
+ parameters = parameters.get_raw_dict()
+ return parameters
+
+ def _figure_paramstyle(self, paramstyle=None):
+ db = self.dbapi()
+ if paramstyle is not None:
+ self._paramstyle = paramstyle
+ elif db is not None:
+ self._paramstyle = db.paramstyle
+ else:
+ self._paramstyle = 'named'
+
+ if self._paramstyle == 'named':
+ self.positional=False
+ elif self._paramstyle == 'pyformat':
+ self.positional=False
+ elif self._paramstyle == 'qmark' or self._paramstyle == 'format' or self._paramstyle == 'numeric':
+ # for positional, use pyformat internally, ANSICompiler will convert
+ # to appropriate character upon compilation
+ self.positional = True
+ else:
+ raise DBAPIError("Unsupported paramstyle '%s'" % self._paramstyle)
+
+ def _get_ischema(self):
+ # We use a property for ischema so that the accessor
+ # creation only happens as needed, since otherwise we
+ # have a circularity problem with the generic
+ # ansisql.engine()
+ if self._ischema is None:
+ import sqlalchemy.databases.information_schema as ischema
+ self._ischema = ischema.ISchema(self)
+ return self._ischema
+ ischema = property(_get_ischema, doc="""returns an ISchema object for this engine, which allows access to information_schema tables (if supported)""")
+
+class DefaultExecutionContext(base.ExecutionContext):
+ def __init__(self, dialect):
+ self.dialect = dialect
+ def pre_exec(self, engine, proxy, compiled, parameters):
+ self._process_defaults(engine, proxy, compiled, parameters)
+ def post_exec(self, engine, proxy, compiled, parameters):
+ pass
+ def get_rowcount(self, cursor):
+ if hasattr(self, '_rowcount'):
+ return self._rowcount
+ else:
+ return cursor.rowcount
+ def supports_sane_rowcount(self):
+ return self.dialect.supports_sane_rowcount()
+ def last_inserted_ids(self):
+ return self._last_inserted_ids
+ def last_inserted_params(self):
+ return self._last_inserted_params
+ def last_updated_params(self):
+ return self._last_updated_params
+ def lastrow_has_defaults(self):
+ return self._lastrow_has_defaults
+ def _process_defaults(self, engine, proxy, compiled, parameters):
+ """INSERT and UPDATE statements, when compiled, may have additional columns added to their
+ VALUES and SET lists corresponding to column defaults/onupdates that are present on the
+ Table object (i.e. ColumnDefault, Sequence, PassiveDefault). This method pre-execs those
+ DefaultGenerator objects that require pre-execution and sets their values within the
+ parameter list, and flags the thread-local state about
+ PassiveDefault objects that may require post-fetching the row after it is inserted/updated.
+ This method relies upon logic within the ANSISQLCompiler in its visit_insert and
+ visit_update methods that add the appropriate column clauses to the statement when its
+ being compiled, so that these parameters can be bound to the statement."""
+ if compiled is None: return
+ if getattr(compiled, "isinsert", False):
+ if isinstance(parameters, list):
+ plist = parameters
+ else:
+ plist = [parameters]
+ drunner = self.dialect.defaultrunner(engine, proxy)
+ self._lastrow_has_defaults = False
+ for param in plist:
+ last_inserted_ids = []
+ need_lastrowid=False
+ for c in compiled.statement.table.c:
+ if not param.has_key(c.name) or param[c.name] is None:
+ if isinstance(c.default, schema.PassiveDefault):
+ self._lastrow_has_defaults = True
+ newid = drunner.get_column_default(c)
+ if newid is not None:
+ param[c.name] = newid
+ if c.primary_key:
+ last_inserted_ids.append(param[c.name])
+ elif c.primary_key:
+ need_lastrowid = True
+ elif c.primary_key:
+ last_inserted_ids.append(param[c.name])
+ if need_lastrowid:
+ self._last_inserted_ids = None
+ else:
+ self._last_inserted_ids = last_inserted_ids
+ self._last_inserted_params = param
+ elif getattr(compiled, 'isupdate', False):
+ if isinstance(parameters, list):
+ plist = parameters
+ else:
+ plist = [parameters]
+ drunner = self.dialect.defaultrunner(engine, proxy)
+ self._lastrow_has_defaults = False
+ for param in plist:
+ for c in compiled.statement.table.c:
+ if c.onupdate is not None and (not param.has_key(c.name) or param[c.name] is None):
+ value = drunner.get_column_onupdate(c)
+ if value is not None:
+ param[c.name] = value
+ self._last_updated_params = param
+
+