diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-06-20 19:28:29 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-06-20 19:28:29 -0400 |
| commit | 3dd536ac06808adcf9c10707dbf2ebb6e3842be7 (patch) | |
| tree | d102291da86021aa4584ef836dbb0471596c3eeb /lib | |
| parent | 40098941007ff3aa1593e834915c4042c1668dc2 (diff) | |
| download | sqlalchemy-3dd536ac06808adcf9c10707dbf2ebb6e3842be7.tar.gz | |
- [feature] The of_type() construct on attributes
now accepts aliased() class constructs as well
as with_polymorphic constructs, and works with
query.join(), any(), has(), and also
eager loaders subqueryload(), joinedload(),
contains_eager()
[ticket:2438] [ticket:1106]
- a rewrite of the query path system to use an
object based approach for more succinct usage. the system
has been designed carefully to not add an excessive method overhead.
- [feature] select() features a correlate_except()
method, auto correlates all selectables except those
passed. Is needed here for the updated any()/has()
functionality.
- remove some old cruft from LoaderStrategy, init(),debug_callable()
- use a namedtuple for _extended_entity_info. This method should
become standard within the orm internals
- some tweaks to the memory profile tests, number of runs can
be customized to work around pysqlite's very annoying behavior
- try to simplify PropertyOption._get_paths(), rename to _process_paths(),
returns a single list now. overall works more completely as was needed
for of_type() functionality
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/sqlalchemy/orm/__init__.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/attributes.py | 13 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/interfaces.py | 187 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 73 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/properties.py | 22 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/query.py | 68 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/state.py | 6 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/strategies.py | 354 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/util.py | 204 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 32 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/__init__.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/compat.py | 12 |
12 files changed, 597 insertions, 379 deletions
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 7c5955c5d..1750bc9f8 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -112,7 +112,8 @@ __all__ = ( 'synonym', 'undefer', 'undefer_group', - 'validates' + 'validates', + 'with_polymorphic' ) diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 55e0291b5..e71752ab5 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -103,12 +103,14 @@ class QueryableAttribute(interfaces.PropComparator): """Base class for class-bound attributes. """ def __init__(self, class_, key, impl=None, - comparator=None, parententity=None): + comparator=None, parententity=None, + of_type=None): self.class_ = class_ self.key = key self.impl = impl self.comparator = comparator self.parententity = parententity + self._of_type = of_type manager = manager_of_class(class_) # manager is None in the case of AliasedClass @@ -137,6 +139,15 @@ class QueryableAttribute(interfaces.PropComparator): def __clause_element__(self): return self.comparator.__clause_element__() + def of_type(self, cls): + return QueryableAttribute( + self.class_, + self.key, + self.impl, + self.comparator.of_type(cls), + self.parententity, + of_type=cls) + def label(self, name): return self.__clause_element__().label(name) diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index bda48cbb1..8d185e9f3 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -42,7 +42,6 @@ __all__ = ( 'SessionExtension', 'StrategizedOption', 'StrategizedProperty', - 'build_path', ) EXT_CONTINUE = util.symbol('EXT_CONTINUE') @@ -77,7 +76,7 @@ class MapperProperty(object): """ - def setup(self, context, entity, path, reduced_path, adapter, **kwargs): + def setup(self, context, entity, path, adapter, **kwargs): """Called by Query for the purposes of constructing a SQL statement. Each MapperProperty associated with the target mapper processes the @@ -87,7 +86,7 @@ class MapperProperty(object): pass - def create_row_processor(self, context, path, reduced_path, + def create_row_processor(self, context, path, mapper, row, adapter): """Return a 3-tuple consisting of three row processing functions. @@ -112,7 +111,7 @@ class MapperProperty(object): def set_parent(self, parent, init): self.parent = parent - def instrument_class(self, mapper): + def instrument_class(self, mapper): # pragma: no-coverage raise NotImplementedError() _compile_started = False @@ -308,15 +307,23 @@ class StrategizedProperty(MapperProperty): strategy_wildcard_key = None - def _get_context_strategy(self, context, reduced_path): - key = ('loaderstrategy', reduced_path) + @util.memoized_property + def _wildcard_path(self): + if self.strategy_wildcard_key: + return ('loaderstrategy', (self.strategy_wildcard_key,)) + else: + return None + + def _get_context_strategy(self, context, path): + # this is essentially performance inlining. + key = ('loaderstrategy', path.reduced_path + (self.key,)) cls = None if key in context.attributes: cls = context.attributes[key] - elif self.strategy_wildcard_key: - key = ('loaderstrategy', (self.strategy_wildcard_key,)) - if key in context.attributes: - cls = context.attributes[key] + else: + wc_key = self._wildcard_path + if wc_key and wc_key in context.attributes: + cls = context.attributes[wc_key] if cls: try: @@ -335,15 +342,15 @@ class StrategizedProperty(MapperProperty): self._strategies[cls] = strategy = cls(self) return strategy - def setup(self, context, entity, path, reduced_path, adapter, **kwargs): - self._get_context_strategy(context, reduced_path + (self.key,)).\ + def setup(self, context, entity, path, adapter, **kwargs): + self._get_context_strategy(context, path).\ setup_query(context, entity, path, - reduced_path, adapter, **kwargs) + adapter, **kwargs) - def create_row_processor(self, context, path, reduced_path, mapper, row, adapter): - return self._get_context_strategy(context, reduced_path + (self.key,)).\ + def create_row_processor(self, context, path, mapper, row, adapter): + return self._get_context_strategy(context, path).\ create_row_processor(context, path, - reduced_path, mapper, row, adapter) + mapper, row, adapter) def do_init(self): self._strategies = {} @@ -354,30 +361,6 @@ class StrategizedProperty(MapperProperty): not mapper.class_manager._attr_has_impl(self.key): self.strategy.init_class_attribute(mapper) -def build_path(entity, key, prev=None): - if prev: - return prev + (entity, key) - else: - return (entity, key) - -def serialize_path(path): - if path is None: - return None - - return zip( - [m.class_ for m in [path[i] for i in range(0, len(path), 2)]], - [path[i] for i in range(1, len(path), 2)] + [None] - ) - -def deserialize_path(path): - if path is None: - return None - - p = tuple(chain(*[(mapperutil.class_mapper(cls), key) for cls, key in path])) - if p and p[-1] is None: - p = p[0:-1] - return p - class MapperOption(object): """Describe a modification to a Query.""" @@ -414,11 +397,11 @@ class PropertyOption(MapperOption): self._process(query, False) def _process(self, query, raiseerr): - paths, mappers = self._get_paths(query, raiseerr) + paths = self._process_paths(query, raiseerr) if paths: - self.process_query_property(query, paths, mappers) + self.process_query_property(query, paths) - def process_query_property(self, query, paths, mappers): + def process_query_property(self, query, paths): pass def __getstate__(self): @@ -450,8 +433,7 @@ class PropertyOption(MapperOption): searchfor = mapperutil._class_to_mapper(mapper) isa = True for ent in query._mapper_entities: - if searchfor is ent.path_entity or isa \ - and searchfor.common_parent(ent.path_entity): + if ent.corresponds_to(searchfor): return ent else: if raiseerr: @@ -488,15 +470,21 @@ class PropertyOption(MapperOption): else: return None - def _get_paths(self, query, raiseerr): - path = None + def _process_paths(self, query, raiseerr): + """reconcile the 'key' for this PropertyOption with + the current path and entities of the query. + + Return a list of affected paths. + + """ + path = mapperutil.PathRegistry.root entity = None - l = [] - mappers = [] + paths = [] + no_result = [] # _current_path implies we're in a # secondary load with an existing path - current_path = list(query._current_path) + current_path = list(query._current_path.path) tokens = deque(self.key) while tokens: @@ -504,7 +492,7 @@ class PropertyOption(MapperOption): if isinstance(token, basestring): # wildcard token if token.endswith(':*'): - return [(token,)], [] + return [path.token(token)] sub_tokens = token.split(".", 1) token = sub_tokens[0] tokens.extendleft(sub_tokens[1:]) @@ -516,7 +504,7 @@ class PropertyOption(MapperOption): current_path = current_path[2:] continue else: - return [], [] + return no_result if not entity: entity = self._find_entity_basestring( @@ -524,10 +512,10 @@ class PropertyOption(MapperOption): token, raiseerr) if entity is None: - return [], [] - path_element = entity.path_entity + return no_result + path_element = entity.entity_zero mapper = entity.mapper - mappers.append(mapper) + if hasattr(mapper.class_, token): prop = getattr(mapper.class_, token).property else: @@ -538,7 +526,7 @@ class PropertyOption(MapperOption): token, mapper) ) else: - return [], [] + return no_result elif isinstance(token, PropComparator): prop = token.property @@ -550,7 +538,7 @@ class PropertyOption(MapperOption): current_path = current_path[2:] continue else: - return [], [] + return no_result if not entity: entity = self._find_entity_prop_comparator( @@ -559,10 +547,9 @@ class PropertyOption(MapperOption): token.parententity, raiseerr) if not entity: - return [], [] - path_element = entity.path_entity + return no_result + path_element = entity.entity_zero mapper = entity.mapper - mappers.append(prop.parent) else: raise sa_exc.ArgumentError( "mapper option expects " @@ -572,11 +559,20 @@ class PropertyOption(MapperOption): raise sa_exc.ArgumentError("Attribute '%s' does not " "link from element '%s'" % (token, path_element)) - path = build_path(path_element, prop.key, path) + path = path[path_element][prop.key] + + paths.append(path) - l.append(path) if getattr(token, '_of_type', None): - path_element = mapper = token._of_type + ac = token._of_type + ext_info = mapperutil._extended_entity_info(ac) + path_element = mapper = ext_info.mapper + if not ext_info.is_aliased_class: + ac = mapperutil.with_polymorphic( + ext_info.mapper.base_mapper, + ext_info.mapper, aliased=True) + ext_info = mapperutil._extended_entity_info(ac) + path.set(query, "path_with_polymorphic", ext_info) else: path_element = mapper = getattr(prop, 'mapper', None) if mapper is None and tokens: @@ -590,9 +586,9 @@ class PropertyOption(MapperOption): # ran out of tokens before # current_path was exhausted. assert not tokens - return [], [] + return no_result - return l, mappers + return paths class StrategizedOption(PropertyOption): """A MapperOption that affects which LoaderStrategy will be used @@ -601,40 +597,25 @@ class StrategizedOption(PropertyOption): chained = False - def process_query_property(self, query, paths, mappers): - - # _get_context_strategy may receive the path in terms of a base - # mapper - e.g. options(eagerload_all(Company.employees, - # Engineer.machines)) in the polymorphic tests leads to - # "(Person, 'machines')" in the path due to the mechanics of how - # the eager strategy builds up the path - + def process_query_property(self, query, paths): + strategy = self.get_strategy_class() if self.chained: for path in paths: - query._attributes[('loaderstrategy', - _reduce_path(path))] = \ - self.get_strategy_class() + path.set( + query, + "loaderstrategy", + strategy + ) else: - query._attributes[('loaderstrategy', - _reduce_path(paths[-1]))] = \ - self.get_strategy_class() + paths[-1].set( + query, + "loaderstrategy", + strategy + ) def get_strategy_class(self): raise NotImplementedError() -def _reduce_path(path): - """Convert a (mapper, path) path to use base mappers. - - This is used to allow more open ended selection of loader strategies, i.e. - Mapper -> prop1 -> Subclass -> prop2, where Subclass is a sub-mapper - of the mapper referenced by Mapper.prop1. - - """ - return tuple([i % 2 != 0 and - element or - getattr(element, 'base_mapper', element) - for i, element in enumerate(path)]) - class LoaderStrategy(object): """Describe the loading behavior of a StrategizedProperty object. @@ -663,22 +644,14 @@ class LoaderStrategy(object): self.is_class_level = False self.parent = self.parent_property.parent self.key = self.parent_property.key - # TODO: there's no particular reason we need - # the separate .init() method at this point. - # It's possible someone has written their - # own LS object. - self.init() - - def init(self): - raise NotImplementedError("LoaderStrategy") def init_class_attribute(self, mapper): pass - def setup_query(self, context, entity, path, reduced_path, adapter, **kwargs): + def setup_query(self, context, entity, path, adapter, **kwargs): pass - def create_row_processor(self, context, path, reduced_path, mapper, + def create_row_processor(self, context, path, mapper, row, adapter): """Return row processing functions which fulfill the contract specified by MapperProperty.create_row_processor. @@ -691,16 +664,6 @@ class LoaderStrategy(object): def __str__(self): return str(self.parent_property) - def debug_callable(self, fn, logger, announcement, logfn): - if announcement: - logger.debug(announcement) - if logfn: - def call(*args, **kwargs): - logger.debug(logfn(*args, **kwargs)) - return fn(*args, **kwargs) - return call - else: - return fn class InstrumentationManager(object): """User-defined class instrumentation extension. diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 2ec30f0ba..789f29c73 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -28,7 +28,8 @@ from sqlalchemy.orm.interfaces import MapperProperty, EXT_CONTINUE, \ PropComparator from sqlalchemy.orm.util import _INSTRUMENTOR, _class_to_mapper, \ - _state_mapper, class_mapper, instance_str, state_str + _state_mapper, class_mapper, instance_str, state_str,\ + PathRegistry import sys sessionlib = util.importlater("sqlalchemy.orm", "session") @@ -432,6 +433,10 @@ class Mapper(object): dispatch = event.dispatcher(events.MapperEvents) + @util.memoized_property + def _sa_path_registry(self): + return PathRegistry.per_mapper(self) + def _configure_inheritance(self): """Configure settings related to inherting and/or inherited mappers being present.""" @@ -1302,13 +1307,12 @@ class Mapper(object): return mappers - def _selectable_from_mappers(self, mappers): + def _selectable_from_mappers(self, mappers, innerjoin): """given a list of mappers (assumed to be within this mapper's inheritance hierarchy), construct an outerjoin amongst those mapper's mapped tables. """ - from_obj = self.mapped_table for m in mappers: if m is self: @@ -1318,7 +1322,11 @@ class Mapper(object): "'with_polymorphic()' requires 'selectable' argument " "when concrete-inheriting mappers are used.") elif not m.single: - from_obj = from_obj.outerjoin(m.local_table, + if innerjoin: + from_obj = from_obj.join(m.local_table, + m.inherit_condition) + else: + from_obj = from_obj.outerjoin(m.local_table, m.inherit_condition) return from_obj @@ -1350,9 +1358,11 @@ class Mapper(object): return selectable else: return self._selectable_from_mappers( - self._mappers_from_spec(spec, selectable)) + self._mappers_from_spec(spec, selectable), + False) - def _with_polymorphic_args(self, spec=None, selectable=False): + def _with_polymorphic_args(self, spec=None, selectable=False, + innerjoin=False): if self.with_polymorphic: if not spec: spec = self.with_polymorphic[0] @@ -1364,7 +1374,8 @@ class Mapper(object): if selectable is not None: return mappers, selectable else: - return mappers, self._selectable_from_mappers(mappers) + return mappers, self._selectable_from_mappers(mappers, + innerjoin) @_memoized_configured_property def _polymorphic_properties(self): @@ -1926,7 +1937,7 @@ class Mapper(object): return result - def _instance_processor(self, context, path, reduced_path, adapter, + def _instance_processor(self, context, path, adapter, polymorphic_from=None, only_load_props=None, refresh_state=None, polymorphic_discriminator=None): @@ -1951,7 +1962,7 @@ class Mapper(object): polymorphic_on = self.polymorphic_on polymorphic_instances = util.PopulateDict( self._configure_subclass_mapper( - context, path, reduced_path, adapter) + context, path, adapter) ) version_id_col = self.version_id_col @@ -1968,7 +1979,9 @@ class Mapper(object): new_populators = [] existing_populators = [] eager_populators = [] - load_path = context.query._current_path + path + load_path = context.query._current_path + path \ + if context.query._current_path.path \ + else path def populate_state(state, dict_, row, isnew, only_load_props): if isnew: @@ -1978,7 +1991,7 @@ class Mapper(object): state.load_path = load_path if not new_populators: - self._populators(context, path, reduced_path, row, adapter, + self._populators(context, path, row, adapter, new_populators, existing_populators, eager_populators @@ -2015,7 +2028,7 @@ class Mapper(object): def _instance(row, result): if not new_populators and invoke_all_eagers: - self._populators(context, path, reduced_path, row, adapter, + self._populators(context, path, row, adapter, new_populators, existing_populators, eager_populators @@ -2191,16 +2204,17 @@ class Mapper(object): return instance return _instance - def _populators(self, context, path, reduced_path, row, adapter, + def _populators(self, context, path, row, adapter, new_populators, existing_populators, eager_populators): - """Produce a collection of attribute level row processor callables.""" + """Produce a collection of attribute level row processor + callables.""" delayed_populators = [] - pops = (new_populators, existing_populators, delayed_populators, eager_populators) + pops = (new_populators, existing_populators, delayed_populators, + eager_populators) for prop in self._props.itervalues(): for i, pop in enumerate(prop.create_row_processor( - context, path, - reduced_path, + context, path, self, row, adapter)): if pop is not None: pops[i].append((prop.key, pop)) @@ -2208,7 +2222,7 @@ class Mapper(object): if delayed_populators: new_populators.extend(delayed_populators) - def _configure_subclass_mapper(self, context, path, reduced_path, adapter): + def _configure_subclass_mapper(self, context, path, adapter): """Produce a mapper level row processor callable factory for mappers inheriting this one.""" @@ -2223,18 +2237,17 @@ class Mapper(object): return None # replace the tip of the path info with the subclass mapper - # being used. that way accurate "load_path" info is available - # for options invoked during deferred loads. - # we lose AliasedClass path elements this way, but currently, - # those are not needed at this stage. - - # this asserts to true - #assert mapper.isa(_class_to_mapper(path[-1])) - - return mapper._instance_processor(context, path[0:-1] + (mapper,), - reduced_path[0:-1] + (mapper.base_mapper,), - adapter, - polymorphic_from=self) + # being used, that way accurate "load_path" info is available + # for options invoked during deferred loads, e.g. + # query(Person).options(defer(Engineer.machines, Machine.name)). + # for AliasedClass paths, disregard this step (new in 0.8). + return mapper._instance_processor( + context, + path.parent[mapper] + if not path.is_aliased_class + else path, + adapter, + polymorphic_from=self) return configure_subclass_mapper log.class_logger(Mapper) diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 9a2d05754..5634a9c5f 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -18,7 +18,8 @@ from sqlalchemy.sql import operators, expression, visitors from sqlalchemy.orm import attributes, dependency, mapper, \ object_mapper, strategies, configure_mappers, relationships from sqlalchemy.orm.util import CascadeOptions, _class_to_mapper, \ - _orm_annotate, _orm_deannotate, _orm_full_deannotate + _orm_annotate, _orm_deannotate, _orm_full_deannotate,\ + _entity_info from sqlalchemy.orm.interfaces import MANYTOMANY, MANYTOONE, \ MapperProperty, ONETOMANY, PropComparator, StrategizedProperty @@ -305,7 +306,7 @@ class RelationshipProperty(StrategizedProperty): self.mapper = mapper self.adapter = adapter if of_type: - self._of_type = _class_to_mapper(of_type) + self._of_type = of_type def adapted(self, adapter): """Return a copy of this PropComparator which will use the @@ -318,7 +319,7 @@ class RelationshipProperty(StrategizedProperty): getattr(self, '_of_type', None), adapter) - @property + @util.memoized_property def parententity(self): return self.property.parent @@ -406,9 +407,8 @@ class RelationshipProperty(StrategizedProperty): def _criterion_exists(self, criterion=None, **kwargs): if getattr(self, '_of_type', None): - target_mapper = self._of_type - to_selectable = target_mapper._with_polymorphic_selectable - if self.property._is_self_referential: + target_mapper, to_selectable, is_aliased_class = _entity_info(self._of_type) + if self.property._is_self_referential and not is_aliased_class: to_selectable = to_selectable.alias() single_crit = target_mapper._single_table_criterion @@ -418,6 +418,7 @@ class RelationshipProperty(StrategizedProperty): else: criterion = single_crit else: + is_aliased_class = False to_selectable = None if self.adapter: @@ -445,8 +446,7 @@ class RelationshipProperty(StrategizedProperty): else: j = _orm_annotate(pj, exclude=self.property.remote_side) - # MARKMARK - if criterion is not None and target_adapter: + if criterion is not None and target_adapter and not is_aliased_class: # limit this adapter to annotated only? criterion = target_adapter.traverse(criterion) @@ -460,8 +460,10 @@ class RelationshipProperty(StrategizedProperty): crit = j & criterion - return sql.exists([1], crit, from_obj=dest).\ - correlate(source._annotate({'_orm_adapt':True})) + ex = sql.exists([1], crit, from_obj=dest).correlate_except(dest) + if secondary is not None: + ex = ex.correlate_except(secondary) + return ex def any(self, criterion=None, **kwargs): """Produce an expression that tests a collection against diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 2c06063bc..987a77ba9 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -32,7 +32,7 @@ from sqlalchemy.orm import ( ) from sqlalchemy.orm.util import ( AliasedClass, ORMAdapter, _entity_descriptor, _entity_info, - _extended_entity_info, + _extended_entity_info, PathRegistry, _is_aliased_class, _is_mapped_class, _orm_columns, _orm_selectable, join as orm_join,with_parent, _attr_as_key, aliased ) @@ -53,6 +53,8 @@ def _generative(*assertions): return self return generate +_path_registry = PathRegistry.root + class Query(object): """ORM-level SQL construction object. @@ -88,7 +90,6 @@ class Query(object): _invoke_all_eagers = True _version_check = False _autoflush = True - _current_path = () _only_load_props = None _refresh_state = None _from_obj = () @@ -105,6 +106,8 @@ class Query(object): _with_hints = () _enable_single_crit = True + _current_path = _path_registry + def __init__(self, entities, session=None): self.session = session self._polymorphic_adapters = {} @@ -125,28 +128,25 @@ class Query(object): for ent in entities: for entity in ent.entities: if entity not in d: - mapper, selectable, \ - is_aliased_class, with_polymorphic_mappers, \ - with_polymorphic_discriminator = \ - _extended_entity_info(entity) - if not is_aliased_class and mapper.with_polymorphic: - if mapper.mapped_table not in \ + ext_info = _extended_entity_info(entity) + if not ext_info.is_aliased_class and ext_info.mapper.with_polymorphic: + if ext_info.mapper.mapped_table not in \ self._polymorphic_adapters: - self._mapper_loads_polymorphically_with(mapper, + self._mapper_loads_polymorphically_with(ext_info.mapper, sql_util.ColumnAdapter( - selectable, - mapper._equivalent_columns)) + ext_info.selectable, + ext_info.mapper._equivalent_columns)) aliased_adapter = None - elif is_aliased_class: + elif ext_info.is_aliased_class: aliased_adapter = sql_util.ColumnAdapter( - selectable, - mapper._equivalent_columns) + ext_info.selectable, + ext_info.mapper._equivalent_columns) else: aliased_adapter = None - d[entity] = (mapper, aliased_adapter, selectable, - is_aliased_class, with_polymorphic_mappers, - with_polymorphic_discriminator) + d[entity] = (ext_info.mapper, aliased_adapter, ext_info.selectable, + ext_info.is_aliased_class, ext_info.with_polymorphic_mappers, + ext_info.with_polymorphic_discriminator) ent.setup_entity(entity, *d[entity]) def _mapper_loads_polymorphically_with(self, mapper, adapter): @@ -1251,7 +1251,7 @@ class Query(object): def having(self, criterion): """apply a HAVING criterion to the query and return the newly resulting :class:`.Query`. - + :meth:`having` is used in conjunction with :meth:`group_by`. HAVING criterion makes it possible to use filters on aggregate @@ -1940,7 +1940,6 @@ class Query(object): raise sa_exc.InvalidRequestError( "Could not find a FROM clause to join from. " "Tried joining to %s, but got: %s" % (right, ae)) - self._from_obj = self._from_obj + (clause,) def _reset_joinpoint(self): @@ -2872,7 +2871,7 @@ class _MapperEntity(_QueryEntity): query._entities.append(self) self.entities = [entity] - self.entity_zero = self.expr = entity + self.expr = entity def setup_entity(self, entity, mapper, aliased_adapter, from_obj, is_aliased_class, @@ -2885,16 +2884,12 @@ class _MapperEntity(_QueryEntity): self._with_polymorphic = with_polymorphic self._polymorphic_discriminator = with_polymorphic_discriminator if is_aliased_class: - self.path_entity = self.entity_zero = entity - self._path = (entity,) + self.entity_zero = entity self._label_name = self.entity_zero._sa_label_name - self._reduced_path = (self.path_entity, ) else: - self.path_entity = mapper - self._path = (mapper,) - self._reduced_path = (mapper.base_mapper, ) self.entity_zero = mapper self._label_name = self.mapper.class_.__name__ + self.path = self.entity_zero._sa_path_registry def set_with_polymorphic(self, query, cls_or_mappers, selectable, polymorphic_on): @@ -2929,10 +2924,13 @@ class _MapperEntity(_QueryEntity): return self.entity_zero def corresponds_to(self, entity): - if _is_aliased_class(entity) or self.is_aliased_class: - return entity is self.path_entity + entity_info = _extended_entity_info(entity) + if entity_info.is_aliased_class or self.is_aliased_class: + return entity is self.entity_zero \ + or \ + entity in self._with_polymorphic else: - return entity.common_parent(self.path_entity) + return entity.common_parent(self.entity_zero) def adapt_to_selectable(self, query, sel): query._entities.append(self) @@ -2976,8 +2974,7 @@ class _MapperEntity(_QueryEntity): if self.primary_entity: _instance = self.mapper._instance_processor( context, - self._path, - self._reduced_path, + self.path, adapter, only_load_props=query._only_load_props, refresh_state=context.refresh_state, @@ -2987,8 +2984,7 @@ class _MapperEntity(_QueryEntity): else: _instance = self.mapper._instance_processor( context, - self._path, - self._reduced_path, + self.path, adapter, polymorphic_discriminator= self._polymorphic_discriminator) @@ -3024,8 +3020,7 @@ class _MapperEntity(_QueryEntity): value.setup( context, self, - self._path, - self._reduced_path, + self.path, adapter, only_load_props=query._only_load_props, column_collection=context.primary_columns @@ -3211,7 +3206,8 @@ class QueryContext(object): self.create_eager_joins = [] self.propagate_options = set(o for o in query._with_options if o.propagate_to_loaders) - self.attributes = query._attributes.copy() + self.attributes = self._attributes = query._attributes.copy() + class AliasOption(interfaces.MapperOption): diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index 9b0f7538f..720554483 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -24,6 +24,7 @@ from sqlalchemy.orm.attributes import PASSIVE_NO_RESULT, \ mapperlib = util.importlater("sqlalchemy.orm", "mapperlib") sessionlib = util.importlater("sqlalchemy.orm", "session") + class InstanceState(object): """tracks state information at the instance level.""" @@ -177,7 +178,7 @@ class InstanceState(object): ) if k in self.__dict__ ) if self.load_path: - d['load_path'] = interfaces.serialize_path(self.load_path) + d['load_path'] = self.load_path.serialize() self.manager.dispatch.pickle(self, d) @@ -222,7 +223,8 @@ class InstanceState(object): ]) if 'load_path' in state: - self.load_path = interfaces.deserialize_path(state['load_path']) + self.load_path = orm_util.PathRegistry.\ + deserialize(state['load_path']) # setup _sa_instance_state ahead of time so that # unpickle events can access the object normally. diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index d0f8962be..131ced0c9 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -73,7 +73,8 @@ def _register_attribute(strategy, mapper, useobject, compare_function=compare_function, useobject=useobject, extension=attribute_ext, - trackparent=useobject and (prop.single_parent or prop.direction is interfaces.ONETOMANY), + trackparent=useobject and (prop.single_parent + or prop.direction is interfaces.ONETOMANY), typecallable=typecallable, callable_=callable_, active_history=active_history, @@ -92,27 +93,29 @@ class UninstrumentedColumnLoader(LoaderStrategy): if the argument is against the with_polymorphic selectable. """ - def init(self): + def __init__(self, parent): + super(UninstrumentedColumnLoader, self).__init__(parent) self.columns = self.parent_property.columns - def setup_query(self, context, entity, path, reduced_path, adapter, + def setup_query(self, context, entity, path, adapter, column_collection=None, **kwargs): for c in self.columns: if adapter: c = adapter.columns[c] column_collection.append(c) - def create_row_processor(self, context, path, reduced_path, mapper, row, adapter): + def create_row_processor(self, context, path, mapper, row, adapter): return None, None, None class ColumnLoader(LoaderStrategy): """Provide loading behavior for a :class:`.ColumnProperty`.""" - def init(self): + def __init__(self, parent): + super(ColumnLoader, self).__init__(parent) self.columns = self.parent_property.columns self.is_composite = hasattr(self.parent_property, 'composite_class') - def setup_query(self, context, entity, path, reduced_path, + def setup_query(self, context, entity, path, adapter, column_collection, **kwargs): for c in self.columns: if adapter: @@ -131,7 +134,7 @@ class ColumnLoader(LoaderStrategy): active_history = active_history ) - def create_row_processor(self, context, path, reduced_path, + def create_row_processor(self, context, path, mapper, row, adapter): key = self.key # look through list of columns represented here @@ -153,7 +156,15 @@ log.class_logger(ColumnLoader) class DeferredColumnLoader(LoaderStrategy): """Provide loading behavior for a deferred :class:`.ColumnProperty`.""" - def create_row_processor(self, context, path, reduced_path, mapper, row, adapter): + def __init__(self, parent): + super(DeferredColumnLoader, self).__init__(parent) + if hasattr(self.parent_property, 'composite_class'): + raise NotImplementedError("Deferred loading for composite " + "types not implemented yet") + self.columns = self.parent_property.columns + self.group = self.parent_property.group + + def create_row_processor(self, context, path, mapper, row, adapter): col = self.columns[0] if adapter: col = adapter.columns[col] @@ -162,7 +173,7 @@ class DeferredColumnLoader(LoaderStrategy): if col in row: return self.parent_property._get_strategy(ColumnLoader).\ create_row_processor( - context, path, reduced_path, mapper, row, adapter) + context, path, mapper, row, adapter) elif not self.is_class_level: def set_deferred_for_local_state(state, dict_, row): @@ -175,13 +186,6 @@ class DeferredColumnLoader(LoaderStrategy): state.reset(dict_, key) return reset_col_for_deferred, None, None - def init(self): - if hasattr(self.parent_property, 'composite_class'): - raise NotImplementedError("Deferred loading for composite " - "types not implemented yet") - self.columns = self.parent_property.columns - self.group = self.parent_property.group - def init_class_attribute(self, mapper): self.is_class_level = True @@ -191,7 +195,7 @@ class DeferredColumnLoader(LoaderStrategy): expire_missing=False ) - def setup_query(self, context, entity, path, reduced_path, adapter, + def setup_query(self, context, entity, path, adapter, only_load_props=None, **kwargs): if ( self.group is not None and @@ -199,7 +203,7 @@ class DeferredColumnLoader(LoaderStrategy): ) or (only_load_props and self.key in only_load_props): self.parent_property._get_strategy(ColumnLoader).\ setup_query(context, entity, - path, reduced_path, adapter, **kwargs) + path, adapter, **kwargs) def _load_for_state(self, state, passive): if not state.key: @@ -276,12 +280,13 @@ class UndeferGroupOption(MapperOption): self.group = group def process_query(self, query): - query._attributes[('undefer', self.group)] = True + query._attributes[("undefer", self.group)] = True class AbstractRelationshipLoader(LoaderStrategy): """LoaderStratgies which deal with related objects.""" - def init(self): + def __init__(self, parent): + super(AbstractRelationshipLoader, self).__init__(parent) self.mapper = self.parent_property.mapper self.target = self.parent_property.target self.uselist = self.parent_property.uselist @@ -301,7 +306,7 @@ class NoLoader(AbstractRelationshipLoader): typecallable = self.parent_property.collection_class, ) - def create_row_processor(self, context, path, reduced_path, mapper, row, adapter): + def create_row_processor(self, context, path, mapper, row, adapter): def invoke_no_load(state, dict_, row): state.initialize(self.key) return invoke_no_load, None, None @@ -314,8 +319,8 @@ class LazyLoader(AbstractRelationshipLoader): """ - def init(self): - super(LazyLoader, self).init() + def __init__(self, parent): + super(LazyLoader, self).__init__(parent) join_condition = self.parent_property._join_condition self._lazywhere, \ self._bind_to_col, \ @@ -533,7 +538,7 @@ class LazyLoader(AbstractRelationshipLoader): q = q.autoflush(False) if state.load_path: - q = q._with_current_path(state.load_path + (self.key,)) + q = q._with_current_path(state.load_path[self.key]) if state.load_options: q = q._conditional_options(*state.load_options) @@ -578,7 +583,7 @@ class LazyLoader(AbstractRelationshipLoader): return None - def create_row_processor(self, context, path, reduced_path, + def create_row_processor(self, context, path, mapper, row, adapter): key = self.key if not self.is_class_level: @@ -633,11 +638,11 @@ class ImmediateLoader(AbstractRelationshipLoader): init_class_attribute(mapper) def setup_query(self, context, entity, - path, reduced_path, adapter, column_collection=None, + path, adapter, column_collection=None, parentmapper=None, **kwargs): pass - def create_row_processor(self, context, path, reduced_path, + def create_row_processor(self, context, path, mapper, row, adapter): def load_immediate(state, dict_, row): state.get_impl(self.key).get(state, dict_) @@ -645,8 +650,8 @@ class ImmediateLoader(AbstractRelationshipLoader): return None, None, load_immediate class SubqueryLoader(AbstractRelationshipLoader): - def init(self): - super(SubqueryLoader, self).init() + def __init__(self, parent): + super(SubqueryLoader, self).__init__(parent) self.join_depth = self.parent_property.join_depth def init_class_attribute(self, mapper): @@ -655,31 +660,36 @@ class SubqueryLoader(AbstractRelationshipLoader): init_class_attribute(mapper) def setup_query(self, context, entity, - path, reduced_path, adapter, + path, adapter, column_collection=None, parentmapper=None, **kwargs): if not context.query._enable_eagerloads: return - path = path + (self.key, ) - reduced_path = reduced_path + (self.key, ) + path = path[self.key] # build up a path indicating the path from the leftmost # entity to the thing we're subquery loading. - subq_path = context.attributes.get(('subquery_path', None), ()) + with_poly_info = path.get(context, "path_with_polymorphic", None) + if with_poly_info is not None: + effective_entity = with_poly_info.entity + else: + effective_entity = self.mapper + + subq_path = context.attributes.get(('subquery_path', None), + mapperutil.PathRegistry.root) subq_path = subq_path + path - # join-depth / recursion check - if ("loaderstrategy", reduced_path) not in context.attributes: + # if not via query option, check for + # a cycle + if not path.contains(context, "loaderstrategy"): if self.join_depth: - if len(path) / 2 > self.join_depth: - return - else: - if self.mapper.base_mapper in \ - interfaces._reduce_path(subq_path): + if path.length / 2 > self.join_depth: return + elif subq_path.contains_mapper(self.mapper): + return subq_mapper, leftmost_mapper, leftmost_attr = \ self._get_leftmost(subq_path) @@ -692,7 +702,7 @@ class SubqueryLoader(AbstractRelationshipLoader): # produce a subquery from it. left_alias = self._generate_from_original_query( orig_query, leftmost_mapper, - leftmost_attr, subq_path + leftmost_attr ) # generate another Query that will join the @@ -700,7 +710,7 @@ class SubqueryLoader(AbstractRelationshipLoader): # basically doing a longhand # "from_self()". (from_self() itself not quite industrial # strength enough for all contingencies...but very close) - q = orig_query.session.query(self.mapper) + q = orig_query.session.query(effective_entity) q._attributes = { ("orig_query", SubqueryLoader): orig_query, ('subquery_path', None) : subq_path @@ -712,16 +722,18 @@ class SubqueryLoader(AbstractRelationshipLoader): q = q.order_by(*local_attr) q = q.add_columns(*local_attr) - q = self._apply_joins(q, to_join, left_alias, parent_alias) + q = self._apply_joins(q, to_join, left_alias, + parent_alias, effective_entity) - q = self._setup_options(q, subq_path, orig_query) + q = self._setup_options(q, subq_path, orig_query, effective_entity) q = self._setup_outermost_orderby(q) # add new query to attributes to be picked up # by create_row_processor - context.attributes[('subquery', reduced_path)] = q + path.set(context, "subquery", q) def _get_leftmost(self, subq_path): + subq_path = subq_path.path subq_mapper = mapperutil._class_to_mapper(subq_path[0]) # determine attributes of the leftmost mapper @@ -743,7 +755,7 @@ class SubqueryLoader(AbstractRelationshipLoader): def _generate_from_original_query(self, orig_query, leftmost_mapper, - leftmost_attr, subq_path + leftmost_attr ): # reformat the original query # to look only for significant columns @@ -769,6 +781,8 @@ class SubqueryLoader(AbstractRelationshipLoader): def _prep_for_joins(self, left_alias, subq_path): + subq_path = subq_path.path + # figure out what's being joined. a.k.a. the fun part to_join = [ (subq_path[i], subq_path[i+1]) @@ -778,11 +792,14 @@ class SubqueryLoader(AbstractRelationshipLoader): # determine the immediate parent class we are joining from, # which needs to be aliased. + if len(to_join) > 1: + ext = mapperutil._extended_entity_info(subq_path[-2]) + if len(to_join) < 2: # in the case of a one level eager load, this is the # leftmost "left_alias". parent_alias = left_alias - elif subq_path[-2].isa(self.parent): + elif ext.mapper.isa(self.parent): # In the case of multiple levels, retrieve # it from subq_path[-2]. This is the same as self.parent # in the vast majority of cases, and [ticket:2014] @@ -800,10 +817,10 @@ class SubqueryLoader(AbstractRelationshipLoader): getattr(parent_alias, self.parent._columntoproperty[c].key) for c in local_cols ] - return to_join, local_attr, parent_alias - def _apply_joins(self, q, to_join, left_alias, parent_alias): + def _apply_joins(self, q, to_join, left_alias, parent_alias, + effective_entity): for i, (mapper, key) in enumerate(to_join): # we need to use query.join() as opposed to @@ -816,11 +833,18 @@ class SubqueryLoader(AbstractRelationshipLoader): first = i == 0 middle = i < len(to_join) - 1 second_to_last = i == len(to_join) - 2 + last = i == len(to_join) - 1 if first: attr = getattr(left_alias, key) + if last and effective_entity is not self.mapper: + attr = attr.of_type(effective_entity) else: - attr = key + if last and effective_entity is not self.mapper: + attr = getattr(parent_alias, key).\ + of_type(effective_entity) + else: + attr = key if second_to_last: q = q.join(parent_alias, attr, from_joinpoint=True) @@ -828,13 +852,14 @@ class SubqueryLoader(AbstractRelationshipLoader): q = q.join(attr, aliased=middle, from_joinpoint=True) return q - def _setup_options(self, q, subq_path, orig_query): + def _setup_options(self, q, subq_path, orig_query, effective_entity): # propagate loader options etc. to the new query. # these will fire relative to subq_path. q = q._with_current_path(subq_path) q = q._conditional_options(*orig_query._with_options) if orig_query._populate_existing: q._populate_existing = orig_query._populate_existing + return q def _setup_outermost_orderby(self, q): @@ -855,7 +880,7 @@ class SubqueryLoader(AbstractRelationshipLoader): q = q.order_by(*eager_order_by) return q - def create_row_processor(self, context, path, reduced_path, + def create_row_processor(self, context, path, mapper, row, adapter): if not self.parent.class_manager[self.key].impl.supports_population: raise sa_exc.InvalidRequestError( @@ -863,27 +888,26 @@ class SubqueryLoader(AbstractRelationshipLoader): "population - eager loading cannot be applied." % self) - reduced_path = reduced_path + (self.key,) + path = path[self.key] - if ('subquery', reduced_path) not in context.attributes: + subq = path.get(context, 'subquery') + if subq is None: return None, None, None local_cols = self.parent_property.local_columns - q = context.attributes[('subquery', reduced_path)] - # cache the loaded collections in the context # so that inheriting mappers don't re-load when they # call upon create_row_processor again - if ('collections', reduced_path) in context.attributes: - collections = context.attributes[('collections', reduced_path)] - else: - collections = context.attributes[('collections', reduced_path)] = dict( + collections = path.get(context, "collections") + if collections is None: + collections = dict( (k, [v[0] for v in v]) for k, v in itertools.groupby( - q, + subq, lambda x:x[1:] )) + path.set(context, 'collections', collections) if adapter: local_cols = [adapter.columns[c] for c in local_cols] @@ -929,98 +953,114 @@ class JoinedLoader(AbstractRelationshipLoader): using joined eager loading. """ - def init(self): - super(JoinedLoader, self).init() + def __init__(self, parent): + super(JoinedLoader, self).__init__(parent) self.join_depth = self.parent_property.join_depth def init_class_attribute(self, mapper): self.parent_property.\ _get_strategy(LazyLoader).init_class_attribute(mapper) - def setup_query(self, context, entity, path, reduced_path, adapter, \ + def setup_query(self, context, entity, path, adapter, \ column_collection=None, parentmapper=None, allow_innerjoin=True, **kwargs): """Add a left outer join to the statement thats being constructed.""" - if not context.query._enable_eagerloads: return - path = path + (self.key,) - reduced_path = reduced_path + (self.key,) + path = path[self.key] + + with_polymorphic = None - if ("user_defined_eager_row_processor", reduced_path) in\ - context.attributes: + user_defined_adapter = path.get(context, + "user_defined_eager_row_processor", + False) + if user_defined_adapter is not False: clauses, adapter, add_to_collection = \ self._get_user_defined_adapter( - context, entity, reduced_path, adapter + context, entity, path, adapter, + user_defined_adapter ) else: - # check for join_depth or basic recursion, - # if the current path was not explicitly stated as - # a desired "loaderstrategy" (i.e. via query.options()) - if ("loaderstrategy", reduced_path) not in context.attributes: + # if not via query option, check for + # a cycle + if not path.contains(context, "loaderstrategy"): if self.join_depth: - if len(path) / 2 > self.join_depth: - return - else: - if self.mapper.base_mapper in reduced_path: + if path.length / 2 > self.join_depth: return + elif path.contains_mapper(self.mapper): + return clauses, adapter, add_to_collection, \ allow_innerjoin = self._generate_row_adapter( - context, entity, path, reduced_path, adapter, + context, entity, path, adapter, column_collection, parentmapper, allow_innerjoin ) - path += (self.mapper,) - reduced_path += (self.mapper.base_mapper,) + with_poly_info = path.get( + context, + "path_with_polymorphic", + None + ) + if with_poly_info is not None: + with_polymorphic = with_poly_info.with_polymorphic_mappers + else: + with_polymorphic = None - for value in self.mapper._polymorphic_properties: + path = path[self.mapper] + for value in self.mapper._iterate_polymorphic_properties( + mappers=with_polymorphic): value.setup( context, entity, path, - reduced_path, clauses, parentmapper=self.mapper, column_collection=add_to_collection, allow_innerjoin=allow_innerjoin) def _get_user_defined_adapter(self, context, entity, - reduced_path, adapter): - clauses = context.attributes[ - ("user_defined_eager_row_processor", - reduced_path)] + path, adapter, user_defined_adapter): adapter = entity._get_entity_clauses(context.query, context) - if adapter and clauses: - context.attributes[ - ("user_defined_eager_row_processor", - reduced_path)] = clauses = clauses.wrap(adapter) + if adapter and user_defined_adapter: + user_defined_adapter = user_defined_adapter.wrap(adapter) + path.set(context, "user_defined_eager_row_processor", + user_defined_adapter) elif adapter: - context.attributes[ - ("user_defined_eager_row_processor", - reduced_path)] = clauses = adapter + user_defined_adapter = adapter + path.set(context, "user_defined_eager_row_processor", + user_defined_adapter) add_to_collection = context.primary_columns - return clauses, adapter, add_to_collection + return user_defined_adapter, adapter, add_to_collection def _generate_row_adapter(self, - context, entity, path, reduced_path, adapter, + context, entity, path, adapter, column_collection, parentmapper, allow_innerjoin ): + with_poly_info = path.get( + context, + "path_with_polymorphic", + None + ) + if with_poly_info: + to_adapt = with_poly_info.entity + else: + to_adapt = mapperutil.AliasedClass(self.mapper) clauses = mapperutil.ORMAdapter( - mapperutil.AliasedClass(self.mapper), + to_adapt, equivalents=self.mapper._equivalent_columns, adapt_required=True) + assert clauses.aliased_class is not None if self.parent_property.direction != interfaces.MANYTOONE: context.multi_row_eager_loaders = True - innerjoin = allow_innerjoin and context.attributes.get( - ("eager_join_type", path), + innerjoin = allow_innerjoin and path.get(context, + "eager_join_type", self.parent_property.innerjoin) if not innerjoin: # if this is an outer join, all eager joins from @@ -1034,9 +1074,7 @@ class JoinedLoader(AbstractRelationshipLoader): ) add_to_collection = context.secondary_columns - context.attributes[ - ("eager_row_processor", reduced_path) - ] = clauses + path.set(context, "eager_row_processor", clauses) return clauses, adapter, add_to_collection, allow_innerjoin def _create_eager_join(self, context, entity, @@ -1055,6 +1093,7 @@ class JoinedLoader(AbstractRelationshipLoader): context.query._should_nest_selectable entity_key = None + if entity not in context.eager_joins and \ not should_nest_selectable and \ context.from_clause: @@ -1096,6 +1135,7 @@ class JoinedLoader(AbstractRelationshipLoader): else: onclause = self.parent_property + assert clauses.aliased_class is not None context.eager_joins[entity_key] = eagerjoin = \ mapperutil.join( towrap, @@ -1134,12 +1174,12 @@ class JoinedLoader(AbstractRelationshipLoader): ) - def _create_eager_adapter(self, context, row, adapter, path, reduced_path): - if ("user_defined_eager_row_processor", reduced_path) in \ - context.attributes: - decorator = context.attributes[ - ("user_defined_eager_row_processor", - reduced_path)] + def _create_eager_adapter(self, context, row, adapter, path): + user_defined_adapter = path.get(context, + "user_defined_eager_row_processor", + False) + if user_defined_adapter is not False: + decorator = user_defined_adapter # user defined eagerloads are part of the "primary" # portion of the load. # the adapters applied to the Query should be honored. @@ -1147,11 +1187,10 @@ class JoinedLoader(AbstractRelationshipLoader): decorator = decorator.wrap(context.adapter) elif context.adapter: decorator = context.adapter - elif ("eager_row_processor", reduced_path) in context.attributes: - decorator = context.attributes[ - ("eager_row_processor", reduced_path)] else: - return False + decorator = path.get(context, "eager_row_processor") + if decorator is None: + return False try: self.mapper.identity_key_from_row(row, decorator) @@ -1161,28 +1200,26 @@ class JoinedLoader(AbstractRelationshipLoader): # processor, will cause a degrade to lazy return False - def create_row_processor(self, context, path, reduced_path, mapper, row, adapter): + def create_row_processor(self, context, path, mapper, row, adapter): if not self.parent.class_manager[self.key].impl.supports_population: raise sa_exc.InvalidRequestError( "'%s' does not support object " "population - eager loading cannot be applied." % self) - our_path = path + (self.key,) - our_reduced_path = reduced_path + (self.key,) + our_path = path[self.key] eager_adapter = self._create_eager_adapter( context, row, - adapter, our_path, - our_reduced_path) + adapter, our_path) if eager_adapter is not False: key = self.key + _instance = self.mapper._instance_processor( context, - our_path + (self.mapper,), - our_reduced_path + (self.mapper.base_mapper,), + our_path[self.mapper], eager_adapter) if not self.uselist: @@ -1193,8 +1230,7 @@ class JoinedLoader(AbstractRelationshipLoader): return self.parent_property.\ _get_strategy(LazyLoader).\ create_row_processor( - context, path, - reduced_path, + context, path, mapper, row, adapter) def _create_collection_loader(self, context, key, _instance): @@ -1279,19 +1315,18 @@ class EagerLazyOption(StrategizedOption): def get_strategy_class(self): return self.strategy_cls +_factory = { + False:JoinedLoader, + "joined":JoinedLoader, + None:NoLoader, + "noload":NoLoader, + "select":LazyLoader, + True:LazyLoader, + "subquery":SubqueryLoader, + "immediate":ImmediateLoader +} def factory(identifier): - if identifier is False or identifier == 'joined': - return JoinedLoader - elif identifier is None or identifier == 'noload': - return NoLoader - elif identifier is False or identifier == 'select': - return LazyLoader - elif identifier == 'subquery': - return SubqueryLoader - elif identifier == 'immediate': - return ImmediateLoader - else: - return LazyLoader + return _factory.get(identifier, LazyLoader) class EagerJoinOption(PropertyOption): @@ -1300,12 +1335,12 @@ class EagerJoinOption(PropertyOption): self.innerjoin = innerjoin self.chained = chained - def process_query_property(self, query, paths, mappers): + def process_query_property(self, query, paths): if self.chained: for path in paths: - query._attributes[("eager_join_type", path)] = self.innerjoin + path.set(query, "eager_join_type", self.innerjoin) else: - query._attributes[("eager_join_type", paths[-1])] = self.innerjoin + paths[-1].set(query, "eager_join_type", self.innerjoin) class LoadEagerFromAliasOption(PropertyOption): @@ -1313,36 +1348,41 @@ class LoadEagerFromAliasOption(PropertyOption): super(LoadEagerFromAliasOption, self).__init__(key) if alias is not None: if not isinstance(alias, basestring): - m, alias, is_aliased_class = mapperutil._entity_info(alias) + mapper, alias, is_aliased_class = \ + mapperutil._entity_info(alias) self.alias = alias self.chained = chained - def process_query_property(self, query, paths, mappers): + def process_query_property(self, query, paths): if self.chained: for path in paths[0:-1]: - (root_mapper, propname) = path[-2:] + (root_mapper, propname) = path.path[-2:] prop = root_mapper._props[propname] adapter = query._polymorphic_adapters.get(prop.mapper, None) - query._attributes.setdefault( - ("user_defined_eager_row_processor", - interfaces._reduce_path(path)), adapter) + path.setdefault(query, + "user_defined_eager_row_processor", + adapter) + root_mapper, propname = paths[-1].path[-2:] + prop = root_mapper._props[propname] if self.alias is not None: if isinstance(self.alias, basestring): - (root_mapper, propname) = paths[-1][-2:] - prop = root_mapper._props[propname] self.alias = prop.target.alias(self.alias) - query._attributes[ - ("user_defined_eager_row_processor", - interfaces._reduce_path(paths[-1])) - ] = sql_util.ColumnAdapter(self.alias) + paths[-1].set(query, "user_defined_eager_row_processor", + sql_util.ColumnAdapter(self.alias, + equivalents=prop.mapper._equivalent_columns) + ) else: - (root_mapper, propname) = paths[-1][-2:] - prop = root_mapper._props[propname] - adapter = query._polymorphic_adapters.get(prop.mapper, None) - query._attributes[ - ("user_defined_eager_row_processor", - interfaces._reduce_path(paths[-1]))] = adapter + if paths[-1].contains(query, "path_with_polymorphic"): + with_poly_info = paths[-1].get(query, "path_with_polymorphic") + adapter = mapperutil.ORMAdapter( + with_poly_info.entity, + equivalents=prop.mapper._equivalent_columns, + adapt_required=True) + else: + adapter = query._polymorphic_adapters.get(prop.mapper, None) + paths[-1].set(query, "user_defined_eager_row_processor", + adapter) def single_parent_validator(desc, prop): def _do_check(state, value, oldvalue, initiator): @@ -1363,6 +1403,8 @@ def single_parent_validator(desc, prop): def set_(state, value, oldvalue, initiator): return _do_check(state, value, oldvalue, initiator) - event.listen(desc, 'append', append, raw=True, retval=True, active_history=True) - event.listen(desc, 'set', set_, raw=True, retval=True, active_history=True) + event.listen(desc, 'append', append, raw=True, retval=True, + active_history=True) + event.listen(desc, 'set', set_, raw=True, retval=True, + active_history=True) diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 51aaa3152..0978ab693 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -9,6 +9,7 @@ from sqlalchemy import sql, util, event, exc as sa_exc, inspection from sqlalchemy.sql import expression, util as sql_util, operators from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE,\ PropComparator, MapperProperty +from itertools import chain from sqlalchemy.orm import attributes, exc import operator import re @@ -233,6 +234,144 @@ class ORMAdapter(sql_util.ColumnAdapter): else: return None +class PathRegistry(object): + """Represent query load paths and registry functions. + + Basically represents structures like: + + (<User mapper>, "orders", <Order mapper>, "items", <Item mapper>) + + These structures are generated by things like + query options (joinedload(), subqueryload(), etc.) and are + used to compose keys stored in the query._attributes dictionary + for various options. + + They are then re-composed at query compile/result row time as + the query is formed and as rows are fetched, where they again + serve to compose keys to look up options in the context.attributes + dictionary, which is copied from query._attributes. + + The path structure has a limited amount of caching, where each + "root" ultimately pulls from a fixed registry associated with + the first mapper, that also contains elements for each of its + property keys. However paths longer than two elements, which + are the exception rather than the rule, are generated on an + as-needed basis. + + """ + + def __eq__(self, other): + return other is not None and \ + self.path == other.path + + def set(self, reg, key, value): + reg._attributes[(key, self.reduced_path)] = value + + def setdefault(self, reg, key, value): + reg._attributes.setdefault((key, self.reduced_path), value) + + def get(self, reg, key, value=None): + key = (key, self.reduced_path) + if key in reg._attributes: + return reg._attributes[key] + else: + return value + + @property + def length(self): + return len(self.path) + + def contains_mapper(self, mapper): + return mapper.base_mapper in self.reduced_path + + def contains(self, reg, key): + return (key, self.reduced_path) in reg._attributes + + def serialize(self): + path = self.path + return zip( + [m.class_ for m in [path[i] for i in range(0, len(path), 2)]], + [path[i] for i in range(1, len(path), 2)] + [None] + ) + + @classmethod + def deserialize(cls, path): + if path is None: + return None + + p = tuple(chain(*[(class_mapper(mcls), key) for mcls, key in path])) + if p and p[-1] is None: + p = p[0:-1] + return cls.coerce(p) + + @classmethod + def per_mapper(cls, mapper): + return EntityRegistry( + cls.root, mapper + ) + + @classmethod + def coerce(cls, raw): + return util.reduce(lambda prev, next:prev[next], raw, cls.root) + + @classmethod + def token(cls, token): + return KeyRegistry(cls.root, token) + + def __add__(self, other): + return util.reduce( + lambda prev, next:prev[next], + other.path, self) + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self.path, ) + +class RootRegistry(PathRegistry): + """Root registry, defers to mappers so that + paths are maintained per-root-mapper. + + """ + path = () + reduced_path = () + + def __getitem__(self, mapper): + return mapper._sa_path_registry +PathRegistry.root = RootRegistry() + +class KeyRegistry(PathRegistry): + def __init__(self, parent, key): + self.key = key + self.parent = parent + self.path = parent.path + (key,) + self.reduced_path = parent.reduced_path + (key,) + + def __getitem__(self, entity): + return EntityRegistry( + self, entity + ) + +class EntityRegistry(PathRegistry, dict): + is_aliased_class = False + + def __init__(self, parent, entity): + self.key = reduced_key = entity + self.parent = parent + if hasattr(entity, 'base_mapper'): + reduced_key = entity.base_mapper + else: + self.is_aliased_class = True + + self.path = parent.path + (entity,) + self.reduced_path = parent.reduced_path + (reduced_key,) + + def __nonzero__(self): + return True + + def __missing__(self, key): + self[key] = item = KeyRegistry(self, key) + return item + + class AliasedClass(object): """Represents an "aliased" form of a mapped class for usage with Query. @@ -321,6 +460,10 @@ class AliasedClass(object): self._sa_label_name = name self.__name__ = 'AliasedClass_' + str(self.__target) + @util.memoized_property + def _sa_path_registry(self): + return PathRegistry.per_mapper(self) + def __getstate__(self): return { 'mapper':self.__mapper, @@ -408,7 +551,8 @@ def aliased(element, alias=None, name=None, adapt_on_names=False): name=name, adapt_on_names=adapt_on_names) def with_polymorphic(base, classes, selectable=False, - polymorphic_on=None, aliased=False): + polymorphic_on=None, aliased=False, + innerjoin=False): """Produce an :class:`.AliasedClass` construct which specifies columns for descendant mappers of the given base. @@ -422,23 +566,23 @@ def with_polymorphic(base, classes, selectable=False, criterion to be used against those tables. The resulting instances will also have those columns already loaded so that no "post fetch" of those columns will be required. - + See the examples at :ref:`with_polymorphic`. :param base: Base class to be aliased. - + :param cls_or_mappers: a single class or mapper, or list of class/mappers, which inherit from the base class. Alternatively, it may also be the string ``'*'``, in which case all descending mapped classes will be added to the FROM clause. - + :param aliased: when True, the selectable will be wrapped in an alias, that is ``(SELECT * FROM <fromclauses>) AS anon_1``. This can be important when using the with_polymorphic() to create the target of a JOIN on a backend that does not support parenthesized joins, such as SQLite and older versions of MySQL. - + :param selectable: a table or select() statement that will be used in place of the generated FROM clause. This argument is required if any of the desired classes use concrete table @@ -455,10 +599,12 @@ def with_polymorphic(base, classes, selectable=False, is useful for mappings that don't have polymorphic loading behavior by default. + :param innerjoin: if True, an INNER JOIN will be used. This should + only be specified if querying for one specific subtype only """ - primary_mapper = class_mapper(base) + primary_mapper = _class_to_mapper(base) mappers, selectable = primary_mapper.\ - _with_polymorphic_args(classes, selectable) + _with_polymorphic_args(classes, selectable, innerjoin=innerjoin) if aliased: selectable = selectable.alias() return AliasedClass(base, @@ -478,11 +624,11 @@ def _orm_annotate(element, exclude=None): def _orm_deannotate(element): """Remove annotations that link a column to a particular mapping. - + Note this doesn't affect "remote" and "foreign" annotations passed by the :func:`.orm.foreign` and :func:`.orm.remote` annotators. - + """ return sql_util._deep_deannotate(element, @@ -644,13 +790,24 @@ def with_parent(instance, prop): value_is_parent=True) +_extended_entity_info_tuple = util.namedtuple("extended_entity_info", [ + "entity", + "mapper", + "selectable", + "is_aliased_class", + "with_polymorphic_mappers", + "with_polymorphic_discriminator" +]) def _extended_entity_info(entity, compile=True): if isinstance(entity, AliasedClass): - return entity._AliasedClass__mapper, \ - entity._AliasedClass__alias, \ - True, \ - entity._AliasedClass__with_polymorphic_mappers, \ - entity._AliasedClass__with_polymorphic_discriminator + return _extended_entity_info_tuple( + entity, + entity._AliasedClass__mapper, \ + entity._AliasedClass__alias, \ + True, \ + entity._AliasedClass__with_polymorphic_mappers, \ + entity._AliasedClass__with_polymorphic_discriminator + ) if isinstance(entity, mapperlib.Mapper): mapper = entity @@ -659,19 +816,22 @@ def _extended_entity_info(entity, compile=True): class_manager = attributes.manager_of_class(entity) if class_manager is None: - return None, entity, False, [], None + return _extended_entity_info_tuple(entity, None, entity, False, [], None) mapper = class_manager.mapper else: - return None, entity, False, [], None + return _extended_entity_info_tuple(entity, None, entity, False, [], None) if compile and mapperlib.module._new_mappers: mapperlib.configure_mappers() - return mapper, \ + return _extended_entity_info_tuple( + entity, + mapper, \ mapper._with_polymorphic_selectable, \ False, \ mapper._with_polymorphic_mappers, \ mapper.polymorphic_on + ) def _entity_info(entity, compile=True): """Return mapping information given a class, mapper, or AliasedClass. @@ -684,7 +844,7 @@ def _entity_info(entity, compile=True): unmapped selectables through. """ - return _extended_entity_info(entity, compile)[0:3] + return _extended_entity_info(entity, compile)[1:4] def _entity_descriptor(entity, key): """Return a class attribute given an entity and string name. @@ -738,7 +898,7 @@ def object_mapper(instance): Raises UnmappedInstanceError if no mapping is configured. This function is available via the inspection system as:: - + inspect(instance).mapper """ @@ -752,7 +912,7 @@ def object_state(instance): Raises UnmappedInstanceError if no mapping is configured. This function is available via the inspection system as:: - + inspect(instance) """ @@ -776,9 +936,9 @@ def class_mapper(class_, compile=True): object is passed. This function is available via the inspection system as:: - + inspect(some_mapped_class) - + """ try: diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index f836d7eaf..bc0497bea 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -3500,9 +3500,14 @@ class _Exists(_UnaryExpression): def select(self, whereclause=None, **params): return select([self], whereclause, **params) - def correlate(self, fromclause): + def correlate(self, *fromclause): e = self._clone() - e.element = self.element.correlate(fromclause).self_group() + e.element = self.element.correlate(*fromclause).self_group() + return e + + def correlate_except(self, *fromclause): + e = self._clone() + e.element = self.element.correlate_except(*fromclause).self_group() return e def select_from(self, clause): @@ -4708,7 +4713,8 @@ class Select(_SelectBase): _hints = util.immutabledict() _distinct = False _from_cloned = None - + _correlate = () + _correlate_except = () _memoized_property = _SelectBase._memoized_property def __init__(self, @@ -4750,7 +4756,6 @@ class Select(_SelectBase): for e in util.to_list(distinct) ] - self._correlate = set() if from_obj is not None: self._from_obj = util.OrderedSet( _literal_as_text(f) @@ -4837,10 +4842,13 @@ class Select(_SelectBase): # using a list to maintain ordering froms = [f for f in froms if f not in toremove] - if len(froms) > 1 or self._correlate: + if len(froms) > 1 or self._correlate or self._correlate_except: if self._correlate: froms = [f for f in froms if f not in _cloned_intersection(froms, self._correlate)] + if self._correlate_except: + froms = [f for f in froms if f in _cloned_intersection(froms, + self._correlate_except)] if self._should_correlate and existing_froms: froms = [f for f in froms if f not in _cloned_intersection(froms, existing_froms)] @@ -5198,16 +5206,24 @@ class Select(_SelectBase): """ self._should_correlate = False if fromclauses and fromclauses[0] is None: - self._correlate = set() + self._correlate = () + else: + self._correlate = set(self._correlate).union(fromclauses) + + @_generative + def correlate_except(self, *fromclauses): + self._should_correlate = False + if fromclauses and fromclauses[0] is None: + self._correlate_except = () else: - self._correlate = self._correlate.union(fromclauses) + self._correlate_except = set(self._correlate_except).union(fromclauses) def append_correlation(self, fromclause): """append the given correlation expression to this select() construct.""" self._should_correlate = False - self._correlate = self._correlate.union([fromclause]) + self._correlate = set(self._correlate).union([fromclause]) def append_column(self, column): """append the given column expression to the columns clause of this diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 76c3c829d..3cfe55f9c 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -7,7 +7,7 @@ from compat import callable, cmp, reduce, defaultdict, py25_dict, \ threading, py3k_warning, jython, pypy, win32, set_types, buffer, pickle, \ update_wrapper, partial, md5_hex, decode_slice, dottedgetter,\ - parse_qsl, any, contextmanager + parse_qsl, any, contextmanager, namedtuple from _collections import NamedTuple, ImmutableContainer, immutabledict, \ Properties, OrderedProperties, ImmutableProperties, OrderedDict, \ diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 99b92b1e3..c5339d013 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -112,6 +112,18 @@ else: reduce = reduce try: + from collections import namedtuple +except ImportError: + def namedtuple(typename, fieldnames): + def __new__(cls, *values): + tup = tuple.__new__(tuptype, values) + for i, fname in enumerate(fieldnames): + setattr(tup, fname, tup[i]) + return tup + tuptype = type(typename, (tuple, ), {'__new__':__new__}) + return tuptype + +try: from collections import defaultdict except ImportError: class defaultdict(dict): |
