summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2010-12-11 03:05:03 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2010-12-11 03:05:03 -0500
commit9c0755640c5f1d45596ff7234d2d42f1c92d09e0 (patch)
treed742ffa4269a28d9dc7e9017876af502a13a02fd /lib/sqlalchemy
parent66e5de30f2e01593182058091075780b41411a78 (diff)
downloadsqlalchemy-9c0755640c5f1d45596ff7234d2d42f1c92d09e0.tar.gz
- clean up the batch insert thing
- add a test for batch inserts - don't need elaborate _inserted_primary_key thing - take some cruft out of ExecutionContext, ResultProxy, EC members can be non-underscored, have mapper just call the EC members for now. - simplify "connection_callable", no need for a "flush_opts" dictionary since this point of expansion is not needed
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/engine/base.py36
-rw-r--r--lib/sqlalchemy/engine/default.py65
-rw-r--r--lib/sqlalchemy/ext/horizontal_shard.py4
-rw-r--r--lib/sqlalchemy/orm/mapper.py182
-rw-r--r--lib/sqlalchemy/orm/session.py3
-rw-r--r--lib/sqlalchemy/orm/unitofwork.py3
6 files changed, 136 insertions, 157 deletions
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 4a00ebda2..d58460fb8 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -2448,8 +2448,16 @@ class ResultProxy(object):
did not explicitly specify returning().
"""
+
+ if not self.context.isinsert:
+ raise exc.InvalidRequestError(
+ "Statement is not an insert() expression construct.")
+ elif self.context._is_explicit_returning:
+ raise exc.InvalidRequestError(
+ "Can't call inserted_primary_key when returning() "
+ "is used.")
- return self.context._inserted_primary_key
+ return self.context.inserted_primary_key
@util.deprecated("0.6", "Use :attr:`.ResultProxy.inserted_primary_key`")
def last_inserted_ids(self):
@@ -2458,22 +2466,24 @@ class ResultProxy(object):
return self.inserted_primary_key
def last_updated_params(self):
- """Return ``last_updated_params()`` from the underlying
- ExecutionContext.
-
- See ExecutionContext for details.
- """
+ """Return the collection of updated parameters from this
+ execution.
- return self.context.last_updated_params
+ """
+ if self.context.executemany:
+ return self.context.compiled_parameters
+ else:
+ return self.context.compiled_parameters[0]
def last_inserted_params(self):
- """Return ``last_inserted_params()`` from the underlying
- ExecutionContext.
-
- See ExecutionContext for details.
+ """Return the collection of inserted parameters from this
+ execution.
+
"""
-
- return self.context.last_inserted_params
+ if self.context.executemany:
+ return self.context.compiled_parameters
+ else:
+ return self.context.compiled_parameters[0]
def lastrow_has_defaults(self):
"""Return ``lastrow_has_defaults()`` from the underlying
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 63b9e44b3..21603b258 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -400,7 +400,9 @@ class DefaultExecutionContext(base.ExecutionContext):
self.cursor = self.create_cursor()
if self.isinsert or self.isupdate:
self.__process_defaults()
-
+ self.postfetch_cols = self.compiled.postfetch
+ self.prefetch_cols = self.compiled.prefetch
+
processors = dict(
(key, value) for key, value in
( (compiled.bind_names[bindparam],
@@ -541,7 +543,8 @@ class DefaultExecutionContext(base.ExecutionContext):
"""
conn = self._connection
- if isinstance(stmt, unicode) and not self.dialect.supports_unicode_statements:
+ if isinstance(stmt, unicode) and \
+ not self.dialect.supports_unicode_statements:
stmt = stmt.encode(self.dialect.encoding)
if self.dialect.positional:
@@ -614,13 +617,14 @@ class DefaultExecutionContext(base.ExecutionContext):
def post_insert(self):
if self.dialect.postfetch_lastrowid and \
- (not len(self._inserted_primary_key) or \
- None in self._inserted_primary_key):
+ (not len(self.inserted_primary_key) or \
+ None in self.inserted_primary_key):
table = self.compiled.statement.table
lastrowid = self.get_lastrowid()
- self._inserted_primary_key = [c is table._autoincrement_column and lastrowid or v
- for c, v in zip(table.primary_key, self._inserted_primary_key)
+ self.inserted_primary_key = [
+ c is table._autoincrement_column and lastrowid or v
+ for c, v in zip(table.primary_key, self.inserted_primary_key)
]
def _fetch_implicit_returning(self, resultproxy):
@@ -628,16 +632,17 @@ class DefaultExecutionContext(base.ExecutionContext):
row = resultproxy.fetchone()
ipk = []
- for c, v in zip(table.primary_key, self._inserted_primary_key):
+ for c, v in zip(table.primary_key, self.inserted_primary_key):
if v is not None:
ipk.append(v)
else:
ipk.append(row[c])
- self._inserted_primary_key = ipk
+ self.inserted_primary_key = ipk
def lastrow_has_defaults(self):
- return hasattr(self, 'postfetch_cols') and len(self.postfetch_cols)
+ return (self.isinsert or self.isupdate) and \
+ bool(self.postfetch_cols)
def set_input_sizes(self, translate=None, exclude_types=None):
"""Given a cursor and ClauseParameters, call the appropriate
@@ -709,31 +714,6 @@ class DefaultExecutionContext(base.ExecutionContext):
else:
return self._exec_default(column.onupdate)
- @util.memoized_property
- def _inserted_primary_key(self):
-
- if not self.isinsert:
- raise exc.InvalidRequestError(
- "Statement is not an insert() expression construct.")
- elif self._is_explicit_returning:
- raise exc.InvalidRequestError(
- "Can't call inserted_primary_key when returning() "
- "is used.")
-
-
- # lazyily evaluate inserted_primary_key for executemany.
- # for execute(), its already in __dict__.
- if self.executemany:
- return [
- [compiled_parameters.get(c.key, None)
- for c in self.compiled.\
- statement.table.primary_key
- ] for compiled_parameters in self.compiled_parameters
- ]
- else:
- # _inserted_primary_key should be calced here
- assert False
-
def __process_defaults(self):
"""Generate default values for compiled insert/update statements,
and generate inserted_primary_key collection.
@@ -764,12 +744,6 @@ class DefaultExecutionContext(base.ExecutionContext):
if val is not None:
param[c.key] = val
del self.current_parameters
-
- if self.isinsert:
- self.last_inserted_params = self.compiled_parameters
- else:
- self.last_updated_params = self.compiled_parameters
-
else:
self.current_parameters = compiled_parameters = \
self.compiled_parameters[0]
@@ -784,19 +758,12 @@ class DefaultExecutionContext(base.ExecutionContext):
compiled_parameters[c.key] = val
del self.current_parameters
- if self.isinsert and not self._is_explicit_returning:
- self._inserted_primary_key = [
+ if self.isinsert:
+ self.inserted_primary_key = [
self.compiled_parameters[0].get(c.key, None)
for c in self.compiled.\
statement.table.primary_key
]
- if self.isinsert:
- self.last_inserted_params = compiled_parameters
- else:
- self.last_updated_params = compiled_parameters
-
- self.postfetch_cols = self.compiled.postfetch
- self.prefetch_cols = self.compiled.prefetch
DefaultDialect.execution_ctx_cls = DefaultExecutionContext
diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py
index 78e3f5953..e48cb9fcb 100644
--- a/lib/sqlalchemy/ext/horizontal_shard.py
+++ b/lib/sqlalchemy/ext/horizontal_shard.py
@@ -50,12 +50,12 @@ class ShardedSession(Session):
self.id_chooser = id_chooser
self.query_chooser = query_chooser
self.__binds = {}
- self._mapper_flush_opts = {'connection_callable':self.connection}
+ self.connection_callable = self.connection
self._query_cls = ShardedQuery
if shards is not None:
for k in shards:
self.bind_shard(k, shards[k])
-
+
def connection(self, mapper=None, instance=None, shard_id=None, **kwargs):
if shard_id is None:
shard_id = self.shard_chooser(mapper, instance)
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index f6a5516d9..20242c97c 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -1468,9 +1468,9 @@ class Mapper(object):
# if session has a connection callable,
# organize individual states with the connection
# to use for update
- if 'connection_callable' in uowtransaction.mapper_flush_opts:
+ if uowtransaction.session.connection_callable:
connection_callable = \
- uowtransaction.mapper_flush_opts['connection_callable']
+ uowtransaction.session.connection_callable
else:
connection = uowtransaction.transaction.connection(self)
connection_callable = None
@@ -1550,15 +1550,10 @@ class Mapper(object):
of objects.
This is called within the context of a UOWTransaction during a
- flush operation.
+ flush operation, given a list of states to be flushed. The
+ base mapper in an inheritance hierarchy handles the inserts/
+ updates for all descendant mappers.
- `_save_obj` issues SQL statements not just for instances mapped
- directly by this mapper, but for instances mapped by all
- inheriting mappers as well. This is to maintain proper insert
- ordering among a polymorphic chain of instances. Therefore
- _save_obj is typically called only on a *base mapper*, or a
- mapper which does not inherit from any other mapper.
-
"""
# if batch=false, call _save_obj separately for each object
@@ -1572,9 +1567,9 @@ class Mapper(object):
# if session has a connection callable,
# organize individual states with the connection
# to use for insert/update
- if 'connection_callable' in uowtransaction.mapper_flush_opts:
+ if uowtransaction.session.connection_callable:
connection_callable = \
- uowtransaction.mapper_flush_opts['connection_callable']
+ uowtransaction.session.connection_callable
else:
connection = uowtransaction.transaction.connection(self)
connection_callable = None
@@ -1592,6 +1587,7 @@ class Mapper(object):
instance_key = state.key or mapper._identity_key_from_state(state)
row_switch = None
+
# call before_XXX extensions
if not has_identity:
mapper.dispatch.on_before_insert(mapper, conn, state)
@@ -1652,10 +1648,9 @@ class Mapper(object):
params = {}
value_params = {}
- hasdata = False
- has_all_pks = True
if isinsert:
+ has_all_pks = True
for col in mapper._cols_by_table[table]:
if col is mapper.version_id_col:
params[col.key] = \
@@ -1669,13 +1664,12 @@ class Mapper(object):
value = prop.get_col_value(col, value)
if value is None:
- if col.default is None and \
- col.server_default is None and \
- col not in pks:
-
- params[col.key] = value
- elif col in pks:
+ if col in pks:
has_all_pks = False
+ elif col.default is None and \
+ col.server_default is None:
+ params[col.key] = value
+
elif isinstance(value, sql.ClauseElement):
value_params[col] = value
else:
@@ -1684,6 +1678,7 @@ class Mapper(object):
insert.append((state, state_dict, params, mapper,
connection, value_params, has_all_pks))
else:
+ hasdata = False
for col in mapper._cols_by_table[table]:
if col is mapper.version_id_col:
params[col._label] = \
@@ -1765,7 +1760,8 @@ class Mapper(object):
else:
hasdata = True
elif col in pks:
- value = state.manager[prop.key].impl.get(state, state_dict)
+ value = state.manager[prop.key].\
+ impl.get(state, state_dict)
if prop.get_col_value:
value = prop.get_col_value(col, value)
params[col._label] = value
@@ -1796,7 +1792,6 @@ class Mapper(object):
statement = self._memo(('update', table), update_stmt)
rows = 0
- postfetch = []
for state, state_dict, params, mapper, \
connection, value_params in update:
@@ -1808,17 +1803,17 @@ class Mapper(object):
c = cached_connections[connection].\
execute(statement, params)
- postfetch.append((mapper, state, state_dict,
- c.prefetch_cols(), c.postfetch_cols(),
- c.last_updated_params(), value_params))
+ mapper._postfetch(
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c.context.prefetch_cols,
+ c.context.postfetch_cols,
+ c.context.compiled_parameters[0],
+ value_params)
rows += c.rowcount
- for mapper, pf in groupby(
- postfetch, lambda rec: rec[0]
- ):
- mapper._postfetch(uowtransaction, table, pf)
-
-
if connection.dialect.supports_sane_rowcount:
if rows != len(update):
raise orm_exc.StaleDataError(
@@ -1834,61 +1829,72 @@ class Mapper(object):
if insert:
statement = self._memo(('insert', table), table.insert)
- postfetch = []
- for (connection, pkeys, hasvalue, has_all_pks), records in groupby(
- insert, lambda rec: (rec[4], rec[2].keys(), bool(rec[5]), rec[6])
+ for (connection, pkeys, hasvalue, has_all_pks), \
+ records in groupby(insert,
+ lambda rec: (rec[4],
+ rec[2].keys(),
+ bool(rec[5]),
+ rec[6])
):
if has_all_pks and not hasvalue:
records = list(records)
- multiparams = [params for state, state_dict,
- params, mapper, conn, value_params,
- has_all_pks in records]
+ multiparams = [rec[2] for rec in records]
c = cached_connections[connection].\
execute(statement, multiparams)
- for (state, state_dict, params, mapper, conn, value_params, has_all_pks), \
- last_inserted_params in zip(records, c.context.compiled_parameters):
- postfetch.append((mapper, state, state_dict,
- c.prefetch_cols(), c.postfetch_cols(),
- last_inserted_params, {}))
+ for (state, state_dict, params, mapper,
+ conn, value_params, has_all_pks), \
+ last_inserted_params in \
+ zip(records, c.context.compiled_parameters):
+ mapper._postfetch(
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c.context.prefetch_cols,
+ c.context.postfetch_cols,
+ last_inserted_params,
+ value_params)
else:
for state, state_dict, params, mapper, \
- connection, value_params, has_all_pks in records:
+ connection, value_params, \
+ has_all_pks in records:
if value_params:
- c = connection.execute(
- statement.values(value_params),
- params)
+ result = connection.execute(
+ statement.values(value_params),
+ params)
else:
- c = cached_connections[connection].\
+ result = cached_connections[connection].\
execute(statement, params)
- primary_key = c.inserted_primary_key
+ primary_key = result.context.inserted_primary_key
if primary_key is not None:
# set primary key attributes
for pk, col in zip(primary_key,
mapper._pks_by_table[table]):
- # TODO: make sure this inlined code is OK
- # with composites
prop = mapper._columntoproperty[col]
if state_dict.get(prop.key) is None:
# TODO: would rather say:
#state_dict[prop.key] = pk
- mapper._set_state_attr_by_column(state,
- state_dict,
- col, pk)
+ mapper._set_state_attr_by_column(
+ state,
+ state_dict,
+ col, pk)
+
+ mapper._postfetch(
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ result.context.prefetch_cols,
+ result.context.postfetch_cols,
+ result.context.compiled_parameters[0],
+ value_params)
- postfetch.append((mapper, state, state_dict,
- c.prefetch_cols(), c.postfetch_cols(),
- c.last_inserted_params(), value_params))
-
- for mapper, pf in groupby(
- postfetch, lambda rec: rec[0]
- ):
- mapper._postfetch(uowtransaction, table, pf)
for state, state_dict, mapper, connection, has_identity, \
instance_key, row_switch in tups:
@@ -1915,36 +1921,32 @@ class Mapper(object):
mapper.dispatch.on_after_update(mapper, connection, state)
def _postfetch(self, uowtransaction, table,
- recs):
+ state, dict_, prefetch_cols, postfetch_cols,
+ params, value_params):
"""During a flush, expire attributes in need of newly
persisted database state."""
- for m, state, dict_, prefetch_cols, postfetch_cols, \
- params, value_params in recs:
- postfetch_cols = postfetch_cols
- generated_cols = list(prefetch_cols)
-
- if self.version_id_col is not None:
- generated_cols.append(self.version_id_col)
-
- for c in generated_cols:
- if c.key in params and c in self._columntoproperty:
- self._set_state_attr_by_column(state, dict_, c, params[c.key])
-
- if postfetch_cols:
- sessionlib._expire_state(state, state.dict,
- [self._columntoproperty[c].key
- for c in postfetch_cols]
- )
-
- # synchronize newly inserted ids from one table to the next
- # TODO: this still goes a little too often. would be nice to
- # have definitive list of "columns that changed" here
- for m, equated_pairs in self._table_to_equated[table]:
- sync.populate(state, m, state, m,
- equated_pairs,
- uowtransaction,
- self.passive_updates)
+ if self.version_id_col is not None:
+ prefetch_cols = list(prefetch_cols) + [self.version_id_col]
+
+ for c in prefetch_cols:
+ if c.key in params and c in self._columntoproperty:
+ self._set_state_attr_by_column(state, dict_, c, params[c.key])
+
+ if postfetch_cols:
+ sessionlib._expire_state(state, state.dict,
+ [self._columntoproperty[c].key
+ for c in postfetch_cols]
+ )
+
+ # synchronize newly inserted ids from one table to the next
+ # TODO: this still goes a little too often. would be nice to
+ # have definitive list of "columns that changed" here
+ for m, equated_pairs in self._table_to_equated[table]:
+ sync.populate(state, m, state, m,
+ equated_pairs,
+ uowtransaction,
+ self.passive_updates)
@util.memoized_property
def _table_to_equated(self):
@@ -1970,9 +1972,9 @@ class Mapper(object):
flush operation.
"""
- if 'connection_callable' in uowtransaction.mapper_flush_opts:
+ if uowtransaction.session.connection_callable:
connection_callable = \
- uowtransaction.mapper_flush_opts['connection_callable']
+ uowtransaction.session.connection_callable
else:
connection = uowtransaction.transaction.connection(self)
connection_callable = None
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 3517eab2b..30a84bf1a 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -511,7 +511,6 @@ class Session(object):
self._enable_transaction_accounting = _enable_transaction_accounting
self.twophase = twophase
self._query_cls = query_cls
- self._mapper_flush_opts = {}
if extension:
for ext in util.to_list(extension):
@@ -530,6 +529,8 @@ class Session(object):
dispatch = event.dispatcher(SessionEvents)
+ connection_callable = None
+
def begin(self, subtransactions=False, nested=False):
"""Begin a transaction on this Session.
diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py
index 875ce634b..d9d64fe39 100644
--- a/lib/sqlalchemy/orm/unitofwork.py
+++ b/lib/sqlalchemy/orm/unitofwork.py
@@ -76,7 +76,6 @@ class UOWEventHandler(interfaces.AttributeExtension):
class UOWTransaction(object):
def __init__(self, session):
self.session = session
- self.mapper_flush_opts = session._mapper_flush_opts
# dictionary used by external actors to
# store arbitrary state information.
@@ -316,7 +315,7 @@ class UOWTransaction(object):
postsort_actions):
rec.execute(self)
-
+
def finalize_flush_changes(self):
"""mark processed objects as clean / deleted after a successful flush().