diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2017-09-13 11:39:47 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2017-09-13 11:39:47 -0400 |
commit | b58dc5bc67caa7d8871acfd42e149e119844a21d (patch) | |
tree | 2737a4c5593b53db4ee61ba3a09ba0d3fb4e6eeb | |
parent | 31f80b9eaeb3c3435b7f6679b41e434478b1d11c (diff) | |
download | sqlalchemy-ticket_4075.tar.gz |
Add multivalued insert context for defaultsticket_4075
wip
Change-Id: I6894c7b4a2bce3e83c3ade8af0e5b2f8df37b785
Fixes: #4075
-rw-r--r-- | lib/sqlalchemy/engine/base.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/default.py | 33 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 3 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/elements.py | 2 |
4 files changed, 36 insertions, 4 deletions
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index b5c95cb17..1719de516 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -977,7 +977,7 @@ class Connection(Connectable): except BaseException as e: self._handle_dbapi_exception(e, None, None, None, None) - ret = ctx._exec_default(default, None) + ret = ctx._exec_default(None, default, None) if self.should_close_with_result: self.close() diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 227ff0845..7d7e7fc2f 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -1177,10 +1177,11 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self.root_connection._handle_dbapi_exception( e, None, None, None, self) - def _exec_default(self, default, type_): + def _exec_default(self, column, default, type_): if default.is_sequence: return self.fire_sequence(default, type_) elif default.is_callable: + self.current_column = column return default.arg(self) elif default.is_clause_element: # TODO: expensive branching here should be @@ -1195,17 +1196,43 @@ class DefaultExecutionContext(interfaces.ExecutionContext): else: return default.arg + def get_current_parameters(self, isolate_multiinsert_groups=True): + try: + parameters = self.current_parameters + column = self.current_column + except AttributeError: + raise exc.InvalidRequestError( + "get_current_parameters() can only be invoked in the " + "context of a Python side column default function") + if isolate_multiinsert_groups and \ + self.isinsert and \ + self.compiled.statement._has_multi_parameters: + if column._is_multiparam_column: + index = column.index + 1 + d = {column.original.key: parameters[column.key]} + else: + d = {column.key: parameters[column.key]} + index = 0 + keys = self.compiled.statement.parameters[0].keys() + d.update( + (key, parameters["%s_m%d" % (key, index)]) + for key in keys + ) + return d + else: + return parameters + def get_insert_default(self, column): if column.default is None: return None else: - return self._exec_default(column.default, column.type) + return self._exec_default(column, column.default, column.type) def get_update_default(self, column): if column.onupdate is None: return None else: - return self._exec_default(column.onupdate, column.type) + return self._exec_default(column, column.onupdate, column.type) def _process_executemany_defaults(self): key_getter = self.compiled._key_getters_for_crud_column[2] diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 5739c22f9..8421b1e66 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -395,7 +395,10 @@ def _create_update_prefetch_bind_param(compiler, c, process=True, name=None): class _multiparam_column(elements.ColumnElement): + _is_multiparam_column = True + def __init__(self, original, index): + self.index = index self.key = "%s_m%d" % (original.key, index + 1) self.original = original self.default = original.default diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 36a6a6557..478e0c59c 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -3656,6 +3656,8 @@ class ColumnClause(Immutable, ColumnElement): onupdate = default = server_default = server_onupdate = None + _is_multiparam_column = False + _memoized_property = util.group_expirable_memoized_property() def __init__(self, text, type_=None, is_literal=False, _selectable=None): |