summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/persistence.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/persistence.py')
-rw-r--r--lib/sqlalchemy/orm/persistence.py134
1 files changed, 90 insertions, 44 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 1f5507edf..b0fa620e3 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -1,5 +1,5 @@
# orm/persistence.py
-# Copyright (C) 2005-2013 the SQLAlchemy authors and contributors <see AUTHORS file>
+# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
@@ -17,7 +17,7 @@ import operator
from itertools import groupby
from .. import sql, util, exc as sa_exc, schema
from . import attributes, sync, exc as orm_exc, evaluator
-from .util import _state_mapper, state_str, _attr_as_key
+from .base import _state_mapper, state_str, _attr_as_key
from ..sql import expression
from . import loading
@@ -61,7 +61,7 @@ def save_obj(base_mapper, states, uowtransaction, single=False):
if insert:
_emit_insert_statements(base_mapper, uowtransaction,
cached_connections,
- table, insert)
+ mapper, table, insert)
_finalize_insert_update_commands(base_mapper, uowtransaction,
states_to_insert, states_to_update)
@@ -246,9 +246,12 @@ def _collect_insert_commands(base_mapper, uowtransaction, table,
value_params = {}
has_all_pks = True
+ has_all_defaults = True
for col in mapper._cols_by_table[table]:
- if col is mapper.version_id_col:
- params[col.key] = mapper.version_id_generator(None)
+ if col is mapper.version_id_col and \
+ mapper.version_id_generator is not False:
+ val = mapper.version_id_generator(None)
+ params[col.key] = val
else:
# pull straight from the dict for
# pending objects
@@ -261,6 +264,9 @@ def _collect_insert_commands(base_mapper, uowtransaction, table,
elif col.default is None and \
col.server_default is None:
params[col.key] = value
+ elif col.server_default is not None and \
+ mapper.base_mapper.eager_defaults:
+ has_all_defaults = False
elif isinstance(value, sql.ClauseElement):
value_params[col] = value
@@ -268,7 +274,8 @@ def _collect_insert_commands(base_mapper, uowtransaction, table,
params[col.key] = value
insert.append((state, state_dict, params, mapper,
- connection, value_params, has_all_pks))
+ connection, value_params, has_all_pks,
+ has_all_defaults))
return insert
@@ -315,19 +322,20 @@ def _collect_update_commands(base_mapper, uowtransaction,
params[col.key] = history.added[0]
hasdata = True
else:
- params[col.key] = mapper.version_id_generator(
- params[col._label])
-
- # HACK: check for history, in case the
- # history is only
- # in a different table than the one
- # where the version_id_col is.
- for prop in mapper._columntoproperty.values():
- history = attributes.get_state_history(
- state, prop.key,
- attributes.PASSIVE_NO_INITIALIZE)
- if history.added:
- hasdata = True
+ if mapper.version_id_generator is not False:
+ val = mapper.version_id_generator(params[col._label])
+ params[col.key] = val
+
+ # HACK: check for history, in case the
+ # history is only
+ # in a different table than the one
+ # where the version_id_col is.
+ for prop in mapper._columntoproperty.values():
+ history = attributes.get_state_history(
+ state, prop.key,
+ attributes.PASSIVE_NO_INITIALIZE)
+ if history.added:
+ hasdata = True
else:
prop = mapper._columntoproperty[col]
history = attributes.get_state_history(
@@ -409,6 +417,7 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table,
mapper._get_state_attr_by_column(
state,
state_dict, col)
+
elif col in post_update_cols:
prop = mapper._columntoproperty[col]
history = attributes.get_state_history(
@@ -478,7 +487,13 @@ def _emit_update_statements(base_mapper, uowtransaction,
sql.bindparam(mapper.version_id_col._label,
type_=mapper.version_id_col.type))
- return table.update(clause)
+ stmt = table.update(clause)
+ if mapper.base_mapper.eager_defaults:
+ stmt = stmt.return_defaults()
+ elif mapper.version_id_col is not None:
+ stmt = stmt.return_defaults(mapper.version_id_col)
+
+ return stmt
statement = base_mapper._memo(('update', table), update_stmt)
@@ -500,8 +515,7 @@ def _emit_update_statements(base_mapper, uowtransaction,
table,
state,
state_dict,
- c.context.prefetch_cols,
- c.context.postfetch_cols,
+ c,
c.context.compiled_parameters[0],
value_params)
rows += c.rowcount
@@ -521,44 +535,55 @@ def _emit_update_statements(base_mapper, uowtransaction,
def _emit_insert_statements(base_mapper, uowtransaction,
- cached_connections, table, insert):
+ cached_connections, mapper, table, insert):
"""Emit INSERT statements corresponding to value lists collected
by _collect_insert_commands()."""
statement = base_mapper._memo(('insert', table), table.insert)
- for (connection, pkeys, hasvalue, has_all_pks), \
+ for (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), \
records in groupby(insert,
lambda rec: (rec[4],
list(rec[2].keys()),
bool(rec[5]),
- rec[6])
+ rec[6], rec[7])
):
- if has_all_pks and not hasvalue:
+ if \
+ (
+ has_all_defaults
+ or not base_mapper.eager_defaults
+ or not connection.dialect.implicit_returning
+ ) and has_all_pks and not hasvalue:
+
records = list(records)
multiparams = [rec[2] for rec in records]
+
c = cached_connections[connection].\
execute(statement, multiparams)
- for (state, state_dict, params, mapper,
- conn, value_params, has_all_pks), \
+ for (state, state_dict, params, mapper_rec,
+ conn, value_params, has_all_pks, has_all_defaults), \
last_inserted_params in \
zip(records, c.context.compiled_parameters):
_postfetch(
- mapper,
+ mapper_rec,
uowtransaction,
table,
state,
state_dict,
- c.context.prefetch_cols,
- c.context.postfetch_cols,
+ c,
last_inserted_params,
value_params)
else:
- for state, state_dict, params, mapper, \
+ if not has_all_defaults and base_mapper.eager_defaults:
+ statement = statement.return_defaults()
+ elif mapper.version_id_col is not None:
+ statement = statement.return_defaults(mapper.version_id_col)
+
+ for state, state_dict, params, mapper_rec, \
connection, value_params, \
- has_all_pks in records:
+ has_all_pks, has_all_defaults in records:
if value_params:
result = connection.execute(
@@ -574,23 +599,22 @@ def _emit_insert_statements(base_mapper, uowtransaction,
# set primary key attributes
for pk, col in zip(primary_key,
mapper._pks_by_table[table]):
- prop = mapper._columntoproperty[col]
+ prop = mapper_rec._columntoproperty[col]
if state_dict.get(prop.key) is None:
# TODO: would rather say:
#state_dict[prop.key] = pk
- mapper._set_state_attr_by_column(
+ mapper_rec._set_state_attr_by_column(
state,
state_dict,
col, pk)
_postfetch(
- mapper,
+ mapper_rec,
uowtransaction,
table,
state,
state_dict,
- result.context.prefetch_cols,
- result.context.postfetch_cols,
+ result,
result.context.compiled_parameters[0],
value_params)
@@ -699,14 +723,25 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction,
if readonly:
state._expire_attributes(state.dict, readonly)
- # if eager_defaults option is enabled,
- # refresh whatever has been expired.
- if base_mapper.eager_defaults and state.unloaded:
+ # if eager_defaults option is enabled, load
+ # all expired cols. Else if we have a version_id_col, make sure
+ # it isn't expired.
+ toload_now = []
+
+ if base_mapper.eager_defaults:
+ toload_now.extend(state._unloaded_non_object)
+ elif mapper.version_id_col is not None and \
+ mapper.version_id_generator is False:
+ prop = mapper._columntoproperty[mapper.version_id_col]
+ if prop.key in state.unloaded:
+ toload_now.extend([prop.key])
+
+ if toload_now:
state.key = base_mapper._identity_key_from_state(state)
loading.load_on_ident(
uowtransaction.session.query(base_mapper),
state.key, refresh_state=state,
- only_load_props=state.unloaded)
+ only_load_props=toload_now)
# call after_XXX extensions
if not has_identity:
@@ -716,15 +751,26 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction,
def _postfetch(mapper, uowtransaction, table,
- state, dict_, prefetch_cols, postfetch_cols,
- params, value_params):
+ state, dict_, result, params, value_params):
"""Expire attributes in need of newly persisted database state,
after an INSERT or UPDATE statement has proceeded for that
state."""
+ prefetch_cols = result.context.prefetch_cols
+ postfetch_cols = result.context.postfetch_cols
+ returning_cols = result.context.returning_cols
+
if mapper.version_id_col is not None:
prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
+ if returning_cols:
+ row = result.context.returned_defaults
+ if row is not None:
+ for col in returning_cols:
+ if col.primary_key:
+ continue
+ mapper._set_state_attr_by_column(state, dict_, col, row[col])
+
for c in prefetch_cols:
if c.key in params and c in mapper._columntoproperty:
mapper._set_state_attr_by_column(state, dict_, c, params[c.key])