diff options
Diffstat (limited to 'lib/sqlalchemy/engine/default.py')
-rw-r--r-- | lib/sqlalchemy/engine/default.py | 33 |
1 files changed, 30 insertions, 3 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 227ff0845..7d7e7fc2f 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -1177,10 +1177,11 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self.root_connection._handle_dbapi_exception( e, None, None, None, self) - def _exec_default(self, default, type_): + def _exec_default(self, column, default, type_): if default.is_sequence: return self.fire_sequence(default, type_) elif default.is_callable: + self.current_column = column return default.arg(self) elif default.is_clause_element: # TODO: expensive branching here should be @@ -1195,17 +1196,43 @@ class DefaultExecutionContext(interfaces.ExecutionContext): else: return default.arg + def get_current_parameters(self, isolate_multiinsert_groups=True): + try: + parameters = self.current_parameters + column = self.current_column + except AttributeError: + raise exc.InvalidRequestError( + "get_current_parameters() can only be invoked in the " + "context of a Python side column default function") + if isolate_multiinsert_groups and \ + self.isinsert and \ + self.compiled.statement._has_multi_parameters: + if column._is_multiparam_column: + index = column.index + 1 + d = {column.original.key: parameters[column.key]} + else: + d = {column.key: parameters[column.key]} + index = 0 + keys = self.compiled.statement.parameters[0].keys() + d.update( + (key, parameters["%s_m%d" % (key, index)]) + for key in keys + ) + return d + else: + return parameters + def get_insert_default(self, column): if column.default is None: return None else: - return self._exec_default(column.default, column.type) + return self._exec_default(column, column.default, column.type) def get_update_default(self, column): if column.onupdate is None: return None else: - return self._exec_default(column.onupdate, column.type) + return self._exec_default(column, column.onupdate, column.type) def _process_executemany_defaults(self): key_getter = self.compiled._key_getters_for_crud_column[2] |