summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2006-06-05 02:31:53 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2006-06-05 02:31:53 +0000
commitf8314ef9ff08af5f104731de402d6e6bd8c043f3 (patch)
tree22520019a71251e76b6f64af9709c92ee3bd0f03 /lib
parent7e5b3d2f9fd69924ac2cf588e60508975588aa28 (diff)
downloadsqlalchemy-f8314ef9ff08af5f104731de402d6e6bd8c043f3.tar.gz
improvements/fixes to session cascade iteration,
fixes to entity_name propigation
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/orm/mapper.py36
-rw-r--r--lib/sqlalchemy/orm/properties.py49
-rw-r--r--lib/sqlalchemy/orm/session.py49
-rw-r--r--lib/sqlalchemy/orm/unitofwork.py18
4 files changed, 98 insertions, 54 deletions
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 4e80aceeb..eba220384 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -387,25 +387,31 @@ class Mapper(object):
and assocites this Mapper with its class via the mapper_registry."""
oldinit = self.class_.__init__
def init(self, *args, **kwargs):
- self._entity_name = kwargs.pop('_sa_entity_name', None)
# this gets the AttributeManager to do some pre-initialization,
# in order to save on KeyErrors later on
sessionlib.global_attributes.init_attr(self)
+ entity_name = kwargs.pop('_sa_entity_name', None)
if kwargs.has_key('_sa_session'):
session = kwargs.pop('_sa_session')
else:
# works for whatever mapper the class is associated with
- mapper = mapper_registry.get(ClassKey(self.__class__, self._entity_name))
+ mapper = mapper_registry.get(ClassKey(self.__class__, entity_name))
if mapper is not None:
session = mapper.extension.get_session()
if session is EXT_PASS:
session = None
else:
session = None
+ # if a session was found, either via _sa_session or via mapper extension,
+ # save() this instance to the session, and give it an associated entity_name.
+ # otherwise, this instance will not have a session or mapper association until it is
+ # save()d to some session.
if session is not None:
+ self._entity_name = entity_name
session._register_new(self)
+
if oldinit is not None:
oldinit(self, *args, **kwargs)
# override oldinit, insuring that its not already a Mapper-decorated init method
@@ -748,16 +754,19 @@ class Mapper(object):
for prop in self.props.values():
prop.register_dependencies(uowcommit, *args, **kwargs)
- def cascade_iterator(self, type, object, recursive=None):
+ def cascade_iterator(self, type, object, callable_=None, recursive=None):
if recursive is None:
recursive=sets.Set()
- if object not in recursive:
- recursive.add(object)
- yield object
for prop in self.props.values():
for c in prop.cascade_iterator(type, object, recursive):
yield c
+ def cascade_callable(self, type, object, callable_, recursive=None):
+ if recursive is None:
+ recursive=sets.Set()
+ for prop in self.props.values():
+ prop.cascade_callable(type, object, callable_, recursive)
+
def _row_identity_key(self, row):
return sessionlib.get_row_key(row, self.class_, self.pks_by_table[self.mapped_table], self.entity_name)
@@ -929,6 +938,8 @@ class MapperProperty(object):
raise NotImplementedError()
def cascade_iterator(self, type, object, recursive=None):
return []
+ def cascade_callable(self, type, object, callable_, recursive=None):
+ return []
def copy(self):
raise NotImplementedError()
def get_criterion(self, query, key, value):
@@ -1157,13 +1168,16 @@ def hash_key(obj):
return obj.hash_key()
else:
return repr(obj)
+
+def has_mapper(object):
+ """returns True if the given object has a mapper association"""
+ return hasattr(object, '_entity_name')
-def object_mapper(object, raiseerror=True, entity_name=None):
- """given an object, returns the primary Mapper associated with the object
- or the object's class."""
+def object_mapper(object, raiseerror=True):
+ """given an object, returns the primary Mapper associated with the object instance"""
try:
- return mapper_registry[ClassKey(object.__class__, getattr(object, '_entity_name', entity_name))]
- except KeyError:
+ return mapper_registry[ClassKey(object.__class__, getattr(object, '_entity_name'))]
+ except (KeyError, AttributeError):
if raiseerror:
raise exceptions.InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (object.__class__.__name__, getattr(object, '_entity_name', None)))
else:
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 8c38b897f..8cdaf3940 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -154,14 +154,25 @@ class PropertyLoader(mapper.MapperProperty):
if not type in self.cascade:
return
childlist = sessionlib.global_attributes.get_history(object, self.key, passive=True)
-
+
+ mapper = self.mapper.primary_mapper()
for c in childlist.added_items() + childlist.deleted_items() + childlist.unchanged_items():
- if c is not None:
- if c not in recursive:
- recursive.add(c)
- yield c
- for c2 in self.mapper.primary_mapper().cascade_iterator(type, c, recursive):
- yield c2
+ if c is not None and c not in recursive:
+ recursive.add(c)
+ yield c
+ for c2 in mapper.cascade_iterator(type, c, recursive):
+ yield c2
+
+ def cascade_callable(self, type, object, callable_, recursive):
+ if not type in self.cascade:
+ return
+ childlist = sessionlib.global_attributes.get_history(object, self.key, passive=True)
+ mapper = self.mapper.primary_mapper()
+ for c in childlist.added_items() + childlist.deleted_items() + childlist.unchanged_items():
+ if c is not None and c not in recursive:
+ recursive.add(c)
+ callable_(c, mapper.entity_name)
+ mapper.cascade_callable(type, c, callable_, recursive)
def copy(self):
x = self.__class__.__new__(self.__class__)
@@ -237,10 +248,16 @@ class PropertyLoader(mapper.MapperProperty):
raise exceptions.ArgumentError("Attempting to assign a new relation '%s' to a non-primary mapper on class '%s'. New relations can only be added to the primary mapper, i.e. the very first mapper created for class '%s' " % (key, parent.class_.__name__, parent.class_.__name__))
self.do_init_subclass(key, parent)
+
+ def _register_attribute(self, class_, callable_=None):
+ sessionlib.global_attributes.register_attribute(class_, self.key, uselist = self.uselist, extension=self.attributeext, cascade=self.cascade, trackparent=True, callable_=callable_)
+
+ def _create_history(self, instance, callable_=None):
+ return sessionlib.global_attributes.create_history(instance, self.key, self.uselist, cascade=self.cascade, trackparent=True, callable_=callable_)
def _set_class_attribute(self, class_, key):
"""sets attribute behavior on our target class."""
- sessionlib.global_attributes.register_attribute(class_, key, uselist = self.uselist, extension=self.attributeext, cascade=self.cascade, trackparent=True)
+ self._register_attribute(class_)
def _get_direction(self):
"""determines our 'direction', i.e. do we represent one to many, many to many, etc."""
@@ -295,7 +312,7 @@ class PropertyLoader(mapper.MapperProperty):
if self.is_primary():
return
#print "PLAIN PROPLOADER EXEC NON-PRIAMRY", repr(id(self)), repr(self.mapper.class_), self.key
- sessionlib.global_attributes.create_history(instance, self.key, self.uselist, cascade=self.cascade, trackparent=True)
+ self._create_history(instance)
def register_dependencies(self, uowcommit):
self._dependency_processor.register_dependencies(uowcommit)
@@ -327,11 +344,17 @@ class LazyLoader(PropertyLoader):
def _set_class_attribute(self, class_, key):
# establish a class-level lazy loader on our class
#print "SETCLASSATTR LAZY", repr(class_), key
- sessionlib.global_attributes.register_attribute(class_, key, uselist = self.uselist, callable_=lambda i: self.setup_loader(i), extension=self.attributeext, cascade=self.cascade, trackparent=True)
+ self._register_attribute(class_, callable_=lambda i: self.setup_loader(i))
def setup_loader(self, instance):
+ # make sure our parent mapper is the one thats assigned to this instance, else call that one
if not self.localparent.is_assigned(instance):
- return mapper.object_mapper(instance).props[self.key].setup_loader(instance)
+ # if no mapper association with this instance (i.e. not in a session, not loaded by a mapper),
+ # then we cant set up a lazy loader
+ if not mapper.has_mapper(instance):
+ return None
+ else:
+ return mapper.object_mapper(instance).props[self.key].setup_loader(instance)
def lazyload():
params = {}
allparams = True
@@ -379,7 +402,7 @@ class LazyLoader(PropertyLoader):
#print "EXEC NON-PRIAMRY", repr(self.mapper.class_), self.key
# we are not the primary manager for this attribute on this class - set up a per-instance lazyloader,
# which will override the class-level behavior
- sessionlib.global_attributes.create_history(instance, self.key, self.uselist, callable_=self.setup_loader(instance), cascade=self.cascade, trackparent=True)
+ self._create_history(instance, callable_=self.setup_loader(instance))
else:
#print "EXEC PRIMARY", repr(self.mapper.class_), self.key
# we are the primary manager for this attribute on this class - reset its per-instance attribute state,
@@ -548,7 +571,7 @@ class EagerLoader(LazyLoader):
if isnew:
# new row loaded from the database. initialize a blank container on the instance.
# this will override any per-class lazyloading type of stuff.
- h = sessionlib.global_attributes.create_history(instance, self.key, self.uselist, cascade=self.cascade, trackparent=True)
+ h = self._create_history(instance)
if not self.uselist:
if isnew:
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 56d699cc6..bd1750165 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -275,12 +275,8 @@ class Session(object):
The 'entity_name' keyword argument will further qualify the specific Mapper used to handle this
instance.
"""
- for c in object_mapper(object, entity_name=entity_name).cascade_iterator('save-update', object):
- if c is object:
- self._save_impl(c, entity_name=entity_name)
- else:
- # TODO: this is running the cascade rules twice
- self.save_or_update(c, entity_name=entity_name)
+ self._save_impl(object, entity_name=entity_name)
+ object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e))
def update(self, object, entity_name=None):
"""Brings the given detached (saved) instance into this Session.
@@ -288,30 +284,31 @@ class Session(object):
Session), an exception is thrown.
This operation cascades the "save_or_update" method to associated instances if the relation is mapped
with cascade="save-update"."""
- for c in object_mapper(object, entity_name=entity_name).cascade_iterator('save-update', object):
- if c is object:
- self._update_impl(c, entity_name=entity_name)
- else:
- self.save_or_update(c, entity_name=entity_name)
+ self._update_impl(object, entity_name=entity_name)
+ object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e))
def save_or_update(self, object, entity_name=None):
- for c in object_mapper(object, entity_name=entity_name).cascade_iterator('save-update', object):
- key = getattr(object, '_instance_key', None)
- if key is None:
- self._save_impl(c, entity_name=entity_name)
- else:
- self._update_impl(c, entity_name=entity_name)
-
+ self._save_or_update_impl(object, entity_name=entity_name)
+ object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e))
+
+ def _save_or_update_impl(self, object, entity_name=None):
+ key = getattr(object, '_instance_key', None)
+ if key is None:
+ self._save_impl(object, entity_name=entity_name)
+ else:
+ self._update_impl(object, entity_name=entity_name)
+
def delete(self, object, entity_name=None):
- for c in object_mapper(object, entity_name=entity_name).cascade_iterator('delete', object):
+ #self.uow.register_deleted(object)
+ for c in [object] + list(object_mapper(object).cascade_iterator('delete', object)):
self.uow.register_deleted(c)
def merge(self, object, entity_name=None):
instance = None
- for obj in object_mapper(object, entity_name=entity_name).cascade_iterator('merge', object):
+ for obj in [object] + list(object_mapper(object).cascade_iterator('merge', object)):
key = getattr(obj, '_instance_key', None)
if key is None:
- mapper = object_mapper(object, entity_name=entity_name)
+ mapper = object_mapper(object)
ident = mapper.identity(object)
for k in ident:
if k is None:
@@ -333,10 +330,8 @@ class Session(object):
if not self.uow.has_key(object._instance_key):
raise exceptions.InvalidRequestError("Instance '%s' is already persistent in a different Session" % repr(object))
else:
- entity_name = kwargs.get('entity_name', None)
- if entity_name is not None:
- m = class_mapper(object.__class__, entity_name=entity_name)
- m._assign_entity_name(object)
+ m = class_mapper(object.__class__, entity_name=kwargs.get('entity_name', None))
+ m._assign_entity_name(object)
self._register_new(object)
def _update_impl(self, object, **kwargs):
@@ -422,8 +417,8 @@ def get_id_key(ident, class_, entity_name=None):
def get_row_key(row, class_, primary_key, entity_name=None):
return Session.get_row_key(row, class_, primary_key, entity_name)
-def object_mapper(obj, **kwargs):
- return sqlalchemy.orm.object_mapper(obj, **kwargs)
+def object_mapper(obj):
+ return sqlalchemy.orm.object_mapper(obj)
def class_mapper(class_, **kwargs):
return sqlalchemy.orm.class_mapper(class_, **kwargs)
diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py
index c33f344fb..9e9778cad 100644
--- a/lib/sqlalchemy/orm/unitofwork.py
+++ b/lib/sqlalchemy/orm/unitofwork.py
@@ -49,7 +49,13 @@ class UOWListElement(attributes.ListAttribute):
if self.cascade is not None:
if not isdelete:
if self.cascade.save_update:
- sess.save_or_update(item)
+ # cascade the save_update operation onto the child object,
+ # relative to the mapper handling the parent object
+ # TODO: easier way to do this ?
+ mapper = object_mapper(obj)
+ prop = mapper.props[self.key]
+ ename = prop.mapper.entity_name
+ sess.save_or_update(item, entity_name=ename)
def append(self, item, _mapper_nohistory = False):
if _mapper_nohistory:
self.append_nohistory(item)
@@ -67,13 +73,19 @@ class UOWScalarElement(attributes.ScalarAttribute):
sess._register_changed(obj)
if newvalue is not None and self.cascade is not None:
if self.cascade.save_update:
- sess.save_or_update(newvalue)
+ # cascade the save_update operation onto the child object,
+ # relative to the mapper handling the parent object
+ # TODO: easier way to do this ?
+ mapper = object_mapper(obj)
+ prop = mapper.props[self.key]
+ ename = prop.mapper.entity_name
+ sess.save_or_update(newvalue, entity_name=ename)
class UOWAttributeManager(attributes.AttributeManager):
"""overrides AttributeManager to provide unit-of-work "dirty" hooks when scalar attribues are modified, plus factory methods for UOWProperrty/UOWListElement."""
def __init__(self):
attributes.AttributeManager.__init__(self)
-
+
def create_prop(self, class_, key, uselist, callable_, **kwargs):
return UOWProperty(class_, self, key, uselist, callable_, **kwargs)