diff options
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/sqlalchemy/orm/properties.py | 36 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/session.py | 6 |
2 files changed, 25 insertions, 17 deletions
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index db43a8e27..4fd9a3e9b 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -38,7 +38,7 @@ class SynonymProperty(MapperProperty): return s return getattr(obj, self.name) setattr(self.parent.class_, self.key, SynonymProp()) - def merge(self, session, source, dest): + def merge(self, session, source, dest, _recursive): pass class ColumnProperty(StrategizedProperty): @@ -61,7 +61,7 @@ class ColumnProperty(StrategizedProperty): setattr(object, self.key, value) def get_history(self, obj, passive=False): return sessionlib.attribute_manager.get_history(obj, self.key, passive=passive) - def merge(self, session, source, dest): + def merge(self, session, source, dest, _recursive): setattr(dest, self.key, getattr(source, self.key, None)) def compare(self, value): return self.columns[0] == value @@ -127,20 +127,26 @@ class PropertyLoader(StrategizedProperty): def __str__(self): return self.__class__.__name__ + " " + str(self.parent) + "->" + self.key + "->" + str(self.mapper) - def merge(self, session, source, dest): - if not "merge" in self.cascade: + def merge(self, session, source, dest, _recursive): + if not "merge" in self.cascade or source in _recursive: return - childlist = sessionlib.attribute_manager.get_history(source, self.key, passive=True) - if childlist is None: - return - if self.uselist: - # sets a blank list according to the correct list class - dest_list = getattr(self.parent.class_, self.key).initialize(dest) - for current in list(childlist): - dest_list.append(session.merge(current)) - else: - setattr(dest, self.key, session.merge(current)) - + _recursive.add(source) + try: + childlist = sessionlib.attribute_manager.get_history(source, self.key, passive=True) + if childlist is None: + return + if self.uselist: + # sets a blank list according to the correct list class + dest_list = getattr(self.parent.class_, self.key).initialize(dest) + for current in list(childlist): + dest_list.append(session.merge(current, _recursive=_recursive)) + else: + current = list(childlist)[0] + if current is not None: + setattr(dest, self.key, session.merge(current, _recursive=_recursive)) + finally: + _recursive.remove(source) + def cascade_iterator(self, type, object, recursive, halt_on=None): if not type in self.cascade: return diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 829220688..f2a718177 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -323,7 +323,7 @@ class Session(object): for c in [object] + list(_object_mapper(object).cascade_iterator('delete', object)): self.uow.register_deleted(c) - def merge(self, object, entity_name=None): + def merge(self, object, entity_name=None, _recursive=None): """copy the state of the given object onto the persistent object with the same identifier. If there is no persistent instance currently associated with the session, it will be loaded. @@ -331,6 +331,8 @@ class Session(object): a newly persistent instance. The given instance does not become associated with the session. This operation cascades to associated instances if the association is mapped with cascade="merge". """ + if _recursive is None: + _recursive = util.Set() mapper = _object_mapper(object) key = getattr(object, '_instance_key', None) if key is None: @@ -341,7 +343,7 @@ class Session(object): else: merged = self.get(mapper.class_, key[1]) for prop in mapper.props.values(): - prop.merge(self, object, merged) + prop.merge(self, object, merged, _recursive) if key is None: self.save(merged) return merged |
