summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-12-21 06:57:20 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-12-21 06:57:20 +0000
commitc114f096bd0bd786916cbc42eaa91e0e1158ccf4 (patch)
tree2313856ffc6bfe8a3b61017a58efe98b4be322e7 /lib/sqlalchemy
parent2db36bf59c447d4d113cba0ae12f1b739c2ae923 (diff)
downloadsqlalchemy-c114f096bd0bd786916cbc42eaa91e0e1158ccf4.tar.gz
- reworked all lazy/deferred/expired callables to be
serializable class instances, added pickling tests - cleaned up "deferred" polymorphic system so that the mapper handles it entirely - columns which are missing from a Query's select statement now get automatically deferred during load.
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/attributes.py62
-rw-r--r--lib/sqlalchemy/orm/interfaces.py23
-rw-r--r--lib/sqlalchemy/orm/mapper.py124
-rw-r--r--lib/sqlalchemy/orm/session.py8
-rw-r--r--lib/sqlalchemy/orm/strategies.py316
-rw-r--r--lib/sqlalchemy/orm/util.py6
-rw-r--r--lib/sqlalchemy/sql/expression.py5
7 files changed, 315 insertions, 229 deletions
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index 089522673..135269906 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -255,12 +255,13 @@ class AttributeImpl(object):
class ScalarAttributeImpl(AttributeImpl):
"""represents a scalar value-holding InstrumentedAttribute."""
- accepts_global_callable = True
+ accepts_scalar_loader = True
def delete(self, state):
if self.key not in state.committed_state:
state.committed_state[self.key] = state.dict.get(self.key, NO_VALUE)
+ # TODO: catch key errors, convert to attributeerror?
del state.dict[self.key]
state.modified=True
@@ -327,7 +328,7 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
Adds events to delete/set operations.
"""
- accepts_global_callable = False
+ accepts_scalar_loader = False
def __init__(self, class_, key, callable_, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs):
super(ScalarObjectAttributeImpl, self).__init__(class_, key,
@@ -338,6 +339,7 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
def delete(self, state):
old = self.get(state)
+ # TODO: catch key errors, convert to attributeerror?
del state.dict[self.key]
self.fire_remove_event(state, old, self)
@@ -404,7 +406,7 @@ class CollectionAttributeImpl(AttributeImpl):
CollectionAdapter, a "view" onto that object that presents consistent
bag semantics to the orm layer independent of the user data implementation.
"""
- accepts_global_callable = False
+ accepts_scalar_loader = False
def __init__(self, class_, key, callable_, typecallable=None, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs):
super(CollectionAttributeImpl, self).__init__(class_,
@@ -479,6 +481,7 @@ class CollectionAttributeImpl(AttributeImpl):
collection = self.get_collection(state)
collection.clear_with_event()
+ # TODO: catch key errors, convert to attributeerror?
del state.dict[self.key]
def initialize(self, state):
@@ -648,7 +651,7 @@ class ClassState(object):
self.mappers = {}
self.attrs = {}
self.has_mutable_scalars = False
-
+
class InstanceState(object):
"""tracks state information at the instance level."""
@@ -658,7 +661,6 @@ class InstanceState(object):
self.dict = obj.__dict__
self.committed_state = {}
self.modified = False
- self.trigger = None
self.callables = {}
self.parents = {}
self.pending = {}
@@ -735,7 +737,7 @@ class InstanceState(object):
return None
def __getstate__(self):
- return {'committed_state':self.committed_state, 'pending':self.pending, 'parents':self.parents, 'modified':self.modified, 'instance':self.obj()}
+ return {'committed_state':self.committed_state, 'pending':self.pending, 'parents':self.parents, 'modified':self.modified, 'instance':self.obj(), 'expired_attributes':getattr(self, 'expired_attributes', None), 'callables':self.callables}
def __setstate__(self, state):
self.committed_state = state['committed_state']
@@ -745,43 +747,62 @@ class InstanceState(object):
self.obj = weakref.ref(state['instance'])
self.class_ = self.obj().__class__
self.dict = self.obj().__dict__
- self.callables = {}
- self.trigger = None
-
+ self.callables = state['callables']
+ self.runid = None
+ self.appenders = {}
+ if state['expired_attributes'] is not None:
+ self.expire_attributes(state['expired_attributes'])
+
def initialize(self, key):
getattr(self.class_, key).impl.initialize(self)
def set_callable(self, key, callable_):
self.dict.pop(key, None)
self.callables[key] = callable_
-
- def __fire_trigger(self):
+
+ def __call__(self):
+ """__call__ allows the InstanceState to act as a deferred
+ callable for loading expired attributes, which is also
+ serializable.
+ """
instance = self.obj()
- self.trigger(instance, [k for k in self.expired_attributes if k not in self.dict])
+ self.class_._class_state.deferred_scalar_loader(instance, [k for k in self.expired_attributes if k not in self.committed_state])
for k in self.expired_attributes:
self.callables.pop(k, None)
self.expired_attributes.clear()
return ATTR_WAS_SET
+ def unmodified(self):
+ """a set of keys which have no uncommitted changes"""
+
+ return util.Set([
+ attr.impl.key for attr in _managed_attributes(self.class_) if
+ attr.impl.key not in self.committed_state
+ and (not hasattr(attr.impl, 'commit_to_state') or not attr.impl.check_mutable_modified(self))
+ ])
+ unmodified = property(unmodified)
+
def expire_attributes(self, attribute_names):
if not hasattr(self, 'expired_attributes'):
self.expired_attributes = util.Set()
+
if attribute_names is None:
for attr in _managed_attributes(self.class_):
self.dict.pop(attr.impl.key, None)
- self.callables[attr.impl.key] = self.__fire_trigger
- self.expired_attributes.add(attr.impl.key)
+
+ if attr.impl.accepts_scalar_loader:
+ self.callables[attr.impl.key] = self
+ self.expired_attributes.add(attr.impl.key)
+
self.committed_state = {}
else:
for key in attribute_names:
self.dict.pop(key, None)
self.committed_state.pop(key, None)
- if not getattr(self.class_, key).impl.accepts_global_callable:
- continue
-
- self.callables[key] = self.__fire_trigger
- self.expired_attributes.add(key)
+ if getattr(self.class_, key).impl.accepts_scalar_loader:
+ self.callables[key] = self
+ self.expired_attributes.add(key)
def reset(self, key):
"""remove the given attribute and any callables associated with it."""
@@ -1081,7 +1102,7 @@ def _init_class_state(class_):
if not '_class_state' in class_.__dict__:
class_._class_state = ClassState()
-def register_class(class_, extra_init=None, on_exception=None):
+def register_class(class_, extra_init=None, on_exception=None, deferred_scalar_loader=None):
# do a sweep first, this also helps some attribute extensions
# (like associationproxy) become aware of themselves at the
# class level
@@ -1089,6 +1110,7 @@ def register_class(class_, extra_init=None, on_exception=None):
getattr(class_, key, None)
_init_class_state(class_)
+ class_._class_state.deferred_scalar_loader=deferred_scalar_loader
oldinit = None
doinit = False
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index 52e39372d..486b7b6b6 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -15,6 +15,7 @@ ORM.
"""
from sqlalchemy import util, logging, exceptions
from sqlalchemy.sql import expression
+from itertools import chain
class_mapper = None
__all__ = ['EXT_CONTINUE', 'EXT_STOP', 'EXT_PASS', 'MapperExtension',
@@ -505,7 +506,27 @@ def build_path(mapper, key, prev=None):
return prev + (mapper.base_mapper, key)
else:
return (mapper.base_mapper, key)
-
+
+def serialize_path(path):
+ if path is None:
+ return None
+
+ return [
+ (mapper.class_, mapper.entity_name, key)
+ for mapper, key in [(path[i], path[i+1]) for i in range(0, len(path)-1, 2)]
+ ]
+
+def deserialize_path(path):
+ if path is None:
+ return None
+
+ global class_mapper
+ if class_mapper is None:
+ from sqlalchemy.orm import class_mapper
+
+ return tuple(
+ chain(*[(class_mapper(cls, entity), key) for cls, entity, key in path])
+ )
class MapperOption(object):
"""Describe a modification to a Query."""
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 8c375ea39..db666a4f9 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -758,7 +758,7 @@ class Mapper(object):
def on_exception(class_, oldinit, instance, args, kwargs):
util.warn_exception(self.extension.init_failed, self, class_, oldinit, instance, args, kwargs)
- attributes.register_class(self.class_, extra_init=extra_init, on_exception=on_exception)
+ attributes.register_class(self.class_, extra_init=extra_init, on_exception=on_exception, deferred_scalar_loader=_load_scalar_attributes)
self._class_state = self.class_._class_state
_mapper_registry[self] = True
@@ -1358,42 +1358,22 @@ class Mapper(object):
instance._sa_session_id = context.session.hash_key
session_identity_map[identitykey] = instance
- if currentload or context.populate_existing or self.always_refresh or state.trigger:
+ if currentload or context.populate_existing or self.always_refresh:
if isnew:
state.runid = context.runid
- state.trigger = None
context.progress.add(state)
-
+
if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
self.populate_instance(context, instance, row, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew)
-
+
+ elif getattr(state, 'expired_attributes', None):
+ if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=state.expired_attributes, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
+ self.populate_instance(context, instance, row, only_load_props=state.expired_attributes, instancekey=identitykey, isnew=isnew)
+
if result is not None and ('append_result' not in extension.methods or extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE):
result.append(instance)
return instance
-
- def _deferred_inheritance_condition(self, base_mapper, needs_tables):
- def visit_binary(binary):
- leftcol = binary.left
- rightcol = binary.right
- if leftcol is None or rightcol is None:
- return
- if leftcol.table not in needs_tables:
- binary.left = sql.bindparam(None, None, type_=binary.right.type)
- param_names.append((leftcol, binary.left))
- elif rightcol not in needs_tables:
- binary.right = sql.bindparam(None, None, type_=binary.right.type)
- param_names.append((rightcol, binary.right))
-
- allconds = []
- param_names = []
-
- for mapper in self.iterate_to_root():
- if mapper is base_mapper:
- break
- allconds.append(visitors.traverse(mapper.inherit_condition, clone=True, visit_binary=visit_binary))
-
- return sql.and_(*allconds), param_names
def translate_row(self, tomapper, row):
"""Translate the column keys of a row into a new or proxied
@@ -1451,7 +1431,10 @@ class Mapper(object):
populators = new_populators
else:
populators = existing_populators
-
+
+ if only_load_props:
+ populators = [p for p in populators if p[0] in only_load_props]
+
for (key, populator) in populators:
selectcontext.exec_with_path(self, key, populator, instance, row, ispostselect=ispostselect, isnew=isnew, **flags)
@@ -1464,26 +1447,75 @@ class Mapper(object):
p(state.obj())
def _get_poly_select_loader(self, selectcontext, row):
- # 'select' or 'union'+col not present
+ """set up attribute loaders for 'select' and 'deferred' polymorphic loading.
+
+ this loading uses a second SELECT statement to load additional tables,
+ either immediately after loading the main table or via a deferred attribute trigger.
+ """
+
(hosted_mapper, needs_tables) = selectcontext.attributes.get(('polymorphic_fetch', self), (None, None))
- if hosted_mapper is None or not needs_tables or hosted_mapper.polymorphic_fetch == 'deferred':
+
+ if hosted_mapper is None or not needs_tables:
return
cond, param_names = self._deferred_inheritance_condition(hosted_mapper, needs_tables)
statement = sql.select(needs_tables, cond, use_labels=True)
- def post_execute(instance, **flags):
- if self.__should_log_debug:
- self.__log_debug("Post query loading instance " + instance_str(instance))
+
+ if hosted_mapper.polymorphic_fetch == 'select':
+ def post_execute(instance, **flags):
+ if self.__should_log_debug:
+ self.__log_debug("Post query loading instance " + instance_str(instance))
+
+ identitykey = self.identity_key_from_instance(instance)
+
+ params = {}
+ for c, bind in param_names:
+ params[bind] = self._get_attr_by_column(instance, c)
+ row = selectcontext.session.connection(self).execute(statement, params).fetchone()
+ self.populate_instance(selectcontext, instance, row, isnew=False, instancekey=identitykey, ispostselect=True)
+ return post_execute
+ elif hosted_mapper.polymorphic_fetch == 'deferred':
+ from sqlalchemy.orm.strategies import DeferredColumnLoader
+
+ def post_execute(instance, **flags):
+ def create_statement(instance):
+ params = {}
+ for (c, bind) in param_names:
+ # use the "committed" (database) version to get query column values
+ params[bind] = self._get_committed_attr_by_column(instance, c)
+ return (statement, params)
+
+ props = [prop for prop in [self._get_col_to_prop(col) for col in statement.inner_columns] if prop.key not in instance.__dict__]
+ keys = [p.key for p in props]
+ for prop in props:
+ strategy = prop._get_strategy(DeferredColumnLoader)
+ instance._state.set_callable(prop.key, strategy.setup_loader(instance, props=keys, create_statement=create_statement))
+ return post_execute
+ else:
+ return None
+
+ def _deferred_inheritance_condition(self, base_mapper, needs_tables):
+ def visit_binary(binary):
+ leftcol = binary.left
+ rightcol = binary.right
+ if leftcol is None or rightcol is None:
+ return
+ if leftcol.table not in needs_tables:
+ binary.left = sql.bindparam(None, None, type_=binary.right.type)
+ param_names.append((leftcol, binary.left))
+ elif rightcol not in needs_tables:
+ binary.right = sql.bindparam(None, None, type_=binary.right.type)
+ param_names.append((rightcol, binary.right))
- identitykey = self.identity_key_from_instance(instance)
+ allconds = []
+ param_names = []
- params = {}
- for c, bind in param_names:
- params[bind] = self._get_attr_by_column(instance, c)
- row = selectcontext.session.connection(self).execute(statement, params).fetchone()
- self.populate_instance(selectcontext, instance, row, isnew=False, instancekey=identitykey, ispostselect=True)
+ for mapper in self.iterate_to_root():
+ if mapper is base_mapper:
+ break
+ allconds.append(visitors.traverse(mapper.inherit_condition, clone=True, visit_binary=visit_binary))
- return post_execute
+ return sql.and_(*allconds), param_names
Mapper.logger = logging.class_logger(Mapper)
@@ -1501,6 +1533,16 @@ def has_mapper(object):
return hasattr(object, '_entity_name')
+object_session = None
+
+def _load_scalar_attributes(instance, attribute_names):
+ global object_session
+ if not object_session:
+ from sqlalchemy.orm.session import object_session
+
+ if object_session(instance).query(object_mapper(instance))._get(instance._instance_key, refresh_instance=instance._state, only_load_props=attribute_names) is None:
+ raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % instance_str(instance))
+
def _state_mapper(state, entity_name=None):
return state.class_._class_state.mappers[state.dict.get('_entity_name', entity_name)]
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 993eba4c5..f75d5c36c 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -1113,7 +1113,7 @@ class Session(object):
return util.IdentitySet(self.uow.new.values())
new = property(new)
-
+
def _expire_state(state, attribute_names):
"""Standalone expire instance function.
@@ -1124,12 +1124,6 @@ def _expire_state(state, attribute_names):
If the list is None or blank, the entire instance is expired.
"""
- if state.trigger is None:
- def load_attributes(instance, attribute_names):
- if object_session(instance).query(instance.__class__)._get(instance._instance_key, refresh_instance=instance._state, only_load_props=attribute_names) is None:
- raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(instance))
- state.trigger = load_attributes
-
state.expire_attributes(attribute_names)
register_attribute = unitofwork.register_attribute
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index 60fc02579..33981f161 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -10,7 +10,7 @@ from sqlalchemy import sql, util, exceptions, logging
from sqlalchemy.sql import util as sql_util
from sqlalchemy.sql import visitors, expression, operators
from sqlalchemy.orm import mapper, attributes
-from sqlalchemy.orm.interfaces import LoaderStrategy, StrategizedOption, MapperOption, PropertyOption
+from sqlalchemy.orm.interfaces import LoaderStrategy, StrategizedOption, MapperOption, PropertyOption, serialize_path, deserialize_path
from sqlalchemy.orm import session as sessionlib
from sqlalchemy.orm import util as mapperutil
@@ -80,53 +80,13 @@ class ColumnLoader(LoaderStrategy):
if self._should_log_debug:
self.logger.debug("Returning active column fetcher for %s %s" % (mapper, self.key))
return (new_execute, None, None)
-
- # our mapped column is not present in the row. check if we need to initialize a polymorphic
- # row fetcher used by inheritance.
- (hosted_mapper, needs_tables) = selectcontext.attributes.get(('polymorphic_fetch', mapper), (None, None))
-
- if hosted_mapper is None:
- return (None, None, None)
-
- if hosted_mapper.polymorphic_fetch == 'deferred':
- # 'deferred' polymorphic row fetcher, put a callable on the property.
- # create a deferred column loader which will query the remaining not-yet-loaded tables in an inheritance load.
- # the mapper for the object creates the WHERE criterion using the mapper who originally
- # "hosted" the query and the list of tables which are unloaded between the "hosted" mapper
- # and this mapper. (i.e. A->B->C, the query used mapper A. therefore will need B's and C's tables
- # in the query).
-
- # deferred loader strategy
- strategy = self.parent_property._get_strategy(DeferredColumnLoader)
-
- # full list of ColumnProperty objects to be loaded in the deferred fetch
- props = [p.key for p in mapper.iterate_properties if isinstance(p.strategy, ColumnLoader) and p.columns[0].table in needs_tables]
-
- # TODO: we are somewhat duplicating efforts from mapper._get_poly_select_loader
- # and should look for ways to simplify.
- cond, param_names = mapper._deferred_inheritance_condition(hosted_mapper, needs_tables)
- statement = sql.select(needs_tables, cond, use_labels=True)
- def create_statement(instance):
- params = {}
- for (c, bind) in param_names:
- # use the "committed" (database) version to get query column values
- params[bind] = mapper._get_committed_attr_by_column(instance, c)
- return (statement, params)
-
+ else:
def new_execute(instance, row, isnew, **flags):
if isnew:
- instance._state.set_callable(self.key, strategy.setup_loader(instance, props=props, create_statement=create_statement))
-
+ instance._state.expire_attributes([self.key])
if self._should_log_debug:
- self.logger.debug("Returning deferred column fetcher for %s %s" % (mapper, self.key))
-
+ self.logger.debug("Deferring load for %s %s" % (mapper, self.key))
return (new_execute, None, None)
- else:
- # immediate polymorphic row fetcher. no processing needed for this row.
- if self._should_log_debug:
- self.logger.debug("Returning no column fetcher for %s %s" % (mapper, self.key))
- return (None, None, None)
-
ColumnLoader.logger = logging.class_logger(ColumnLoader)
@@ -170,9 +130,10 @@ class DeferredColumnLoader(LoaderStrategy):
self.parent_property._get_strategy(ColumnLoader).setup_query(context, **kwargs)
def setup_loader(self, instance, props=None, create_statement=None):
- localparent = mapper.object_mapper(instance, raiseerror=False)
- if localparent is None:
+ if not mapper.has_mapper(instance):
return None
+
+ localparent = mapper.object_mapper(instance)
# adjust for the ColumnProperty associated with the instance
# not being our own ColumnProperty. This can occur when entity_name
@@ -181,39 +142,64 @@ class DeferredColumnLoader(LoaderStrategy):
prop = localparent.get_property(self.key)
if prop is not self.parent_property:
return prop._get_strategy(DeferredColumnLoader).setup_loader(instance)
-
- def lazyload():
- if not mapper.has_identity(instance):
- return None
-
- if props is not None:
- group = props
- elif self.group is not None:
- group = [p.key for p in localparent.iterate_properties if isinstance(p.strategy, DeferredColumnLoader) and p.group==self.group]
- else:
- group = [self.parent_property.key]
-
- # narrow the keys down to just those which aren't present on the instance
- group = [k for k in group if k not in instance.__dict__]
-
- if self._should_log_debug:
- self.logger.debug("deferred load %s group %s" % (mapperutil.attribute_str(instance, self.key), group and ','.join(group) or 'None'))
-
- session = sessionlib.object_session(instance)
- if session is None:
- raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
- if create_statement is None:
- ident = instance._instance_key[1]
- session.query(localparent)._get(None, ident=ident, only_load_props=group, refresh_instance=instance._state)
- else:
- statement, params = create_statement(instance)
- session.query(localparent).from_statement(statement).params(params)._get(None, only_load_props=group, refresh_instance=instance._state)
- return attributes.ATTR_WAS_SET
- return lazyload
+ return LoadDeferredColumns(instance, self.key, props, optimizing_statement=create_statement)
DeferredColumnLoader.logger = logging.class_logger(DeferredColumnLoader)
+class LoadDeferredColumns(object):
+ """callable, serializable loader object used by DeferredColumnLoader"""
+
+ def __init__(self, instance, key, keys, optimizing_statement):
+ self.instance = instance
+ self.key = key
+ self.keys = keys
+ self.optimizing_statement = optimizing_statement
+
+ def __getstate__(self):
+ return {'instance':self.instance, 'key':self.key, 'keys':self.keys}
+
+ def __setstate__(self, state):
+ self.instance = state['instance']
+ self.key = state['key']
+ self.keys = state['keys']
+ self.optimizing_statement = None
+
+ def __call__(self):
+ if not mapper.has_identity(self.instance):
+ return None
+
+ localparent = mapper.object_mapper(self.instance, raiseerror=False)
+
+ prop = localparent.get_property(self.key)
+ strategy = prop._get_strategy(DeferredColumnLoader)
+
+ if self.keys:
+ toload = self.keys
+ elif strategy.group:
+ toload = [p.key for p in localparent.iterate_properties if isinstance(p.strategy, DeferredColumnLoader) and p.group==strategy.group]
+ else:
+ toload = [self.key]
+
+ # narrow the keys down to just those which have no history
+ group = [k for k in toload if k in self.instance._state.unmodified]
+
+ if strategy._should_log_debug:
+ strategy.logger.debug("deferred load %s group %s" % (mapperutil.attribute_str(self.instance, self.key), group and ','.join(group) or 'None'))
+
+ session = sessionlib.object_session(self.instance)
+ if session is None:
+ raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (self.instance.__class__, self.key))
+
+ query = session.query(localparent)
+ if not self.optimizing_statement:
+ ident = self.instance._instance_key[1]
+ query._get(None, ident=ident, only_load_props=group, refresh_instance=self.instance._state)
+ else:
+ statement, params = self.optimizing_statement(self.instance)
+ query.from_statement(statement).params(params)._get(None, only_load_props=group, refresh_instance=self.instance._state)
+ return attributes.ATTR_WAS_SET
+
class DeferredOption(StrategizedOption):
def __init__(self, key, defer=False):
super(DeferredOption, self).__init__(key)
@@ -276,7 +262,7 @@ NoLoader.logger = logging.class_logger(NoLoader)
class LazyLoader(AbstractRelationLoader):
def init(self):
super(LazyLoader, self).init()
- (self.lazywhere, self.lazybinds, self.lazyreverse) = self._create_lazy_clause(self)
+ (self.lazywhere, self.lazybinds, self.equated_columns) = self._create_lazy_clause(self)
self.logger.info(str(self.parent_property) + " lazy loading clause " + str(self.lazywhere))
@@ -293,10 +279,10 @@ class LazyLoader(AbstractRelationLoader):
def lazy_clause(self, instance, reverse_direction=False):
if instance is None:
- return self.lazy_none_clause(reverse_direction)
+ return self._lazy_none_clause(reverse_direction)
if not reverse_direction:
- (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.lazyreverse)
+ (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.equated_columns)
else:
(criterion, lazybinds, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
@@ -308,9 +294,9 @@ class LazyLoader(AbstractRelationLoader):
bindparam.value = mapper._get_committed_attr_by_column(instance, bind_to_col[bindparam.key])
return visitors.traverse(criterion, clone=True, visit_bindparam=visit_bindparam)
- def lazy_none_clause(self, reverse_direction=False):
+ def _lazy_none_clause(self, reverse_direction=False):
if not reverse_direction:
- (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.lazyreverse)
+ (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.equated_columns)
else:
(criterion, lazybinds, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
@@ -331,71 +317,18 @@ class LazyLoader(AbstractRelationLoader):
def setup_loader(self, instance, options=None, path=None):
if not mapper.has_mapper(instance):
return None
- else:
- # adjust for the PropertyLoader associated with the instance
- # not being our own PropertyLoader. This can occur when entity_name
- # mappers are used to map different versions of the same PropertyLoader
- # to the class.
- prop = mapper.object_mapper(instance).get_property(self.key)
- if prop is not self.parent_property:
- return prop._get_strategy(LazyLoader).setup_loader(instance)
-
- def lazyload():
- if self._should_log_debug:
- self.logger.debug("lazy load attribute %s on instance %s" % (self.key, mapperutil.instance_str(instance)))
- if not mapper.has_identity(instance):
- return None
+ localparent = mapper.object_mapper(instance)
- session = sessionlib.object_session(instance)
- if session is None:
- try:
- session = mapper.object_mapper(instance).get_session()
- except exceptions.InvalidRequestError:
- raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session, and no contextual session is established; lazy load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
-
- # if we have a simple straight-primary key load, use mapper.get()
- # to possibly save a DB round trip
- q = session.query(self.mapper).autoflush(False)
- if path:
- q = q._with_current_path(path)
- if self.use_get:
- params = {}
- for col, bind in self.lazybinds.iteritems():
- # use the "committed" (database) version to get query column values
- params[bind.key] = self.parent._get_committed_attr_by_column(instance, col)
- ident = []
- nonnulls = False
- for primary_key in self.select_mapper.primary_key:
- bind = self.lazyreverse[primary_key]
- v = params[bind.key]
- if v is not None:
- nonnulls = True
- ident.append(v)
- if not nonnulls:
- return None
- if options:
- q = q._conditional_options(*options)
- return q.get(ident)
- elif self.order_by is not False:
- q = q.order_by(self.order_by)
- elif self.secondary is not None and self.secondary.default_order_by() is not None:
- q = q.order_by(self.secondary.default_order_by())
-
- if options:
- q = q._conditional_options(*options)
- q = q.filter(self.lazy_clause(instance))
-
- result = q.all()
- if self.uselist:
- return result
- else:
- if result:
- return result[0]
- else:
- return None
-
- return lazyload
+ # adjust for the PropertyLoader associated with the instance
+ # not being our own PropertyLoader. This can occur when entity_name
+ # mappers are used to map different versions of the same PropertyLoader
+ # to the class.
+ prop = localparent.get_property(self.key)
+ if prop is not self.parent_property:
+ return prop._get_strategy(LazyLoader).setup_loader(instance)
+
+ return LoadLazyAttribute(instance, self.key, options, path)
def create_row_processor(self, selectcontext, mapper, row):
if not self.is_class_level or len(selectcontext.options):
@@ -424,7 +357,7 @@ class LazyLoader(AbstractRelationLoader):
(primaryjoin, secondaryjoin, remote_side) = (prop.polymorphic_primaryjoin, prop.polymorphic_secondaryjoin, prop.remote_side)
binds = {}
- reverse = {}
+ equated_columns = {}
def should_bind(targetcol, othercol):
if reverse_direction and not secondaryjoin:
@@ -437,20 +370,17 @@ class LazyLoader(AbstractRelationLoader):
return
leftcol = binary.left
rightcol = binary.right
-
+
+ equated_columns[rightcol] = leftcol
+ equated_columns[leftcol] = rightcol
+
if should_bind(leftcol, rightcol):
- col = leftcol
- binary.left = binds.setdefault(leftcol,
- sql.bindparam(None, None, type_=binary.right.type))
- reverse[rightcol] = binds[col]
+ binary.left = binds[leftcol] = sql.bindparam(None, None, type_=binary.right.type)
# the "left is not right" compare is to handle part of a join clause that is "table.c.col1==table.c.col1",
# which can happen in rare cases (test/orm/relationships.py RelationTest2)
if leftcol is not rightcol and should_bind(rightcol, leftcol):
- col = rightcol
- binary.right = binds.setdefault(rightcol,
- sql.bindparam(None, None, type_=binary.left.type))
- reverse[leftcol] = binds[col]
+ binary.right = binds[rightcol] = sql.bindparam(None, None, type_=binary.left.type)
lazywhere = primaryjoin
@@ -461,11 +391,86 @@ class LazyLoader(AbstractRelationLoader):
if reverse_direction:
secondaryjoin = visitors.traverse(secondaryjoin, clone=True, visit_binary=visit_binary)
lazywhere = sql.and_(lazywhere, secondaryjoin)
- return (lazywhere, binds, reverse)
+ return (lazywhere, binds, equated_columns)
_create_lazy_clause = classmethod(_create_lazy_clause)
LazyLoader.logger = logging.class_logger(LazyLoader)
+class LoadLazyAttribute(object):
+ """callable, serializable loader object used by LazyLoader"""
+
+ def __init__(self, instance, key, options, path):
+ self.instance = instance
+ self.key = key
+ self.options = options
+ self.path = path
+
+ def __getstate__(self):
+ return {'instance':self.instance, 'key':self.key, 'options':self.options, 'path':serialize_path(self.path)}
+
+ def __setstate__(self, state):
+ self.instance = state['instance']
+ self.key = state['key']
+ self.options= state['options']
+ self.path = deserialize_path(state['path'])
+
+ def __call__(self):
+ instance = self.instance
+
+ if not mapper.has_identity(instance):
+ return None
+
+ instance_mapper = mapper.object_mapper(instance)
+ prop = instance_mapper.get_property(self.key)
+ strategy = prop._get_strategy(LazyLoader)
+
+ if strategy._should_log_debug:
+ strategy.logger.debug("lazy load attribute %s on instance %s" % (self.key, mapperutil.instance_str(instance)))
+
+ session = sessionlib.object_session(instance)
+ if session is None:
+ try:
+ session = instance_mapper.get_session()
+ except exceptions.InvalidRequestError:
+ raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session, and no contextual session is established; lazy load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
+
+ q = session.query(prop.mapper).autoflush(False)
+ if self.path:
+ q = q._with_current_path(self.path)
+
+ # if we have a simple primary key load, use mapper.get()
+ # to possibly save a DB round trip
+ if strategy.use_get:
+ ident = []
+ allnulls = True
+ for primary_key in prop.select_mapper.primary_key:
+ val = instance_mapper._get_committed_attr_by_column(instance, strategy.equated_columns[primary_key])
+ allnulls = allnulls and val is None
+ ident.append(val)
+ if allnulls:
+ return None
+ if self.options:
+ q = q._conditional_options(*self.options)
+ return q.get(ident)
+
+ if strategy.order_by is not False:
+ q = q.order_by(strategy.order_by)
+ elif strategy.secondary is not None and strategy.secondary.default_order_by() is not None:
+ q = q.order_by(strategy.secondary.default_order_by())
+
+ if self.options:
+ q = q._conditional_options(*self.options)
+ q = q.filter(strategy.lazy_clause(instance))
+
+ result = q.all()
+ if strategy.uselist:
+ return result
+ else:
+ if result:
+ return result[0]
+ else:
+ return None
+
class EagerLoader(AbstractRelationLoader):
"""Loads related objects inline with a parent query."""
@@ -630,8 +635,7 @@ class EagerLoader(AbstractRelationLoader):
if self._should_log_debug:
self.logger.debug("eager loader %s degrading to lazy loader" % str(self))
return self.parent_property._get_strategy(LazyLoader).create_row_processor(selectcontext, mapper, row)
-
-
+
def __str__(self):
return str(self.parent) + "." + self.key
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 7b76183be..397d99c0f 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -284,8 +284,10 @@ def instance_str(instance):
def state_str(state):
"""Return a string describing an instance."""
-
- return state.class_.__name__ + "@" + hex(id(state.obj()))
+ if state is None:
+ return "None"
+ else:
+ return state.class_.__name__ + "@" + hex(id(state.obj()))
def attribute_str(instance, attribute):
return instance_str(instance) + "." + attribute
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index ff370bc59..df3dbd279 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -47,7 +47,6 @@ __all__ = [
'subquery', 'table', 'text', 'union', 'union_all', 'update', ]
-BIND_PARAMS = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
def desc(column):
"""Return a descending ``ORDER BY`` clause element.
@@ -1795,6 +1794,8 @@ class _TextClause(ClauseElement):
__visit_name__ = 'textclause'
+ _bind_params_regex = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
+
def __init__(self, text = "", bind=None, bindparams=None, typemap=None):
self._bind = bind
self.bindparams = {}
@@ -1809,7 +1810,7 @@ class _TextClause(ClauseElement):
# scan the string and search for bind parameter names, add them
# to the list of bindparams
- self.text = BIND_PARAMS.sub(repl, text)
+ self.text = self._bind_params_regex.sub(repl, text)
if bindparams is not None:
for b in bindparams:
self.bindparams[b.key] = b