summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2017-09-13 11:39:47 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2017-09-13 11:39:47 -0400
commitb58dc5bc67caa7d8871acfd42e149e119844a21d (patch)
tree2737a4c5593b53db4ee61ba3a09ba0d3fb4e6eeb
parent31f80b9eaeb3c3435b7f6679b41e434478b1d11c (diff)
downloadsqlalchemy-ticket_4075.tar.gz
Add multivalued insert context for defaultsticket_4075
wip Change-Id: I6894c7b4a2bce3e83c3ade8af0e5b2f8df37b785 Fixes: #4075
-rw-r--r--lib/sqlalchemy/engine/base.py2
-rw-r--r--lib/sqlalchemy/engine/default.py33
-rw-r--r--lib/sqlalchemy/sql/crud.py3
-rw-r--r--lib/sqlalchemy/sql/elements.py2
4 files changed, 36 insertions, 4 deletions
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index b5c95cb17..1719de516 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -977,7 +977,7 @@ class Connection(Connectable):
except BaseException as e:
self._handle_dbapi_exception(e, None, None, None, None)
- ret = ctx._exec_default(default, None)
+ ret = ctx._exec_default(None, default, None)
if self.should_close_with_result:
self.close()
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 227ff0845..7d7e7fc2f 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -1177,10 +1177,11 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
self.root_connection._handle_dbapi_exception(
e, None, None, None, self)
- def _exec_default(self, default, type_):
+ def _exec_default(self, column, default, type_):
if default.is_sequence:
return self.fire_sequence(default, type_)
elif default.is_callable:
+ self.current_column = column
return default.arg(self)
elif default.is_clause_element:
# TODO: expensive branching here should be
@@ -1195,17 +1196,43 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
else:
return default.arg
+ def get_current_parameters(self, isolate_multiinsert_groups=True):
+ try:
+ parameters = self.current_parameters
+ column = self.current_column
+ except AttributeError:
+ raise exc.InvalidRequestError(
+ "get_current_parameters() can only be invoked in the "
+ "context of a Python side column default function")
+ if isolate_multiinsert_groups and \
+ self.isinsert and \
+ self.compiled.statement._has_multi_parameters:
+ if column._is_multiparam_column:
+ index = column.index + 1
+ d = {column.original.key: parameters[column.key]}
+ else:
+ d = {column.key: parameters[column.key]}
+ index = 0
+ keys = self.compiled.statement.parameters[0].keys()
+ d.update(
+ (key, parameters["%s_m%d" % (key, index)])
+ for key in keys
+ )
+ return d
+ else:
+ return parameters
+
def get_insert_default(self, column):
if column.default is None:
return None
else:
- return self._exec_default(column.default, column.type)
+ return self._exec_default(column, column.default, column.type)
def get_update_default(self, column):
if column.onupdate is None:
return None
else:
- return self._exec_default(column.onupdate, column.type)
+ return self._exec_default(column, column.onupdate, column.type)
def _process_executemany_defaults(self):
key_getter = self.compiled._key_getters_for_crud_column[2]
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 5739c22f9..8421b1e66 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -395,7 +395,10 @@ def _create_update_prefetch_bind_param(compiler, c, process=True, name=None):
class _multiparam_column(elements.ColumnElement):
+ _is_multiparam_column = True
+
def __init__(self, original, index):
+ self.index = index
self.key = "%s_m%d" % (original.key, index + 1)
self.original = original
self.default = original.default
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 36a6a6557..478e0c59c 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -3656,6 +3656,8 @@ class ColumnClause(Immutable, ColumnElement):
onupdate = default = server_default = server_onupdate = None
+ _is_multiparam_column = False
+
_memoized_property = util.group_expirable_memoized_property()
def __init__(self, text, type_=None, is_literal=False, _selectable=None):