summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-01-21 19:47:25 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-01-21 19:47:25 +0000
commit5f33323bc639c612e79bf2f91d4e2e7c28cfbaa8 (patch)
tree0745b57b34f3d3931b55151d6e87925f8f1c51de /lib
parentdd694c44f48ee544a53851f619ed131b05ff25fd (diff)
downloadsqlalchemy-5f33323bc639c612e79bf2f91d4e2e7c28cfbaa8.tar.gz
added recursion check to merge
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/orm/properties.py36
-rw-r--r--lib/sqlalchemy/orm/session.py6
2 files changed, 25 insertions, 17 deletions
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index db43a8e27..4fd9a3e9b 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -38,7 +38,7 @@ class SynonymProperty(MapperProperty):
return s
return getattr(obj, self.name)
setattr(self.parent.class_, self.key, SynonymProp())
- def merge(self, session, source, dest):
+ def merge(self, session, source, dest, _recursive):
pass
class ColumnProperty(StrategizedProperty):
@@ -61,7 +61,7 @@ class ColumnProperty(StrategizedProperty):
setattr(object, self.key, value)
def get_history(self, obj, passive=False):
return sessionlib.attribute_manager.get_history(obj, self.key, passive=passive)
- def merge(self, session, source, dest):
+ def merge(self, session, source, dest, _recursive):
setattr(dest, self.key, getattr(source, self.key, None))
def compare(self, value):
return self.columns[0] == value
@@ -127,20 +127,26 @@ class PropertyLoader(StrategizedProperty):
def __str__(self):
return self.__class__.__name__ + " " + str(self.parent) + "->" + self.key + "->" + str(self.mapper)
- def merge(self, session, source, dest):
- if not "merge" in self.cascade:
+ def merge(self, session, source, dest, _recursive):
+ if not "merge" in self.cascade or source in _recursive:
return
- childlist = sessionlib.attribute_manager.get_history(source, self.key, passive=True)
- if childlist is None:
- return
- if self.uselist:
- # sets a blank list according to the correct list class
- dest_list = getattr(self.parent.class_, self.key).initialize(dest)
- for current in list(childlist):
- dest_list.append(session.merge(current))
- else:
- setattr(dest, self.key, session.merge(current))
-
+ _recursive.add(source)
+ try:
+ childlist = sessionlib.attribute_manager.get_history(source, self.key, passive=True)
+ if childlist is None:
+ return
+ if self.uselist:
+ # sets a blank list according to the correct list class
+ dest_list = getattr(self.parent.class_, self.key).initialize(dest)
+ for current in list(childlist):
+ dest_list.append(session.merge(current, _recursive=_recursive))
+ else:
+ current = list(childlist)[0]
+ if current is not None:
+ setattr(dest, self.key, session.merge(current, _recursive=_recursive))
+ finally:
+ _recursive.remove(source)
+
def cascade_iterator(self, type, object, recursive, halt_on=None):
if not type in self.cascade:
return
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 829220688..f2a718177 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -323,7 +323,7 @@ class Session(object):
for c in [object] + list(_object_mapper(object).cascade_iterator('delete', object)):
self.uow.register_deleted(c)
- def merge(self, object, entity_name=None):
+ def merge(self, object, entity_name=None, _recursive=None):
"""copy the state of the given object onto the persistent object with the same identifier.
If there is no persistent instance currently associated with the session, it will be loaded.
@@ -331,6 +331,8 @@ class Session(object):
a newly persistent instance. The given instance does not become associated with the session.
This operation cascades to associated instances if the association is mapped with cascade="merge".
"""
+ if _recursive is None:
+ _recursive = util.Set()
mapper = _object_mapper(object)
key = getattr(object, '_instance_key', None)
if key is None:
@@ -341,7 +343,7 @@ class Session(object):
else:
merged = self.get(mapper.class_, key[1])
for prop in mapper.props.values():
- prop.merge(self, object, merged)
+ prop.merge(self, object, merged, _recursive)
if key is None:
self.save(merged)
return merged