summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/sqlalchemy/orm/dependency.py32
-rw-r--r--lib/sqlalchemy/orm/unitofwork.py15
-rw-r--r--lib/sqlalchemy/topological.py32
-rw-r--r--test/base/test_dependency.py41
4 files changed, 58 insertions, 62 deletions
diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py
index a34c45233..dec90adfe 100644
--- a/lib/sqlalchemy/orm/dependency.py
+++ b/lib/sqlalchemy/orm/dependency.py
@@ -94,10 +94,11 @@ class DependencyProcessor(object):
"""
# locate and disable the aggregate processors
# for this dependency
- after_save = unitofwork.ProcessAll(uow, self, False, True)
+
before_delete = unitofwork.ProcessAll(uow, self, True, True)
- after_save.disabled = True
before_delete.disabled = True
+ after_save = unitofwork.ProcessAll(uow, self, False, True)
+ after_save.disabled = True
# check if the "child" side is part of the cycle
child_saves = unitofwork.SaveUpdateAll(uow, self.mapper.base_mapper)
@@ -122,7 +123,7 @@ class DependencyProcessor(object):
# check if the "parent" side is part of the cycle
if not isdelete:
parent_saves = unitofwork.SaveUpdateAll(uow, self.parent.base_mapper)
- parent_deletes = before_delte = None
+ parent_deletes = before_delete = None
if parent_saves in uow.cycles:
parent_in_cycles = True
else:
@@ -133,19 +134,18 @@ class DependencyProcessor(object):
# now create actions /dependencies for each state.
for state in states:
+ # I'd like to emit the before_delete/after_save actions
+ # here and have the unit of work not get confused by that
+ # when it alters the list of dependencies...
if isdelete:
before_delete = unitofwork.ProcessState(uow, self, True, state)
- yield before_delete
+ if parent_in_cycles:
+ parent_deletes = unitofwork.DeleteState(uow, state)
else:
after_save = unitofwork.ProcessState(uow, self, False, state)
- yield after_save
-
- if parent_in_cycles:
- if isdelete:
- parent_deletes = unitofwork.DeleteState(uow, state)
- else:
+ if parent_in_cycles:
parent_saves = unitofwork.SaveUpdateState(uow, state)
-
+
if child_in_cycles:
# locate each child state associated with the parent action,
# create dependencies for each.
@@ -174,7 +174,11 @@ class DependencyProcessor(object):
child_action,
after_save, before_delete,
isdelete, childisdelete)
-
+
+ # ... but at the moment it
+ # does so we emit a null iterator
+ return iter([])
+
def presort_deletes(self, uowcommit, states):
pass
@@ -304,8 +308,8 @@ class OneToManyDP(DependencyProcessor):
])
else:
uow.dependencies.update([
- (child_action, before_delete),
- (before_delete, delete_parent),
+ (before_delete, child_action),
+ (child_action, delete_parent)
])
def presort_deletes(self, uowcommit, states):
diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py
index b8373ff63..898be9139 100644
--- a/lib/sqlalchemy/orm/unitofwork.py
+++ b/lib/sqlalchemy/orm/unitofwork.py
@@ -195,6 +195,8 @@ class UOWTransaction(object):
self.dependencies.remove(edge)
elif cycles.issuperset(edge):
self.dependencies.remove(edge)
+ elif edge[0].disabled or edge[1].disabled:
+ self.dependencies.remove(edge)
elif edge[0] in cycles:
self.dependencies.remove(edge)
for dep in convert[edge[0]]:
@@ -203,19 +205,18 @@ class UOWTransaction(object):
self.dependencies.remove(edge)
for dep in convert[edge[1]]:
self.dependencies.add((edge[0], dep))
- elif edge[0].disabled or edge[1].disabled:
- self.dependencies.remove(edge)
postsort_actions = set(
- [a for a in self.postsort_actions.values()
- if not a.disabled]
+ [a for a in self.postsort_actions.values()
+ if not a.disabled
+ ]
).difference(cycles)
# execute actions
sort = topological.sort(self.dependencies, postsort_actions)
-# print "------------------------"
-# print self.dependencies
-# print sort
+ #print "------------------------"
+ #print self.dependencies
+ #print sort
for rec in sort:
rec.execute(self)
diff --git a/lib/sqlalchemy/topological.py b/lib/sqlalchemy/topological.py
index 5fc982ae0..bcf47bd64 100644
--- a/lib/sqlalchemy/topological.py
+++ b/lib/sqlalchemy/topological.py
@@ -26,26 +26,16 @@ __all__ = ['sort']
class _EdgeCollection(object):
"""A collection of directed edges."""
- def __init__(self):
+ def __init__(self, edges):
self.parent_to_children = util.defaultdict(set)
self.child_to_parents = util.defaultdict(set)
-
- def add(self, edge):
- """Add an edge to this collection."""
-
- parentnode, childnode = edge
- self.parent_to_children[parentnode].add(childnode)
- self.child_to_parents[childnode].add(parentnode)
-
+ for parentnode, childnode in edges:
+ self.parent_to_children[parentnode].add(childnode)
+ self.child_to_parents[childnode].add(parentnode)
+
def has_parents(self, node):
return node in self.child_to_parents and bool(self.child_to_parents[node])
- def edges_by_parent(self, node):
- if node in self.parent_to_children:
- return [(node, child) for child in self.parent_to_children[node]]
- else:
- return []
-
def outgoing(self, node):
"""an iterable returning all nodes reached via node's outgoing edges"""
@@ -79,13 +69,9 @@ def sort(tuples, allitems):
'tuples' is a list of tuples representing a partial ordering.
"""
- edges = _EdgeCollection()
+ edges = _EdgeCollection(tuples)
nodes = set(allitems)
- for t in tuples:
- nodes.update(t)
- edges.add(t)
-
queue = []
for n in nodes:
if not edges.has_parents(n):
@@ -106,12 +92,8 @@ def sort(tuples, allitems):
def find_cycles(tuples, allitems):
# straight from gvr with some mods
todo = set(allitems)
- edges = _EdgeCollection()
+ edges = _EdgeCollection(tuples)
- for t in tuples:
- todo.update(t)
- edges.add(t)
-
output = set()
while todo:
diff --git a/test/base/test_dependency.py b/test/base/test_dependency.py
index 8c38a98b0..462e923f1 100644
--- a/test/base/test_dependency.py
+++ b/test/base/test_dependency.py
@@ -5,7 +5,14 @@ from sqlalchemy import exc
import collections
class DependencySortTest(TestBase):
- def assert_sort(self, tuples, result):
+ def assert_sort(self, tuples, allitems=None):
+
+ if allitems is None:
+ allitems = self._nodes_from_tuples(tuples)
+ else:
+ allitems = self._nodes_from_tuples(tuples).union(allitems)
+
+ result = topological.sort(tuples, allitems)
deps = collections.defaultdict(set)
for parent, child in tuples:
@@ -16,6 +23,12 @@ class DependencySortTest(TestBase):
for n in result[i:]:
assert node not in deps[n]
+ def _nodes_from_tuples(self, tups):
+ s = set()
+ for tup in tups:
+ s.update(tup)
+ return s
+
def test_sort_one(self):
rootnode = 'root'
node2 = 'node2'
@@ -36,7 +49,7 @@ class DependencySortTest(TestBase):
(node4, subnode3),
(node4, subnode4)
]
- self.assert_sort(tuples, topological.sort(tuples, []))
+ self.assert_sort(tuples)
def test_sort_two(self):
node1 = 'node1'
@@ -53,7 +66,7 @@ class DependencySortTest(TestBase):
(node5, node6),
(node6, node2)
]
- self.assert_sort(tuples, topological.sort(tuples, [node7]))
+ self.assert_sort(tuples, [node7])
def test_sort_three(self):
node1 = 'keywords'
@@ -66,7 +79,7 @@ class DependencySortTest(TestBase):
(node1, node3),
(node3, node2)
]
- self.assert_sort(tuples, topological.sort(tuples, []))
+ self.assert_sort(tuples)
def test_raise_on_cycle_one(self):
node1 = 'node1'
@@ -82,7 +95,7 @@ class DependencySortTest(TestBase):
(node3, node1),
(node4, node1)
]
- allitems = [node1, node2, node3, node4]
+ allitems = self._nodes_from_tuples(tuples)
assert_raises(exc.CircularDependencyError, topological.sort, tuples, allitems)
# TODO: test find_cycles
@@ -101,7 +114,8 @@ class DependencySortTest(TestBase):
(node3, node2),
(node2, node3)
]
- assert_raises(exc.CircularDependencyError, topological.sort, tuples, [])
+ allitems = self._nodes_from_tuples(tuples)
+ assert_raises(exc.CircularDependencyError, topological.sort, tuples, allitems)
# TODO: test find_cycles
@@ -112,24 +126,19 @@ class DependencySortTest(TestBase):
(question, provider), (providerservice, question),
(provider, providerservice), (question, answer), (issue, question)]
- assert_raises(exc.CircularDependencyError, topological.sort, tuples, [])
+ allitems = self._nodes_from_tuples(tuples)
+ assert_raises(exc.CircularDependencyError, topological.sort, tuples, allitems)
# TODO: test find_cycles
def test_large_sort(self):
tuples = [(i, i + 1) for i in range(0, 1500, 2)]
- self.assert_sort(
- tuples,
- topological.sort(tuples, [])
- )
+ self.assert_sort(tuples)
def test_ticket_1380(self):
# ticket:1380 regression: would raise a KeyError
tuples = [(id(i), i) for i in range(3)]
- self.assert_sort(
- tuples,
- topological.sort(tuples, [])
- )
+ self.assert_sort(tuples)
def test_find_cycle_one(self):
node1 = 'node1'
@@ -145,7 +154,7 @@ class DependencySortTest(TestBase):
]
eq_(
- topological.find_cycles(tuples),
+ topological.find_cycles(tuples, self._nodes_from_tuples(tuples)),
set([node1, node2, node3])
)