summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2017-07-13 18:32:42 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2017-08-07 15:05:00 -0400
commit68879d50faa9e2602e55d5d191647b1cf864e5ab (patch)
treebe9f9e906a3674aa7237ae564eee244931399bae /lib/sqlalchemy
parent4b4f8fbf25f1a5a76c1579c1a3fd6ffad07c8c66 (diff)
downloadsqlalchemy-68879d50faa9e2602e55d5d191647b1cf864e5ab.tar.gz
Enable multi-level selectin polymorphic loading
Change-Id: Icc742bbeecdb7448ce84caccd63e086af16e81c1 Fixes: #4026
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/loading.py27
-rw-r--r--lib/sqlalchemy/orm/mapper.py66
-rw-r--r--lib/sqlalchemy/orm/session.py18
-rw-r--r--lib/sqlalchemy/orm/state.py1
-rw-r--r--lib/sqlalchemy/orm/strategy_options.py2
-rw-r--r--lib/sqlalchemy/orm/unitofwork.py26
-rw-r--r--lib/sqlalchemy/testing/assertions.py13
7 files changed, 109 insertions, 44 deletions
diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py
index 48c0db851..e4aea3994 100644
--- a/lib/sqlalchemy/orm/loading.py
+++ b/lib/sqlalchemy/orm/loading.py
@@ -360,20 +360,26 @@ def _instance_processor(
if (
key in context.attributes and
context.attributes[key].strategy ==
- (('selectinload_polymorphic', True), ) and
- mapper in context.attributes[key].local_opts['mappers']
- ) or mapper.polymorphic_load == 'selectin':
+ (('selectinload_polymorphic', True), )
+ ):
+ selectin_load_via = mapper._should_selectin_load(
+ context.attributes[key].local_opts['entities'],
+ _polymorphic_from)
+ else:
+ selectin_load_via = mapper._should_selectin_load(
+ None, _polymorphic_from)
+ if selectin_load_via and selectin_load_via is not _polymorphic_from:
# only_load_props goes w/ refresh_state only, and in a refresh
# we are a single row query for the exact entity; polymorphic
# loading does not apply
assert only_load_props is None
- callable_ = _load_subclass_via_in(context, path, mapper)
+ callable_ = _load_subclass_via_in(context, path, selectin_load_via)
PostLoad.callable_for_path(
- context, load_path, mapper,
- callable_, mapper)
+ context, load_path, selectin_load_via,
+ callable_, selectin_load_via)
post_load = PostLoad.for_context(context, load_path, only_load_props)
@@ -523,12 +529,15 @@ def _instance_processor(
return _instance
-@util.dependencies("sqlalchemy.ext.baked")
-def _load_subclass_via_in(baked, context, path, mapper):
+def _load_subclass_via_in(context, path, entity):
+ mapper = entity.mapper
zero_idx = len(mapper.base_mapper.primary_key) == 1
- q, enable_opt, disable_opt = mapper._subclass_load_via_in
+ if entity.is_aliased_class:
+ q, enable_opt, disable_opt = mapper._subclass_load_via_in(entity)
+ else:
+ q, enable_opt, disable_opt = mapper._subclass_load_via_in_mapper
def do_load(context, path, states, load_only, effective_entity):
orig_query = context.query
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index d102618a2..9b9457213 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -2706,11 +2706,44 @@ class Mapper(InspectionAttr):
cols.extend(props[key].columns)
return sql.select(cols, cond, use_labels=True)
- @_memoized_configured_property
+ def _iterate_to_target_viawpoly(self, mapper):
+ if self.isa(mapper):
+ prev = self
+ for m in self.iterate_to_root():
+ yield m
+
+ if m is not prev and prev not in \
+ m._with_polymorphic_mappers:
+ break
+
+ prev = m
+ if m is mapper:
+ break
+
+ def _should_selectin_load(self, enabled_via_opt, polymorphic_from):
+ if not enabled_via_opt:
+ # common case, takes place for all polymorphic loads
+ mapper = polymorphic_from
+ for m in self._iterate_to_target_viawpoly(mapper):
+ if m.polymorphic_load == 'selectin':
+ return m
+ else:
+ # uncommon case, selectin load options were used
+ enabled_via_opt = set(enabled_via_opt)
+ enabled_via_opt_mappers = {e.mapper: e for e in enabled_via_opt}
+ for entity in enabled_via_opt.union([polymorphic_from]):
+ mapper = entity.mapper
+ for m in self._iterate_to_target_viawpoly(mapper):
+ if m.polymorphic_load == 'selectin' or \
+ m in enabled_via_opt_mappers:
+ return enabled_via_opt_mappers.get(m, m)
+
+ return None
+
@util.dependencies(
"sqlalchemy.ext.baked",
"sqlalchemy.orm.strategy_options")
- def _subclass_load_via_in(self, baked, strategy_options):
+ def _subclass_load_via_in(self, baked, strategy_options, entity):
"""Assemble a BakedQuery that can load the columns local to
this subclass as a SELECT with IN.
@@ -2722,8 +2755,8 @@ class Mapper(InspectionAttr):
keep_props = set(
[polymorphic_prop] + self._identity_key_props)
- disable_opt = strategy_options.Load(self)
- enable_opt = strategy_options.Load(self)
+ disable_opt = strategy_options.Load(entity)
+ enable_opt = strategy_options.Load(entity)
for prop in self.attrs:
if prop.parent is self or prop in keep_props:
@@ -2747,11 +2780,22 @@ class Mapper(InspectionAttr):
else:
in_expr = self.primary_key[0]
- q = baked.BakedQuery(
- self._compiled_cache,
- lambda session: session.query(self),
- (self, )
- )
+ if entity.is_aliased_class:
+ assert entity.mapper is self
+ q = baked.BakedQuery(
+ self._compiled_cache,
+ lambda session: session.query(entity).
+ select_entity_from(entity.selectable)._adapt_all_clauses(),
+ (self, )
+ )
+ q.spoil()
+ else:
+ q = baked.BakedQuery(
+ self._compiled_cache,
+ lambda session: session.query(self),
+ (self, )
+ )
+
q += lambda q: q.filter(
in_expr.in_(
sql.bindparam('primary_keys', expanding=True)
@@ -2760,6 +2804,10 @@ class Mapper(InspectionAttr):
return q, enable_opt, disable_opt
+ @_memoized_configured_property
+ def _subclass_load_via_in_mapper(self):
+ return self._subclass_load_via_in(self)
+
def cascade_iterator(self, type_, state, halt_on=None):
"""Iterate each element and its mapper in an object graph,
for all relationships that meet the given cascade rule.
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 7c313e635..752f182e5 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -1688,6 +1688,7 @@ class Session(_SessionClassMethods):
state.key = instance_key
self.identity_map.replace(state)
+ state._orphaned_outside_of_session = False
statelib.InstanceState._commit_all_states(
((state, state.dict) for state in states),
@@ -1762,6 +1763,7 @@ class Session(_SessionClassMethods):
self.add(instance, _warn=False)
def _save_or_update_state(self, state):
+ state._orphaned_outside_of_session = False
self._save_or_update_impl(state)
mapper = _state_mapper(state)
@@ -2271,11 +2273,17 @@ class Session(_SessionClassMethods):
proc = new.union(dirty).difference(deleted)
for state in proc:
- is_orphan = (
- _state_mapper(state)._is_orphan(state) and state.has_identity)
- _reg = flush_context.register_object(state, isdelete=is_orphan)
- assert _reg, "Failed to add object to the flush context!"
- processed.add(state)
+ is_orphan = _state_mapper(state)._is_orphan(state)
+
+ is_persistent_orphan = is_orphan and state.has_identity
+
+ if is_orphan and not is_persistent_orphan and state._orphaned_outside_of_session:
+ self._expunge_states([state])
+ else:
+ _reg = flush_context.register_object(
+ state, isdelete=is_persistent_orphan)
+ assert _reg, "Failed to add object to the flush context!"
+ processed.add(state)
# put all remaining deletes into the flush context.
if objset:
diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py
index 1781a41e9..2e53fe9e3 100644
--- a/lib/sqlalchemy/orm/state.py
+++ b/lib/sqlalchemy/orm/state.py
@@ -61,6 +61,7 @@ class InstanceState(interfaces.InspectionAttr):
expired = False
_deleted = False
_load_pending = False
+ _orphaned_outside_of_session = False
is_instance = True
callables = ()
diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py
index 796f859f8..c47536a02 100644
--- a/lib/sqlalchemy/orm/strategy_options.py
+++ b/lib/sqlalchemy/orm/strategy_options.py
@@ -1414,7 +1414,7 @@ def selectin_polymorphic(loadopt, classes):
"""
loadopt.set_class_strategy(
{"selectinload_polymorphic": True},
- opts={"mappers": tuple(sorted((inspect(cls) for cls in classes), key=id))}
+ opts={"entities": tuple(sorted((inspect(cls) for cls in classes), key=id))}
)
return loadopt
diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py
index ee3e2043b..a3bd53637 100644
--- a/lib/sqlalchemy/orm/unitofwork.py
+++ b/lib/sqlalchemy/orm/unitofwork.py
@@ -52,22 +52,24 @@ def track_cascade_events(descriptor, prop):
return
sess = state.session
- if sess:
- prop = state.manager.mapper._props[key]
+ prop = state.manager.mapper._props[key]
- if sess._warn_on_events:
- sess._flush_warning(
- "collection remove"
- if prop.uselist
- else "related attribute delete")
+ if sess and sess._warn_on_events:
+ sess._flush_warning(
+ "collection remove"
+ if prop.uselist
+ else "related attribute delete")
- # expunge pending orphans
- item_state = attributes.instance_state(item)
- if prop._cascade.delete_orphan and \
- item_state in sess._new and \
- prop.mapper._is_orphan(item_state):
+ # expunge pending orphans
+ item_state = attributes.instance_state(item)
+
+ if prop._cascade.delete_orphan and \
+ prop.mapper._is_orphan(item_state):
+ if sess and item_state in sess._new:
sess.expunge(item)
+ else:
+ item_state._orphaned_outside_of_session = True
def set_(state, newvalue, oldvalue, initiator):
# process "save_update" cascade rules for when an instance
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index c0854ea55..08d0f0aac 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -520,13 +520,10 @@ class AssertsExecutionResults(object):
db, callable_, assertsql.CountStatements(count))
@contextlib.contextmanager
- def assert_execution(self, *rules):
- assertsql.asserter.add_rules(rules)
- try:
+ def assert_execution(self, db, *rules):
+ with self.sql_execution_asserter(db) as asserter:
yield
- assertsql.asserter.statement_complete()
- finally:
- assertsql.asserter.clear_rules()
+ asserter.assert_(*rules)
- def assert_statement_count(self, count):
- return self.assert_execution(assertsql.CountStatements(count))
+ def assert_statement_count(self, db, count):
+ return self.assert_execution(db, assertsql.CountStatements(count))