diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2017-09-13 11:39:47 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2017-09-13 11:39:47 -0400 |
commit | b58dc5bc67caa7d8871acfd42e149e119844a21d (patch) | |
tree | 2737a4c5593b53db4ee61ba3a09ba0d3fb4e6eeb /lib/sqlalchemy/engine/default.py | |
parent | 31f80b9eaeb3c3435b7f6679b41e434478b1d11c (diff) | |
download | sqlalchemy-ticket_4075.tar.gz |
Add multivalued insert context for defaultsticket_4075
wip
Change-Id: I6894c7b4a2bce3e83c3ade8af0e5b2f8df37b785
Fixes: #4075
Diffstat (limited to 'lib/sqlalchemy/engine/default.py')
-rw-r--r-- | lib/sqlalchemy/engine/default.py | 33 |
1 files changed, 30 insertions, 3 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 227ff0845..7d7e7fc2f 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -1177,10 +1177,11 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self.root_connection._handle_dbapi_exception( e, None, None, None, self) - def _exec_default(self, default, type_): + def _exec_default(self, column, default, type_): if default.is_sequence: return self.fire_sequence(default, type_) elif default.is_callable: + self.current_column = column return default.arg(self) elif default.is_clause_element: # TODO: expensive branching here should be @@ -1195,17 +1196,43 @@ class DefaultExecutionContext(interfaces.ExecutionContext): else: return default.arg + def get_current_parameters(self, isolate_multiinsert_groups=True): + try: + parameters = self.current_parameters + column = self.current_column + except AttributeError: + raise exc.InvalidRequestError( + "get_current_parameters() can only be invoked in the " + "context of a Python side column default function") + if isolate_multiinsert_groups and \ + self.isinsert and \ + self.compiled.statement._has_multi_parameters: + if column._is_multiparam_column: + index = column.index + 1 + d = {column.original.key: parameters[column.key]} + else: + d = {column.key: parameters[column.key]} + index = 0 + keys = self.compiled.statement.parameters[0].keys() + d.update( + (key, parameters["%s_m%d" % (key, index)]) + for key in keys + ) + return d + else: + return parameters + def get_insert_default(self, column): if column.default is None: return None else: - return self._exec_default(column.default, column.type) + return self._exec_default(column, column.default, column.type) def get_update_default(self, column): if column.onupdate is None: return None else: - return self._exec_default(column.onupdate, column.type) + return self._exec_default(column, column.onupdate, column.type) def _process_executemany_defaults(self): key_getter = self.compiled._key_getters_for_crud_column[2] |