summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/engine/default.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/engine/default.py')
-rw-r--r--lib/sqlalchemy/engine/default.py33
1 files changed, 30 insertions, 3 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 227ff0845..7d7e7fc2f 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -1177,10 +1177,11 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
self.root_connection._handle_dbapi_exception(
e, None, None, None, self)
- def _exec_default(self, default, type_):
+ def _exec_default(self, column, default, type_):
if default.is_sequence:
return self.fire_sequence(default, type_)
elif default.is_callable:
+ self.current_column = column
return default.arg(self)
elif default.is_clause_element:
# TODO: expensive branching here should be
@@ -1195,17 +1196,43 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
else:
return default.arg
+ def get_current_parameters(self, isolate_multiinsert_groups=True):
+ try:
+ parameters = self.current_parameters
+ column = self.current_column
+ except AttributeError:
+ raise exc.InvalidRequestError(
+ "get_current_parameters() can only be invoked in the "
+ "context of a Python side column default function")
+ if isolate_multiinsert_groups and \
+ self.isinsert and \
+ self.compiled.statement._has_multi_parameters:
+ if column._is_multiparam_column:
+ index = column.index + 1
+ d = {column.original.key: parameters[column.key]}
+ else:
+ d = {column.key: parameters[column.key]}
+ index = 0
+ keys = self.compiled.statement.parameters[0].keys()
+ d.update(
+ (key, parameters["%s_m%d" % (key, index)])
+ for key in keys
+ )
+ return d
+ else:
+ return parameters
+
def get_insert_default(self, column):
if column.default is None:
return None
else:
- return self._exec_default(column.default, column.type)
+ return self._exec_default(column, column.default, column.type)
def get_update_default(self, column):
if column.onupdate is None:
return None
else:
- return self._exec_default(column.onupdate, column.type)
+ return self._exec_default(column, column.onupdate, column.type)
def _process_executemany_defaults(self):
key_getter = self.compiled._key_getters_for_crud_column[2]