diff options
Diffstat (limited to 'lib/sqlalchemy/engine/default.py')
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 126 |
1 files changed, 73 insertions, 53 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 86563cd7c..ceecee364 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -26,16 +26,17 @@ class PoolConnectionProvider(base.ConnectionProvider): class DefaultDialect(base.Dialect): """Default implementation of Dialect""" - def __init__(self, convert_unicode=False, encoding='utf-8', default_paramstyle='named', **kwargs): + def __init__(self, convert_unicode=False, encoding='utf-8', default_paramstyle='named', paramstyle=None, dbapi=None, **kwargs): self.convert_unicode = convert_unicode self.supports_autoclose_results = True self.encoding = encoding self.positional = False self._ischema = None - self._figure_paramstyle(default=default_paramstyle) + self.dbapi = dbapi + self._figure_paramstyle(paramstyle=paramstyle, default=default_paramstyle) - def create_execution_context(self): - return DefaultExecutionContext(self) + def create_execution_context(self, **kwargs): + return DefaultExecutionContext(self, **kwargs) def type_descriptor(self, typeobj): """Provide a database-specific ``TypeEngine`` object, given @@ -56,6 +57,9 @@ class DefaultDialect(base.Dialect): # TODO: probably raise this and fill out # db modules better return 30 + + def supports_alter(self): + return True def oid_column_name(self, column): return None @@ -92,14 +96,8 @@ class DefaultDialect(base.Dialect): def do_execute(self, cursor, statement, parameters, **kwargs): cursor.execute(statement, parameters) - def defaultrunner(self, engine, proxy): - return base.DefaultRunner(engine, proxy) - - def create_cursor(self, connection): - return connection.cursor() - - def create_result_proxy_args(self, connection, cursor): - return dict(should_prefetch=False) + def defaultrunner(self, connection): + return base.DefaultRunner(connection) def _set_paramstyle(self, style): self._paramstyle = style @@ -126,11 +124,10 @@ class DefaultDialect(base.Dialect): return parameters def _figure_paramstyle(self, paramstyle=None, default='named'): - db = self.dbapi() if paramstyle is not None: self._paramstyle = paramstyle - elif db is not None: - self._paramstyle = db.paramstyle + elif self.dbapi is not None: + self._paramstyle = self.dbapi.paramstyle else: self._paramstyle = default @@ -146,10 +143,6 @@ class DefaultDialect(base.Dialect): raise DBAPIError("Unsupported paramstyle '%s'" % self._paramstyle) def _get_ischema(self): - # We use a property for ischema so that the accessor - # creation only happens as needed, since otherwise we - # have a circularity problem with the generic - # ansisql.engine() if self._ischema is None: import sqlalchemy.databases.information_schema as ischema self._ischema = ischema.ISchema(self) @@ -157,20 +150,49 @@ class DefaultDialect(base.Dialect): ischema = property(_get_ischema, doc="""returns an ISchema object for this engine, which allows access to information_schema tables (if supported)""") class DefaultExecutionContext(base.ExecutionContext): - def __init__(self, dialect): + def __init__(self, dialect, connection, compiled=None, compiled_parameters=None, statement=None, parameters=None): self.dialect = dialect + self.connection = connection + self.compiled = compiled + self.compiled_parameters = compiled_parameters + + if compiled is not None: + self.typemap = compiled.typemap + self.column_labels = compiled.column_labels + self.statement = unicode(compiled) + else: + self.typemap = self.column_labels = None + self.parameters = parameters + self.statement = statement - def pre_exec(self, engine, proxy, compiled, parameters): - self._process_defaults(engine, proxy, compiled, parameters) + if not dialect.supports_unicode_statements(): + self.statement = self.statement.encode('ascii') + + self.cursor = self.create_cursor() + + engine = property(lambda s:s.connection.engine) + + def is_select(self): + return re.match(r'SELECT', self.statement.lstrip(), re.I) + + def create_cursor(self): + return self.connection.connection.cursor() + + def pre_exec(self): + self._process_defaults() + self.parameters = self.dialect.convert_compiled_params(self.compiled_parameters) - def post_exec(self, engine, proxy, compiled, parameters): + def post_exec(self): pass - def get_rowcount(self, cursor): + def get_result_proxy(self): + return base.ResultProxy(self) + + def get_rowcount(self): if hasattr(self, '_rowcount'): return self._rowcount else: - return cursor.rowcount + return self.cursor.rowcount def supports_sane_rowcount(self): return self.dialect.supports_sane_rowcount() @@ -187,44 +209,44 @@ class DefaultExecutionContext(base.ExecutionContext): def lastrow_has_defaults(self): return self._lastrow_has_defaults - def set_input_sizes(self, cursor, parameters): + def set_input_sizes(self): """Given a cursor and ClauseParameters, call the appropriate style of ``setinputsizes()`` on the cursor, using DBAPI types from the bind parameter's ``TypeEngine`` objects. """ - if isinstance(parameters, list): - plist = parameters + if isinstance(self.compiled_parameters, list): + plist = self.compiled_parameters else: - plist = [parameters] + plist = [self.compiled_parameters] if self.dialect.positional: inputsizes = [] for params in plist[0:1]: for key in params.positional: typeengine = params.binds[key].type - dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.module) + dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) if dbtype is not None: inputsizes.append(dbtype) - cursor.setinputsizes(*inputsizes) + self.cursor.setinputsizes(*inputsizes) else: inputsizes = {} for params in plist[0:1]: for key in params.keys(): typeengine = params.binds[key].type - dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.module) + dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) if dbtype is not None: inputsizes[key] = dbtype - cursor.setinputsizes(**inputsizes) + self.cursor.setinputsizes(**inputsizes) - def _process_defaults(self, engine, proxy, compiled, parameters): + def _process_defaults(self): """``INSERT`` and ``UPDATE`` statements, when compiled, may have additional columns added to their ``VALUES`` and ``SET`` lists corresponding to column defaults/onupdates that are present on the ``Table`` object (i.e. ``ColumnDefault``, ``Sequence``, ``PassiveDefault``). This method pre-execs those ``DefaultGenerator`` objects that require pre-execution - and sets their values within the parameter list, and flags the - thread-local state about ``PassiveDefault`` objects that may + and sets their values within the parameter list, and flags this + ExecutionContext about ``PassiveDefault`` objects that may require post-fetching the row after it is inserted/updated. This method relies upon logic within the ``ANSISQLCompiler`` @@ -234,30 +256,28 @@ class DefaultExecutionContext(base.ExecutionContext): statement. """ - if compiled is None: return - - if getattr(compiled, "isinsert", False): - if isinstance(parameters, list): - plist = parameters + if self.compiled.isinsert: + if isinstance(self.compiled_parameters, list): + plist = self.compiled_parameters else: - plist = [parameters] - drunner = self.dialect.defaultrunner(engine, proxy) + plist = [self.compiled_parameters] + drunner = self.dialect.defaultrunner(base.Connection(self.engine, self.connection.connection)) self._lastrow_has_defaults = False for param in plist: last_inserted_ids = [] need_lastrowid=False # check the "default" status of each column in the table - for c in compiled.statement.table.c: + for c in self.compiled.statement.table.c: # check if it will be populated by a SQL clause - we'll need that # after execution. - if c in compiled.inline_params: + if c in self.compiled.inline_params: self._lastrow_has_defaults = True if c.primary_key: need_lastrowid = True # check if its not present at all. see if theres a default # and fire it off, and add to bind parameters. if # its a pk, add the value to our last_inserted_ids list, - # or, if its a SQL-side default, dont do any of that, but we'll need + # or, if its a SQL-side default, let it fire off on the DB side, but we'll need # the SQL-generated value after execution. elif not c.key in param or param.get_original(c.key) is None: if isinstance(c.default, schema.PassiveDefault): @@ -278,19 +298,19 @@ class DefaultExecutionContext(base.ExecutionContext): else: self._last_inserted_ids = last_inserted_ids self._last_inserted_params = param - elif getattr(compiled, 'isupdate', False): - if isinstance(parameters, list): - plist = parameters + elif self.compiled.isupdate: + if isinstance(self.compiled_parameters, list): + plist = self.compiled_parameters else: - plist = [parameters] - drunner = self.dialect.defaultrunner(engine, proxy) + plist = [self.compiled_parameters] + drunner = self.dialect.defaultrunner(base.Connection(self.engine, self.connection.connection)) self._lastrow_has_defaults = False for param in plist: # check the "onupdate" status of each column in the table - for c in compiled.statement.table.c: + for c in self.compiled.statement.table.c: # it will be populated by a SQL clause - we'll need that # after execution. - if c in compiled.inline_params: + if c in self.compiled.inline_params: pass # its not in the bind parameters, and theres an "onupdate" defined for the column; # execute it and add to bind params |
