summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2008-06-02 03:07:12 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2008-06-02 03:07:12 +0000
commite525aee01556e59ff9fc02dd68fd6a38532fe45a (patch)
tree3bb7b616ed238af5e3596c3250565794f19db0d7 /lib
parente3e15357202d4c7918fe6a3b8d2d002e4ede3de2 (diff)
downloadsqlalchemy-e525aee01556e59ff9fc02dd68fd6a38532fe45a.tar.gz
- removed query.min()/max()/sum()/avg(). these should be called using column arguments or values in conjunction with func.
- fixed [ticket:1008], count() works with single table inheritance - changed the relationship of InstrumentedAttribute to class such that each subclass in an inheritance hierarchy gets a unique InstrumentedAttribute per column-oriented attribute, including for the same underlying ColumnProperty. This allows expressions from subclasses to be annotated accurately so that Query can get a hold of the exact entities to be queried when using column-based expressions. This repairs various polymorphic scenarios with both single and joined table inheritance. - still to be determined is what does something like query(Person.name, Engineer.engineer_info) do; currently it's problematic. Even trickier is query(Person.name, Engineer.engineer_info, Manager.manager_name)
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/orm/interfaces.py10
-rw-r--r--lib/sqlalchemy/orm/properties.py35
-rw-r--r--lib/sqlalchemy/orm/query.py99
-rw-r--r--lib/sqlalchemy/orm/strategies.py74
4 files changed, 144 insertions, 74 deletions
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index 6c9fe7753..7b120e884 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -439,9 +439,10 @@ class PropComparator(expression.ColumnOperators):
return a.has(b, **kwargs)
has_op = staticmethod(has_op)
- def __init__(self, prop):
+ def __init__(self, prop, mapper):
self.prop = self.property = prop
-
+ self.mapper = mapper
+
def of_type_op(a, class_):
return a.of_type(class_)
of_type_op = staticmethod(of_type_op)
@@ -753,11 +754,12 @@ class LoaderStrategy(object):
def __init__(self, parent):
self.parent_property = parent
self.is_class_level = False
-
- def init(self):
self.parent = self.parent_property.parent
self.key = self.parent_property.key
+ def init(self):
+ raise NotImplementedError("LoaderStrategy")
+
def init_class_attribute(self):
pass
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index c7300f216..792824fe3 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -38,7 +38,7 @@ class ColumnProperty(StrategizedProperty):
self.columns = [expression._labeled(c) for c in columns]
self.group = kwargs.pop('group', None)
self.deferred = kwargs.pop('deferred', False)
- self.comparator = ColumnProperty.ColumnComparator(self)
+ self.comparator_factory = ColumnProperty.ColumnComparator
util.set_creation_order(self)
if self.deferred:
self.strategy_class = strategies.DeferredColumnLoader
@@ -80,7 +80,7 @@ class ColumnProperty(StrategizedProperty):
class ColumnComparator(PropComparator):
def __clause_element__(self):
- return self.prop.columns[0]._annotate({"parententity": self.prop.parent})
+ return self.prop.columns[0]._annotate({"parententity": self.mapper})
__clause_element__ = util.cache_decorator(__clause_element__)
def operate(self, op, *other, **kwargs):
@@ -101,7 +101,7 @@ class CompositeProperty(ColumnProperty):
def __init__(self, class_, *columns, **kwargs):
super(CompositeProperty, self).__init__(*columns, **kwargs)
self.composite_class = class_
- self.comparator = kwargs.pop('comparator', CompositeProperty.Comparator)(self)
+ self.comparator_factory = kwargs.pop('comparator', CompositeProperty.Comparator)
self.strategy_class = strategies.CompositeColumnLoader
def do_init(self):
@@ -170,8 +170,7 @@ class SynonymProperty(MapperProperty):
def do_init(self):
class_ = self.parent.class_
- def comparator():
- return self.parent._get_property(self.key, resolve_synonyms=True).comparator
+
self.logger.info("register managed attribute %s on class %s" % (self.key, class_.__name__))
if self.descriptor is None:
class SynonymProp(object):
@@ -184,7 +183,14 @@ class SynonymProperty(MapperProperty):
return s
return getattr(obj, self.name)
self.descriptor = SynonymProp()
- sessionlib.register_attribute(class_, self.key, uselist=False, proxy_property=self.descriptor, useobject=False, comparator=comparator, parententity=self.parent)
+
+ def comparator_callable(prop, mapper):
+ def comparator():
+ prop = self.parent._get_property(self.key, resolve_synonyms=True)
+ return prop.comparator_factory(prop, mapper)
+ return comparator
+
+ strategies.DefaultColumnLoader(self)._register_attribute(None, None, False, comparator_callable, proxy_property=self.descriptor)
def merge(self, session, source, dest, _recursive):
pass
@@ -195,18 +201,13 @@ class ComparableProperty(MapperProperty):
def __init__(self, comparator_factory, descriptor=None):
self.descriptor = descriptor
- self.comparator = comparator_factory(self)
+ self.comparator_factory = comparator_factory
util.set_creation_order(self)
def do_init(self):
"""Set up a proxy to the unmanaged descriptor."""
- class_ = self.parent.class_
- # refactor me
- sessionlib.register_attribute(class_, self.key, uselist=False,
- proxy_property=self.descriptor,
- useobject=False,
- comparator=self.comparator)
+ strategies.DefaultColumnLoader(self)._register_attribute(None, None, False, self.comparator_factory, proxy_property=self.descriptor)
def setup(self, context, entity, path, adapter, **kwargs):
pass
@@ -252,10 +253,11 @@ class PropertyLoader(StrategizedProperty):
self.passive_updates = passive_updates
self.remote_side = remote_side
self.enable_typechecks = enable_typechecks
- self.comparator = PropertyLoader.Comparator(self)
+ self.comparator = PropertyLoader.Comparator(self, None)
self.join_depth = join_depth
self.local_remote_pairs = _local_remote_pairs
self.__join_cache = {}
+ self.comparator_factory = PropertyLoader.Comparator
util.set_creation_order(self)
if strategy_class:
@@ -295,8 +297,9 @@ class PropertyLoader(StrategizedProperty):
self._is_backref = _is_backref
class Comparator(PropComparator):
- def __init__(self, prop, of_type=None):
+ def __init__(self, prop, mapper, of_type=None):
self.prop = self.property = prop
+ self.mapper = mapper
if of_type:
self._of_type = _class_to_mapper(of_type)
@@ -314,7 +317,7 @@ class PropertyLoader(StrategizedProperty):
return op(self, *other, **kwargs)
def of_type(self, cls):
- return PropertyLoader.Comparator(self.prop, cls)
+ return PropertyLoader.Comparator(self.prop, self.mapper, cls)
def __eq__(self, other):
if other is None:
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 555c376e5..43f206f38 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -97,7 +97,14 @@ class Query(object):
self.__setup_aliasizers(self._entities)
def __setup_aliasizers(self, entities):
- d = {}
+ if hasattr(self, '_mapper_adapter_map'):
+ # usually safe to share a single map, but copying to prevent
+ # subtle leaks if end-user is reusing base query with arbitrary
+ # number of aliased() objects
+ self._mapper_adapter_map = d = self._mapper_adapter_map.copy()
+ else:
+ self._mapper_adapter_map = d = {}
+
for ent in entities:
for entity in ent.entities:
if entity not in d:
@@ -114,7 +121,7 @@ class Query(object):
d[entity] = (mapper, adapter, selectable, is_aliased_class, with_polymorphic)
ent.setup_entity(entity, *d[entity])
-
+
def __mapper_loads_polymorphically_with(self, mapper, adapter):
for m2 in mapper._with_polymorphic_mappers:
for m in m2.iterate_to_root():
@@ -650,26 +657,6 @@ class Query(object):
return self.filter(sql.and_(*clauses))
- def min(self, col):
- """Execute the SQL ``min()`` function against the given column."""
-
- return self._col_aggregate(col, sql.func.min)
-
- def max(self, col):
- """Execute the SQL ``max()`` function against the given column."""
-
- return self._col_aggregate(col, sql.func.max)
-
- def sum(self, col):
- """Execute the SQL ``sum()`` function against the given column."""
-
- return self._col_aggregate(col, sql.func.sum)
-
- def avg(self, col):
- """Execute the SQL ``avg()`` function against the given column."""
-
- return self._col_aggregate(col, sql.func.avg)
-
def order_by(self, *criterion):
"""apply one or more ORDER BY criterion to the query and return the newly resulting ``Query``"""
@@ -1213,18 +1200,17 @@ class Query(object):
_should_nest_selectable = property(_should_nest_selectable)
def count(self):
- """Apply this query's criterion to a SELECT COUNT statement.
-
- this is the purely generative version which will become
- the public method in version 0.5.
-
- """
- return self._col_aggregate(sql.literal_column('1'), sql.func.count, nested_cols=list(self._mapper_zero().primary_key))
+ """Apply this query's criterion to a SELECT COUNT statement."""
+
+ return self._col_aggregate(sql.literal_column('1'), sql.func.count, nested_cols=list(self._only_mapper_zero().primary_key))
def _col_aggregate(self, col, func, nested_cols=None):
- whereclause = self._criterion
-
context = QueryContext(self)
+
+ self._adjust_for_single_inheritance(context)
+
+ whereclause = context.whereclause
+
from_obj = self.__mapper_zero_from_obj()
if self._should_nest_selectable:
@@ -1371,7 +1357,9 @@ class Query(object):
froms = [context.from_clause] # "load from a single FROM" mode, i.e. when select_from() or join() is used
else:
froms = context.froms # "load from discrete FROMs" mode, i.e. when each _MappedEntity has its own FROM
-
+
+ self._adjust_for_single_inheritance(context)
+
if eager_joins and self._should_nest_selectable:
# for eager joins present and LIMIT/OFFSET/DISTINCT, wrap the query inside a select,
# then append eager joins onto that
@@ -1382,7 +1370,15 @@ class Query(object):
context.order_by = None
order_by_col_expr = []
- inner = sql.select(context.primary_columns + order_by_col_expr, context.whereclause, from_obj=froms, use_labels=labels, correlate=False, order_by=context.order_by, **self._select_args)
+ inner = sql.select(
+ context.primary_columns + order_by_col_expr,
+ context.whereclause,
+ from_obj=froms,
+ use_labels=labels,
+ correlate=False,
+ order_by=context.order_by,
+ **self._select_args
+ )
if self._correlate:
inner = inner.correlate(*self._correlate)
@@ -1418,7 +1414,17 @@ class Query(object):
froms += context.eager_joins.values()
- statement = sql.select(context.primary_columns + context.secondary_columns, context.whereclause, from_obj=froms, use_labels=labels, for_update=for_update, correlate=False, order_by=context.order_by, **self._select_args)
+ statement = sql.select(
+ context.primary_columns + context.secondary_columns,
+ context.whereclause,
+ from_obj=froms,
+ use_labels=labels,
+ for_update=for_update,
+ correlate=False,
+ order_by=context.order_by,
+ **self._select_args
+ )
+
if self._correlate:
statement = statement.correlate(*self._correlate)
@@ -1429,6 +1435,22 @@ class Query(object):
return context
+ def _adjust_for_single_inheritance(self, context):
+ """Apply single-table-inheritance filtering.
+
+ For all distinct single-table-inheritance mappers represented in the columns
+ clause of this query, add criterion to the WHERE clause of the given QueryContext
+ such that only the appropriate subtypes are selected from the total results.
+
+ """
+ for entity, (mapper, adapter, s, i, w) in self._mapper_adapter_map.iteritems():
+ if mapper.single and mapper.inherits and mapper.polymorphic_on and mapper.polymorphic_identity is not None:
+ crit = mapper.polymorphic_on.in_([m.polymorphic_identity for m in mapper.polymorphic_iterator()])
+ if adapter:
+ crit = adapter.traverse(crit)
+ crit = self._adapt_clause(crit, False, False)
+ context.whereclause = sql.and_(context.whereclause, crit)
+
def __log_debug(self, msg):
self.logger.debug(msg)
@@ -1463,7 +1485,7 @@ class _MapperEntity(_QueryEntity):
self.entities = [entity]
self.entity_zero = entity
self.entity_name = entity_name
-
+
def setup_entity(self, entity, mapper, adapter, from_obj, is_aliased_class, with_polymorphic):
self.mapper = mapper
self.extension = self.mapper.extension
@@ -1554,15 +1576,10 @@ class _MapperEntity(_QueryEntity):
return main, entname
def setup_context(self, query, context):
- # if single-table inheritance mapper, add "typecol IN (polymorphic)" criterion so
- # that we only load the appropriate types
- if self.mapper.single and self.mapper.inherits is not None and self.mapper.polymorphic_on is not None and self.mapper.polymorphic_identity is not None:
- context.whereclause = sql.and_(context.whereclause, self.mapper.polymorphic_on.in_([m.polymorphic_identity for m in self.mapper.polymorphic_iterator()]))
+ adapter = self._get_entity_clauses(query, context)
context.froms.append(self.selectable)
- adapter = self._get_entity_clauses(query, context)
-
if context.order_by is False and self.mapper.order_by:
context.order_by = self.mapper.order_by
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index d40937a18..829210205 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -17,11 +17,31 @@ from sqlalchemy.orm import session as sessionlib
from sqlalchemy.orm import util as mapperutil
-class ColumnLoader(LoaderStrategy):
- """Default column loader."""
+class DefaultColumnLoader(LoaderStrategy):
+ def _register_attribute(self, compare_function, copy_function, mutable_scalars, comparator_factory, callable_=None, proxy_property=None):
+ self.logger.info("%s register managed attribute" % self)
+
+ for mapper in self.parent.polymorphic_iterator():
+ if mapper is self.parent or not mapper.concrete:
+ sessionlib.register_attribute(
+ mapper.class_,
+ self.key,
+ uselist=False,
+ useobject=False,
+ copy_function=copy_function,
+ compare_function=compare_function,
+ mutable_scalars=mutable_scalars,
+ comparator=comparator_factory(self.parent_property, mapper),
+ parententity=mapper,
+ callable_=callable_,
+ proxy_property=proxy_property
+ )
+
+DefaultColumnLoader.logger = log.class_logger(DefaultColumnLoader)
+
+class ColumnLoader(DefaultColumnLoader):
def init(self):
- super(ColumnLoader, self).init()
self.columns = self.parent_property.columns
self._should_log_debug = log.is_debug_enabled(self.logger)
self.is_composite = hasattr(self.parent_property, 'composite_class')
@@ -34,9 +54,14 @@ class ColumnLoader(LoaderStrategy):
def init_class_attribute(self):
self.is_class_level = True
- self.logger.info("%s register managed attribute" % self)
coltype = self.columns[0].type
- sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, copy_function=coltype.copy_value, compare_function=coltype.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator, parententity=self.parent)
+
+ self._register_attribute(
+ coltype.compare_values,
+ coltype.copy_value,
+ self.columns[0].type.is_mutable(),
+ self.parent_property.comparator_factory
+ )
def create_row_processor(self, selectcontext, path, mapper, row, adapter):
key, col = self.key, self.columns[0]
@@ -78,7 +103,13 @@ class CompositeColumnLoader(ColumnLoader):
return False
else:
return True
- sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, copy_function=copy, compare_function=compare, mutable_scalars=True, comparator=self.parent_property.comparator, parententity=self.parent)
+
+ self._register_attribute(
+ compare,
+ copy,
+ True,
+ self.parent_property.comparator_factory
+ )
def create_row_processor(self, selectcontext, path, mapper, row, adapter):
key, columns, composite_class = self.key, self.columns, self.parent_property.composite_class
@@ -106,7 +137,7 @@ class CompositeColumnLoader(ColumnLoader):
CompositeColumnLoader.logger = log.class_logger(CompositeColumnLoader)
-class DeferredColumnLoader(LoaderStrategy):
+class DeferredColumnLoader(DefaultColumnLoader):
"""Deferred column loader, a per-column or per-column-group lazy loader."""
def create_row_processor(self, selectcontext, path, mapper, row, adapter):
@@ -130,7 +161,6 @@ class DeferredColumnLoader(LoaderStrategy):
return (new_execute, None)
def init(self):
- super(DeferredColumnLoader, self).init()
if hasattr(self.parent_property, 'composite_class'):
raise NotImplementedError("Deferred loading for composite types not implemented yet")
self.columns = self.parent_property.columns
@@ -139,8 +169,13 @@ class DeferredColumnLoader(LoaderStrategy):
def init_class_attribute(self):
self.is_class_level = True
- self.logger.info("%s register managed attribute" % self)
- sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, callable_=self.class_level_loader, copy_function=self.columns[0].type.copy_value, compare_function=self.columns[0].type.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator, parententity=self.parent)
+ self._register_attribute(
+ self.columns[0].type.compare_values,
+ self.columns[0].type.copy_value,
+ self.columns[0].type.is_mutable(),
+ self.parent_property.comparator_factory,
+ callable_=self.class_level_loader,
+ )
def setup_query(self, context, entity, path, adapter, only_load_props=None, **kwargs):
if \
@@ -238,7 +273,6 @@ class UndeferGroupOption(MapperOption):
class AbstractRelationLoader(LoaderStrategy):
def init(self):
- super(AbstractRelationLoader, self).init()
for attr in ['mapper', 'target', 'table', 'uselist']:
setattr(self, attr, getattr(self.parent_property, attr))
self._should_log_debug = log.is_debug_enabled(self.logger)
@@ -249,7 +283,7 @@ class AbstractRelationLoader(LoaderStrategy):
else:
state.initialize(self.key)
- def _register_attribute(self, class_, callable_=None, **kwargs):
+ def _register_attribute(self, class_, callable_=None, impl_class=None, **kwargs):
self.logger.info("%s register managed %s attribute" % (self, (self.uselist and "collection" or "scalar")))
if self.parent_property.backref:
@@ -257,7 +291,21 @@ class AbstractRelationLoader(LoaderStrategy):
else:
attribute_ext = None
- sessionlib.register_attribute(class_, self.key, uselist=self.uselist, useobject=True, extension=attribute_ext, cascade=self.parent_property.cascade, trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_, comparator=self.parent_property.comparator, parententity=self.parent, **kwargs)
+ sessionlib.register_attribute(
+ class_,
+ self.key,
+ uselist=self.uselist,
+ useobject=True,
+ extension=attribute_ext,
+ cascade=self.parent_property.cascade,
+ trackparent=True,
+ typecallable=self.parent_property.collection_class,
+ callable_=callable_,
+ comparator=self.parent_property.comparator,
+ parententity=self.parent,
+ impl_class=impl_class,
+ **kwargs
+ )
class NoLoader(AbstractRelationLoader):
def init_class_attribute(self):