diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2017-07-13 18:32:42 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2017-08-07 15:05:00 -0400 |
| commit | 68879d50faa9e2602e55d5d191647b1cf864e5ab (patch) | |
| tree | be9f9e906a3674aa7237ae564eee244931399bae /lib/sqlalchemy | |
| parent | 4b4f8fbf25f1a5a76c1579c1a3fd6ffad07c8c66 (diff) | |
| download | sqlalchemy-68879d50faa9e2602e55d5d191647b1cf864e5ab.tar.gz | |
Enable multi-level selectin polymorphic loading
Change-Id: Icc742bbeecdb7448ce84caccd63e086af16e81c1
Fixes: #4026
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/orm/loading.py | 27 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 66 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/session.py | 18 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/state.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/strategy_options.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/unitofwork.py | 26 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 13 |
7 files changed, 109 insertions, 44 deletions
diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 48c0db851..e4aea3994 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -360,20 +360,26 @@ def _instance_processor( if ( key in context.attributes and context.attributes[key].strategy == - (('selectinload_polymorphic', True), ) and - mapper in context.attributes[key].local_opts['mappers'] - ) or mapper.polymorphic_load == 'selectin': + (('selectinload_polymorphic', True), ) + ): + selectin_load_via = mapper._should_selectin_load( + context.attributes[key].local_opts['entities'], + _polymorphic_from) + else: + selectin_load_via = mapper._should_selectin_load( + None, _polymorphic_from) + if selectin_load_via and selectin_load_via is not _polymorphic_from: # only_load_props goes w/ refresh_state only, and in a refresh # we are a single row query for the exact entity; polymorphic # loading does not apply assert only_load_props is None - callable_ = _load_subclass_via_in(context, path, mapper) + callable_ = _load_subclass_via_in(context, path, selectin_load_via) PostLoad.callable_for_path( - context, load_path, mapper, - callable_, mapper) + context, load_path, selectin_load_via, + callable_, selectin_load_via) post_load = PostLoad.for_context(context, load_path, only_load_props) @@ -523,12 +529,15 @@ def _instance_processor( return _instance -@util.dependencies("sqlalchemy.ext.baked") -def _load_subclass_via_in(baked, context, path, mapper): +def _load_subclass_via_in(context, path, entity): + mapper = entity.mapper zero_idx = len(mapper.base_mapper.primary_key) == 1 - q, enable_opt, disable_opt = mapper._subclass_load_via_in + if entity.is_aliased_class: + q, enable_opt, disable_opt = mapper._subclass_load_via_in(entity) + else: + q, enable_opt, disable_opt = mapper._subclass_load_via_in_mapper def do_load(context, path, states, load_only, effective_entity): orig_query = context.query diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index d102618a2..9b9457213 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -2706,11 +2706,44 @@ class Mapper(InspectionAttr): cols.extend(props[key].columns) return sql.select(cols, cond, use_labels=True) - @_memoized_configured_property + def _iterate_to_target_viawpoly(self, mapper): + if self.isa(mapper): + prev = self + for m in self.iterate_to_root(): + yield m + + if m is not prev and prev not in \ + m._with_polymorphic_mappers: + break + + prev = m + if m is mapper: + break + + def _should_selectin_load(self, enabled_via_opt, polymorphic_from): + if not enabled_via_opt: + # common case, takes place for all polymorphic loads + mapper = polymorphic_from + for m in self._iterate_to_target_viawpoly(mapper): + if m.polymorphic_load == 'selectin': + return m + else: + # uncommon case, selectin load options were used + enabled_via_opt = set(enabled_via_opt) + enabled_via_opt_mappers = {e.mapper: e for e in enabled_via_opt} + for entity in enabled_via_opt.union([polymorphic_from]): + mapper = entity.mapper + for m in self._iterate_to_target_viawpoly(mapper): + if m.polymorphic_load == 'selectin' or \ + m in enabled_via_opt_mappers: + return enabled_via_opt_mappers.get(m, m) + + return None + @util.dependencies( "sqlalchemy.ext.baked", "sqlalchemy.orm.strategy_options") - def _subclass_load_via_in(self, baked, strategy_options): + def _subclass_load_via_in(self, baked, strategy_options, entity): """Assemble a BakedQuery that can load the columns local to this subclass as a SELECT with IN. @@ -2722,8 +2755,8 @@ class Mapper(InspectionAttr): keep_props = set( [polymorphic_prop] + self._identity_key_props) - disable_opt = strategy_options.Load(self) - enable_opt = strategy_options.Load(self) + disable_opt = strategy_options.Load(entity) + enable_opt = strategy_options.Load(entity) for prop in self.attrs: if prop.parent is self or prop in keep_props: @@ -2747,11 +2780,22 @@ class Mapper(InspectionAttr): else: in_expr = self.primary_key[0] - q = baked.BakedQuery( - self._compiled_cache, - lambda session: session.query(self), - (self, ) - ) + if entity.is_aliased_class: + assert entity.mapper is self + q = baked.BakedQuery( + self._compiled_cache, + lambda session: session.query(entity). + select_entity_from(entity.selectable)._adapt_all_clauses(), + (self, ) + ) + q.spoil() + else: + q = baked.BakedQuery( + self._compiled_cache, + lambda session: session.query(self), + (self, ) + ) + q += lambda q: q.filter( in_expr.in_( sql.bindparam('primary_keys', expanding=True) @@ -2760,6 +2804,10 @@ class Mapper(InspectionAttr): return q, enable_opt, disable_opt + @_memoized_configured_property + def _subclass_load_via_in_mapper(self): + return self._subclass_load_via_in(self) + def cascade_iterator(self, type_, state, halt_on=None): """Iterate each element and its mapper in an object graph, for all relationships that meet the given cascade rule. diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 7c313e635..752f182e5 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -1688,6 +1688,7 @@ class Session(_SessionClassMethods): state.key = instance_key self.identity_map.replace(state) + state._orphaned_outside_of_session = False statelib.InstanceState._commit_all_states( ((state, state.dict) for state in states), @@ -1762,6 +1763,7 @@ class Session(_SessionClassMethods): self.add(instance, _warn=False) def _save_or_update_state(self, state): + state._orphaned_outside_of_session = False self._save_or_update_impl(state) mapper = _state_mapper(state) @@ -2271,11 +2273,17 @@ class Session(_SessionClassMethods): proc = new.union(dirty).difference(deleted) for state in proc: - is_orphan = ( - _state_mapper(state)._is_orphan(state) and state.has_identity) - _reg = flush_context.register_object(state, isdelete=is_orphan) - assert _reg, "Failed to add object to the flush context!" - processed.add(state) + is_orphan = _state_mapper(state)._is_orphan(state) + + is_persistent_orphan = is_orphan and state.has_identity + + if is_orphan and not is_persistent_orphan and state._orphaned_outside_of_session: + self._expunge_states([state]) + else: + _reg = flush_context.register_object( + state, isdelete=is_persistent_orphan) + assert _reg, "Failed to add object to the flush context!" + processed.add(state) # put all remaining deletes into the flush context. if objset: diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index 1781a41e9..2e53fe9e3 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -61,6 +61,7 @@ class InstanceState(interfaces.InspectionAttr): expired = False _deleted = False _load_pending = False + _orphaned_outside_of_session = False is_instance = True callables = () diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index 796f859f8..c47536a02 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -1414,7 +1414,7 @@ def selectin_polymorphic(loadopt, classes): """ loadopt.set_class_strategy( {"selectinload_polymorphic": True}, - opts={"mappers": tuple(sorted((inspect(cls) for cls in classes), key=id))} + opts={"entities": tuple(sorted((inspect(cls) for cls in classes), key=id))} ) return loadopt diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index ee3e2043b..a3bd53637 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -52,22 +52,24 @@ def track_cascade_events(descriptor, prop): return sess = state.session - if sess: - prop = state.manager.mapper._props[key] + prop = state.manager.mapper._props[key] - if sess._warn_on_events: - sess._flush_warning( - "collection remove" - if prop.uselist - else "related attribute delete") + if sess and sess._warn_on_events: + sess._flush_warning( + "collection remove" + if prop.uselist + else "related attribute delete") - # expunge pending orphans - item_state = attributes.instance_state(item) - if prop._cascade.delete_orphan and \ - item_state in sess._new and \ - prop.mapper._is_orphan(item_state): + # expunge pending orphans + item_state = attributes.instance_state(item) + + if prop._cascade.delete_orphan and \ + prop.mapper._is_orphan(item_state): + if sess and item_state in sess._new: sess.expunge(item) + else: + item_state._orphaned_outside_of_session = True def set_(state, newvalue, oldvalue, initiator): # process "save_update" cascade rules for when an instance diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index c0854ea55..08d0f0aac 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -520,13 +520,10 @@ class AssertsExecutionResults(object): db, callable_, assertsql.CountStatements(count)) @contextlib.contextmanager - def assert_execution(self, *rules): - assertsql.asserter.add_rules(rules) - try: + def assert_execution(self, db, *rules): + with self.sql_execution_asserter(db) as asserter: yield - assertsql.asserter.statement_complete() - finally: - assertsql.asserter.clear_rules() + asserter.assert_(*rules) - def assert_statement_count(self, count): - return self.assert_execution(assertsql.CountStatements(count)) + def assert_statement_count(self, db, count): + return self.assert_execution(db, assertsql.CountStatements(count)) |
