diff options
Diffstat (limited to 'lib/sqlalchemy/engine/default.py')
-rw-r--r-- | lib/sqlalchemy/engine/default.py | 136 |
1 files changed, 81 insertions, 55 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index a5af6ff19..f6c2263b3 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -452,14 +452,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext): isinsert = False isupdate = False isdelete = False + is_crud = False isddl = False executemany = False result_map = None compiled = None statement = None - postfetch_cols = None - prefetch_cols = None - returning_cols = None _is_implicit_returning = False _is_explicit_returning = False @@ -515,10 +513,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext): if not compiled.can_execute: raise exc.ArgumentError("Not an executable clause") - self.execution_options = compiled.statement._execution_options - if connection._execution_options: - self.execution_options = dict(self.execution_options) - self.execution_options.update(connection._execution_options) + self.execution_options = compiled.statement._execution_options.union( + connection._execution_options) # compiled clauseelement. process bind params, process table defaults, # track collections used by ResultProxy to target and process results @@ -548,6 +544,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self.cursor = self.create_cursor() if self.isinsert or self.isupdate or self.isdelete: + self.is_crud = True self._is_explicit_returning = bool(compiled.statement._returning) self._is_implicit_returning = bool( compiled.returning and not compiled.statement._returning) @@ -681,10 +678,6 @@ class DefaultExecutionContext(interfaces.ExecutionContext): return self.execution_options.get("no_parameters", False) @util.memoized_property - def is_crud(self): - return self.isinsert or self.isupdate or self.isdelete - - @util.memoized_property def should_autocommit(self): autocommit = self.execution_options.get('autocommit', not self.compiled and @@ -799,52 +792,84 @@ class DefaultExecutionContext(interfaces.ExecutionContext): def supports_sane_multi_rowcount(self): return self.dialect.supports_sane_multi_rowcount - def post_insert(self): - + def _setup_crud_result_proxy(self): + if self.isinsert and \ + not self.executemany: + if not self._is_implicit_returning and \ + not self.compiled.inline and \ + self.dialect.postfetch_lastrowid: + + self._setup_ins_pk_from_lastrowid() + + elif not self._is_implicit_returning: + self._setup_ins_pk_from_empty() + + result = self.get_result_proxy() + + if self.isinsert: + if self._is_implicit_returning: + row = result.fetchone() + self.returned_defaults = row + self._setup_ins_pk_from_implicit_returning(row) + result.close(_autoclose_connection=False) + result._metadata = None + elif not self._is_explicit_returning: + result.close(_autoclose_connection=False) + result._metadata = None + elif self.isupdate and self._is_implicit_returning: + row = result.fetchone() + self.returned_defaults = row + result.close(_autoclose_connection=False) + result._metadata = None + + elif result._metadata is None: + # no results, get rowcount + # (which requires open cursor on some drivers + # such as kintersbasdb, mxodbc) + result.rowcount + result.close(_autoclose_connection=False) + return result + + def _setup_ins_pk_from_lastrowid(self): key_getter = self.compiled._key_getters_for_crud_column[2] table = self.compiled.statement.table + compiled_params = self.compiled_parameters[0] + + lastrowid = self.get_lastrowid() + autoinc_col = table._autoincrement_column + if autoinc_col is not None: + # apply type post processors to the lastrowid + proc = autoinc_col.type._cached_result_processor( + self.dialect, None) + if proc is not None: + lastrowid = proc(lastrowid) + self.inserted_primary_key = [ + lastrowid if c is autoinc_col else + compiled_params.get(key_getter(c), None) + for c in table.primary_key + ] - if not self._is_implicit_returning and \ - not self._is_explicit_returning and \ - not self.compiled.inline and \ - self.dialect.postfetch_lastrowid: - - lastrowid = self.get_lastrowid() - autoinc_col = table._autoincrement_column - if autoinc_col is not None: - # apply type post processors to the lastrowid - proc = autoinc_col.type._cached_result_processor( - self.dialect, None) - if proc is not None: - lastrowid = proc(lastrowid) - self.inserted_primary_key = [ - lastrowid if c is autoinc_col else - self.compiled_parameters[0].get(key_getter(c), None) - for c in table.primary_key - ] - else: - self.inserted_primary_key = [ - self.compiled_parameters[0].get(key_getter(c), None) - for c in table.primary_key - ] - - def _fetch_implicit_returning(self, resultproxy): + def _setup_ins_pk_from_empty(self): + key_getter = self.compiled._key_getters_for_crud_column[2] table = self.compiled.statement.table - row = resultproxy.fetchone() - - ipk = [] - for c, v in zip(table.primary_key, self.inserted_primary_key): - if v is not None: - ipk.append(v) - else: - ipk.append(row[c]) + compiled_params = self.compiled_parameters[0] + self.inserted_primary_key = [ + compiled_params.get(key_getter(c), None) + for c in table.primary_key + ] - self.inserted_primary_key = ipk - self.returned_defaults = row + def _setup_ins_pk_from_implicit_returning(self, row): + key_getter = self.compiled._key_getters_for_crud_column[2] + table = self.compiled.statement.table + compiled_params = self.compiled_parameters[0] - def _fetch_implicit_update_returning(self, resultproxy): - row = resultproxy.fetchone() - self.returned_defaults = row + self.inserted_primary_key = [ + row[col] if value is None else value + for col, value in [ + (col, compiled_params.get(key_getter(col), None)) + for col in table.primary_key + ] + ] def lastrow_has_defaults(self): return (self.isinsert or self.isupdate) and \ @@ -956,14 +981,17 @@ class DefaultExecutionContext(interfaces.ExecutionContext): def _process_executesingle_defaults(self): key_getter = self.compiled._key_getters_for_crud_column[2] - prefetch = self.compiled.prefetch self.current_parameters = compiled_parameters = \ self.compiled_parameters[0] for c in prefetch: if self.isinsert: - val = self.get_insert_default(c) + if c.default and \ + not c.default.is_sequence and c.default.is_scalar: + val = c.default.arg + else: + val = self.get_insert_default(c) else: val = self.get_update_default(c) @@ -972,6 +1000,4 @@ class DefaultExecutionContext(interfaces.ExecutionContext): del self.current_parameters - - DefaultDialect.execution_ctx_cls = DefaultExecutionContext |