diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-06-02 03:07:12 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-06-02 03:07:12 +0000 |
| commit | e525aee01556e59ff9fc02dd68fd6a38532fe45a (patch) | |
| tree | 3bb7b616ed238af5e3596c3250565794f19db0d7 /lib | |
| parent | e3e15357202d4c7918fe6a3b8d2d002e4ede3de2 (diff) | |
| download | sqlalchemy-e525aee01556e59ff9fc02dd68fd6a38532fe45a.tar.gz | |
- removed query.min()/max()/sum()/avg(). these should be called using column arguments or values in conjunction with func.
- fixed [ticket:1008], count() works with single table inheritance
- changed the relationship of InstrumentedAttribute to class such that each subclass in an inheritance hierarchy gets a unique InstrumentedAttribute per column-oriented attribute, including for the same underlying ColumnProperty. This allows expressions from subclasses to be annotated accurately so that Query can get a hold of the exact entities to be queried when using column-based expressions. This repairs various polymorphic scenarios with both single and joined table inheritance.
- still to be determined is what does something like query(Person.name, Engineer.engineer_info) do; currently it's problematic. Even trickier is query(Person.name, Engineer.engineer_info, Manager.manager_name)
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/sqlalchemy/orm/interfaces.py | 10 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/properties.py | 35 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/query.py | 99 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/strategies.py | 74 |
4 files changed, 144 insertions, 74 deletions
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 6c9fe7753..7b120e884 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -439,9 +439,10 @@ class PropComparator(expression.ColumnOperators): return a.has(b, **kwargs) has_op = staticmethod(has_op) - def __init__(self, prop): + def __init__(self, prop, mapper): self.prop = self.property = prop - + self.mapper = mapper + def of_type_op(a, class_): return a.of_type(class_) of_type_op = staticmethod(of_type_op) @@ -753,11 +754,12 @@ class LoaderStrategy(object): def __init__(self, parent): self.parent_property = parent self.is_class_level = False - - def init(self): self.parent = self.parent_property.parent self.key = self.parent_property.key + def init(self): + raise NotImplementedError("LoaderStrategy") + def init_class_attribute(self): pass diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index c7300f216..792824fe3 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -38,7 +38,7 @@ class ColumnProperty(StrategizedProperty): self.columns = [expression._labeled(c) for c in columns] self.group = kwargs.pop('group', None) self.deferred = kwargs.pop('deferred', False) - self.comparator = ColumnProperty.ColumnComparator(self) + self.comparator_factory = ColumnProperty.ColumnComparator util.set_creation_order(self) if self.deferred: self.strategy_class = strategies.DeferredColumnLoader @@ -80,7 +80,7 @@ class ColumnProperty(StrategizedProperty): class ColumnComparator(PropComparator): def __clause_element__(self): - return self.prop.columns[0]._annotate({"parententity": self.prop.parent}) + return self.prop.columns[0]._annotate({"parententity": self.mapper}) __clause_element__ = util.cache_decorator(__clause_element__) def operate(self, op, *other, **kwargs): @@ -101,7 +101,7 @@ class CompositeProperty(ColumnProperty): def __init__(self, class_, *columns, **kwargs): super(CompositeProperty, self).__init__(*columns, **kwargs) self.composite_class = class_ - self.comparator = kwargs.pop('comparator', CompositeProperty.Comparator)(self) + self.comparator_factory = kwargs.pop('comparator', CompositeProperty.Comparator) self.strategy_class = strategies.CompositeColumnLoader def do_init(self): @@ -170,8 +170,7 @@ class SynonymProperty(MapperProperty): def do_init(self): class_ = self.parent.class_ - def comparator(): - return self.parent._get_property(self.key, resolve_synonyms=True).comparator + self.logger.info("register managed attribute %s on class %s" % (self.key, class_.__name__)) if self.descriptor is None: class SynonymProp(object): @@ -184,7 +183,14 @@ class SynonymProperty(MapperProperty): return s return getattr(obj, self.name) self.descriptor = SynonymProp() - sessionlib.register_attribute(class_, self.key, uselist=False, proxy_property=self.descriptor, useobject=False, comparator=comparator, parententity=self.parent) + + def comparator_callable(prop, mapper): + def comparator(): + prop = self.parent._get_property(self.key, resolve_synonyms=True) + return prop.comparator_factory(prop, mapper) + return comparator + + strategies.DefaultColumnLoader(self)._register_attribute(None, None, False, comparator_callable, proxy_property=self.descriptor) def merge(self, session, source, dest, _recursive): pass @@ -195,18 +201,13 @@ class ComparableProperty(MapperProperty): def __init__(self, comparator_factory, descriptor=None): self.descriptor = descriptor - self.comparator = comparator_factory(self) + self.comparator_factory = comparator_factory util.set_creation_order(self) def do_init(self): """Set up a proxy to the unmanaged descriptor.""" - class_ = self.parent.class_ - # refactor me - sessionlib.register_attribute(class_, self.key, uselist=False, - proxy_property=self.descriptor, - useobject=False, - comparator=self.comparator) + strategies.DefaultColumnLoader(self)._register_attribute(None, None, False, self.comparator_factory, proxy_property=self.descriptor) def setup(self, context, entity, path, adapter, **kwargs): pass @@ -252,10 +253,11 @@ class PropertyLoader(StrategizedProperty): self.passive_updates = passive_updates self.remote_side = remote_side self.enable_typechecks = enable_typechecks - self.comparator = PropertyLoader.Comparator(self) + self.comparator = PropertyLoader.Comparator(self, None) self.join_depth = join_depth self.local_remote_pairs = _local_remote_pairs self.__join_cache = {} + self.comparator_factory = PropertyLoader.Comparator util.set_creation_order(self) if strategy_class: @@ -295,8 +297,9 @@ class PropertyLoader(StrategizedProperty): self._is_backref = _is_backref class Comparator(PropComparator): - def __init__(self, prop, of_type=None): + def __init__(self, prop, mapper, of_type=None): self.prop = self.property = prop + self.mapper = mapper if of_type: self._of_type = _class_to_mapper(of_type) @@ -314,7 +317,7 @@ class PropertyLoader(StrategizedProperty): return op(self, *other, **kwargs) def of_type(self, cls): - return PropertyLoader.Comparator(self.prop, cls) + return PropertyLoader.Comparator(self.prop, self.mapper, cls) def __eq__(self, other): if other is None: diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 555c376e5..43f206f38 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -97,7 +97,14 @@ class Query(object): self.__setup_aliasizers(self._entities) def __setup_aliasizers(self, entities): - d = {} + if hasattr(self, '_mapper_adapter_map'): + # usually safe to share a single map, but copying to prevent + # subtle leaks if end-user is reusing base query with arbitrary + # number of aliased() objects + self._mapper_adapter_map = d = self._mapper_adapter_map.copy() + else: + self._mapper_adapter_map = d = {} + for ent in entities: for entity in ent.entities: if entity not in d: @@ -114,7 +121,7 @@ class Query(object): d[entity] = (mapper, adapter, selectable, is_aliased_class, with_polymorphic) ent.setup_entity(entity, *d[entity]) - + def __mapper_loads_polymorphically_with(self, mapper, adapter): for m2 in mapper._with_polymorphic_mappers: for m in m2.iterate_to_root(): @@ -650,26 +657,6 @@ class Query(object): return self.filter(sql.and_(*clauses)) - def min(self, col): - """Execute the SQL ``min()`` function against the given column.""" - - return self._col_aggregate(col, sql.func.min) - - def max(self, col): - """Execute the SQL ``max()`` function against the given column.""" - - return self._col_aggregate(col, sql.func.max) - - def sum(self, col): - """Execute the SQL ``sum()`` function against the given column.""" - - return self._col_aggregate(col, sql.func.sum) - - def avg(self, col): - """Execute the SQL ``avg()`` function against the given column.""" - - return self._col_aggregate(col, sql.func.avg) - def order_by(self, *criterion): """apply one or more ORDER BY criterion to the query and return the newly resulting ``Query``""" @@ -1213,18 +1200,17 @@ class Query(object): _should_nest_selectable = property(_should_nest_selectable) def count(self): - """Apply this query's criterion to a SELECT COUNT statement. - - this is the purely generative version which will become - the public method in version 0.5. - - """ - return self._col_aggregate(sql.literal_column('1'), sql.func.count, nested_cols=list(self._mapper_zero().primary_key)) + """Apply this query's criterion to a SELECT COUNT statement.""" + + return self._col_aggregate(sql.literal_column('1'), sql.func.count, nested_cols=list(self._only_mapper_zero().primary_key)) def _col_aggregate(self, col, func, nested_cols=None): - whereclause = self._criterion - context = QueryContext(self) + + self._adjust_for_single_inheritance(context) + + whereclause = context.whereclause + from_obj = self.__mapper_zero_from_obj() if self._should_nest_selectable: @@ -1371,7 +1357,9 @@ class Query(object): froms = [context.from_clause] # "load from a single FROM" mode, i.e. when select_from() or join() is used else: froms = context.froms # "load from discrete FROMs" mode, i.e. when each _MappedEntity has its own FROM - + + self._adjust_for_single_inheritance(context) + if eager_joins and self._should_nest_selectable: # for eager joins present and LIMIT/OFFSET/DISTINCT, wrap the query inside a select, # then append eager joins onto that @@ -1382,7 +1370,15 @@ class Query(object): context.order_by = None order_by_col_expr = [] - inner = sql.select(context.primary_columns + order_by_col_expr, context.whereclause, from_obj=froms, use_labels=labels, correlate=False, order_by=context.order_by, **self._select_args) + inner = sql.select( + context.primary_columns + order_by_col_expr, + context.whereclause, + from_obj=froms, + use_labels=labels, + correlate=False, + order_by=context.order_by, + **self._select_args + ) if self._correlate: inner = inner.correlate(*self._correlate) @@ -1418,7 +1414,17 @@ class Query(object): froms += context.eager_joins.values() - statement = sql.select(context.primary_columns + context.secondary_columns, context.whereclause, from_obj=froms, use_labels=labels, for_update=for_update, correlate=False, order_by=context.order_by, **self._select_args) + statement = sql.select( + context.primary_columns + context.secondary_columns, + context.whereclause, + from_obj=froms, + use_labels=labels, + for_update=for_update, + correlate=False, + order_by=context.order_by, + **self._select_args + ) + if self._correlate: statement = statement.correlate(*self._correlate) @@ -1429,6 +1435,22 @@ class Query(object): return context + def _adjust_for_single_inheritance(self, context): + """Apply single-table-inheritance filtering. + + For all distinct single-table-inheritance mappers represented in the columns + clause of this query, add criterion to the WHERE clause of the given QueryContext + such that only the appropriate subtypes are selected from the total results. + + """ + for entity, (mapper, adapter, s, i, w) in self._mapper_adapter_map.iteritems(): + if mapper.single and mapper.inherits and mapper.polymorphic_on and mapper.polymorphic_identity is not None: + crit = mapper.polymorphic_on.in_([m.polymorphic_identity for m in mapper.polymorphic_iterator()]) + if adapter: + crit = adapter.traverse(crit) + crit = self._adapt_clause(crit, False, False) + context.whereclause = sql.and_(context.whereclause, crit) + def __log_debug(self, msg): self.logger.debug(msg) @@ -1463,7 +1485,7 @@ class _MapperEntity(_QueryEntity): self.entities = [entity] self.entity_zero = entity self.entity_name = entity_name - + def setup_entity(self, entity, mapper, adapter, from_obj, is_aliased_class, with_polymorphic): self.mapper = mapper self.extension = self.mapper.extension @@ -1554,15 +1576,10 @@ class _MapperEntity(_QueryEntity): return main, entname def setup_context(self, query, context): - # if single-table inheritance mapper, add "typecol IN (polymorphic)" criterion so - # that we only load the appropriate types - if self.mapper.single and self.mapper.inherits is not None and self.mapper.polymorphic_on is not None and self.mapper.polymorphic_identity is not None: - context.whereclause = sql.and_(context.whereclause, self.mapper.polymorphic_on.in_([m.polymorphic_identity for m in self.mapper.polymorphic_iterator()])) + adapter = self._get_entity_clauses(query, context) context.froms.append(self.selectable) - adapter = self._get_entity_clauses(query, context) - if context.order_by is False and self.mapper.order_by: context.order_by = self.mapper.order_by diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index d40937a18..829210205 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -17,11 +17,31 @@ from sqlalchemy.orm import session as sessionlib from sqlalchemy.orm import util as mapperutil -class ColumnLoader(LoaderStrategy): - """Default column loader.""" +class DefaultColumnLoader(LoaderStrategy): + def _register_attribute(self, compare_function, copy_function, mutable_scalars, comparator_factory, callable_=None, proxy_property=None): + self.logger.info("%s register managed attribute" % self) + + for mapper in self.parent.polymorphic_iterator(): + if mapper is self.parent or not mapper.concrete: + sessionlib.register_attribute( + mapper.class_, + self.key, + uselist=False, + useobject=False, + copy_function=copy_function, + compare_function=compare_function, + mutable_scalars=mutable_scalars, + comparator=comparator_factory(self.parent_property, mapper), + parententity=mapper, + callable_=callable_, + proxy_property=proxy_property + ) + +DefaultColumnLoader.logger = log.class_logger(DefaultColumnLoader) + +class ColumnLoader(DefaultColumnLoader): def init(self): - super(ColumnLoader, self).init() self.columns = self.parent_property.columns self._should_log_debug = log.is_debug_enabled(self.logger) self.is_composite = hasattr(self.parent_property, 'composite_class') @@ -34,9 +54,14 @@ class ColumnLoader(LoaderStrategy): def init_class_attribute(self): self.is_class_level = True - self.logger.info("%s register managed attribute" % self) coltype = self.columns[0].type - sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, copy_function=coltype.copy_value, compare_function=coltype.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator, parententity=self.parent) + + self._register_attribute( + coltype.compare_values, + coltype.copy_value, + self.columns[0].type.is_mutable(), + self.parent_property.comparator_factory + ) def create_row_processor(self, selectcontext, path, mapper, row, adapter): key, col = self.key, self.columns[0] @@ -78,7 +103,13 @@ class CompositeColumnLoader(ColumnLoader): return False else: return True - sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, copy_function=copy, compare_function=compare, mutable_scalars=True, comparator=self.parent_property.comparator, parententity=self.parent) + + self._register_attribute( + compare, + copy, + True, + self.parent_property.comparator_factory + ) def create_row_processor(self, selectcontext, path, mapper, row, adapter): key, columns, composite_class = self.key, self.columns, self.parent_property.composite_class @@ -106,7 +137,7 @@ class CompositeColumnLoader(ColumnLoader): CompositeColumnLoader.logger = log.class_logger(CompositeColumnLoader) -class DeferredColumnLoader(LoaderStrategy): +class DeferredColumnLoader(DefaultColumnLoader): """Deferred column loader, a per-column or per-column-group lazy loader.""" def create_row_processor(self, selectcontext, path, mapper, row, adapter): @@ -130,7 +161,6 @@ class DeferredColumnLoader(LoaderStrategy): return (new_execute, None) def init(self): - super(DeferredColumnLoader, self).init() if hasattr(self.parent_property, 'composite_class'): raise NotImplementedError("Deferred loading for composite types not implemented yet") self.columns = self.parent_property.columns @@ -139,8 +169,13 @@ class DeferredColumnLoader(LoaderStrategy): def init_class_attribute(self): self.is_class_level = True - self.logger.info("%s register managed attribute" % self) - sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, callable_=self.class_level_loader, copy_function=self.columns[0].type.copy_value, compare_function=self.columns[0].type.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator, parententity=self.parent) + self._register_attribute( + self.columns[0].type.compare_values, + self.columns[0].type.copy_value, + self.columns[0].type.is_mutable(), + self.parent_property.comparator_factory, + callable_=self.class_level_loader, + ) def setup_query(self, context, entity, path, adapter, only_load_props=None, **kwargs): if \ @@ -238,7 +273,6 @@ class UndeferGroupOption(MapperOption): class AbstractRelationLoader(LoaderStrategy): def init(self): - super(AbstractRelationLoader, self).init() for attr in ['mapper', 'target', 'table', 'uselist']: setattr(self, attr, getattr(self.parent_property, attr)) self._should_log_debug = log.is_debug_enabled(self.logger) @@ -249,7 +283,7 @@ class AbstractRelationLoader(LoaderStrategy): else: state.initialize(self.key) - def _register_attribute(self, class_, callable_=None, **kwargs): + def _register_attribute(self, class_, callable_=None, impl_class=None, **kwargs): self.logger.info("%s register managed %s attribute" % (self, (self.uselist and "collection" or "scalar"))) if self.parent_property.backref: @@ -257,7 +291,21 @@ class AbstractRelationLoader(LoaderStrategy): else: attribute_ext = None - sessionlib.register_attribute(class_, self.key, uselist=self.uselist, useobject=True, extension=attribute_ext, cascade=self.parent_property.cascade, trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_, comparator=self.parent_property.comparator, parententity=self.parent, **kwargs) + sessionlib.register_attribute( + class_, + self.key, + uselist=self.uselist, + useobject=True, + extension=attribute_ext, + cascade=self.parent_property.cascade, + trackparent=True, + typecallable=self.parent_property.collection_class, + callable_=callable_, + comparator=self.parent_property.comparator, + parententity=self.parent, + impl_class=impl_class, + **kwargs + ) class NoLoader(AbstractRelationLoader): def init_class_attribute(self): |
