summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/ansible/inventory/group.py89
-rw-r--r--lib/ansible/inventory/host.py13
-rw-r--r--test/units/plugins/inventory/test_group.py125
3 files changed, 200 insertions, 27 deletions
diff --git a/lib/ansible/inventory/group.py b/lib/ansible/inventory/group.py
index 4847e6fbd8..859781cc4d 100644
--- a/lib/ansible/inventory/group.py
+++ b/lib/ansible/inventory/group.py
@@ -19,6 +19,8 @@ __metaclass__ = type
from ansible.errors import AnsibleError
+from itertools import chain
+
class Group:
''' a group of ansible hosts '''
@@ -80,6 +82,38 @@ class Group:
g.deserialize(parent_data)
self.parent_groups.append(g)
+ def _walk_relationship(self, rel):
+ '''
+ Given `rel` that is an iterable property of Group,
+ consitituting a directed acyclic graph among all groups,
+ Returns a set of all groups in full tree
+ A B C
+ | / | /
+ | / | /
+ D -> E
+ | / vertical connections
+ | / are directed upward
+ F
+ Called on F, returns set of (A, B, C, D, E)
+ '''
+ seen = set([])
+ unprocessed = set(getattr(self, rel))
+
+ while unprocessed:
+ seen.update(unprocessed)
+ unprocessed = set(chain.from_iterable(
+ getattr(g, rel) for g in unprocessed
+ ))
+ unprocessed.difference_update(seen)
+
+ return seen
+
+ def get_ancestors(self):
+ return self._walk_relationship('parent_groups')
+
+ def get_descendants(self):
+ return self._walk_relationship('child_groups')
+
@property
def host_names(self):
if self._hosts is None:
@@ -96,6 +130,17 @@ class Group:
# don't add if it's already there
if group not in self.child_groups:
+
+ # prepare list of group's new ancestors this edge creates
+ start_ancestors = group.get_ancestors()
+ new_ancestors = self.get_ancestors()
+ if group in new_ancestors:
+ raise AnsibleError(
+ "Adding group '%s' as child to '%s' creates a recursive "
+ "dependency loop." % (group.name, self.name))
+ new_ancestors.add(self)
+ new_ancestors.difference_update(start_ancestors)
+
self.child_groups.append(group)
# update the depth of the child
@@ -109,18 +154,28 @@ class Group:
if self.name not in [g.name for g in group.parent_groups]:
group.parent_groups.append(self)
for h in group.get_hosts():
- h.populate_ancestors()
+ h.populate_ancestors(additions=new_ancestors)
self.clear_hosts_cache()
def _check_children_depth(self):
- try:
- for group in self.child_groups:
- group.depth = max([self.depth + 1, group.depth])
- group._check_children_depth()
- except RuntimeError:
- raise AnsibleError("The group named '%s' has a recursive dependency loop." % self.name)
+ depth = self.depth
+ start_depth = self.depth # self.depth could change over loop
+ seen = set([])
+ unprocessed = set(self.child_groups)
+
+ while unprocessed:
+ seen.update(unprocessed)
+ depth += 1
+ to_process = unprocessed.copy()
+ unprocessed = set([])
+ for g in to_process:
+ if g.depth < depth:
+ g.depth = depth
+ unprocessed.update(g.child_groups)
+ if depth - start_depth > len(seen):
+ raise AnsibleError("The group named '%s' has a recursive dependency loop." % self.name)
def add_host(self, host):
if host.name not in self.host_names:
@@ -147,8 +202,8 @@ class Group:
def clear_hosts_cache(self):
self._hosts_cache = None
- for g in self.parent_groups:
- g.clear_hosts_cache()
+ for g in self.get_ancestors():
+ g._hosts_cache = None
def get_hosts(self):
@@ -160,8 +215,8 @@ class Group:
hosts = []
seen = {}
- for kid in self.child_groups:
- kid_hosts = kid.get_hosts()
+ for kid in self.get_descendants():
+ kid_hosts = kid.hosts
for kk in kid_hosts:
if kk not in seen:
seen[kk] = 1
@@ -179,18 +234,6 @@ class Group:
def get_vars(self):
return self.vars.copy()
- def _get_ancestors(self):
-
- results = {}
- for g in self.parent_groups:
- results[g.name] = g
- results.update(g._get_ancestors())
- return results
-
- def get_ancestors(self):
-
- return self._get_ancestors().values()
-
def set_priority(self, priority):
try:
self.priority = int(priority)
diff --git a/lib/ansible/inventory/host.py b/lib/ansible/inventory/host.py
index 647e00dc1c..4327273945 100644
--- a/lib/ansible/inventory/host.py
+++ b/lib/ansible/inventory/host.py
@@ -101,17 +101,22 @@ class Host:
def get_name(self):
return self.name
- def populate_ancestors(self):
+ def populate_ancestors(self, additions=None):
# populate ancestors
- for group in self.groups:
- self.add_group(group)
+ if additions is None:
+ for group in self.groups:
+ self.add_group(group)
+ else:
+ for group in additions:
+ if group not in self.groups:
+ self.groups.append(group)
def add_group(self, group):
# populate ancestors first
for oldg in group.get_ancestors():
if oldg not in self.groups:
- self.add_group(oldg)
+ self.groups.append(oldg)
# actually add group
if group not in self.groups:
diff --git a/test/units/plugins/inventory/test_group.py b/test/units/plugins/inventory/test_group.py
new file mode 100644
index 0000000000..086c7cf798
--- /dev/null
+++ b/test/units/plugins/inventory/test_group.py
@@ -0,0 +1,125 @@
+# Copyright 2018 Alan Rominger <arominge@redhat.com>
+#
+# This file is part of Ansible
+#
+# Ansible is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Ansible is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Ansible. If not, see <http://www.gnu.org/licenses/>.
+
+from ansible.compat.tests import unittest
+
+from ansible.inventory.group import Group
+from ansible.inventory.host import Host
+from ansible.errors import AnsibleError
+
+
+class TestGroup(unittest.TestCase):
+
+ def test_depth_update(self):
+ A = Group('A')
+ B = Group('B')
+ Z = Group('Z')
+ A.add_child_group(B)
+ A.add_child_group(Z)
+ self.assertEqual(A.depth, 0)
+ self.assertEqual(Z.depth, 1)
+ self.assertEqual(B.depth, 1)
+
+ def test_depth_update_dual_branches(self):
+ alpha = Group('alpha')
+ A = Group('A')
+ alpha.add_child_group(A)
+ B = Group('B')
+ A.add_child_group(B)
+ Z = Group('Z')
+ alpha.add_child_group(Z)
+ beta = Group('beta')
+ B.add_child_group(beta)
+ Z.add_child_group(beta)
+
+ self.assertEqual(alpha.depth, 0) # apex
+ self.assertEqual(beta.depth, 3) # alpha -> A -> B -> beta
+
+ omega = Group('omega')
+ omega.add_child_group(alpha)
+
+ # verify that both paths are traversed to get the max depth value
+ self.assertEqual(B.depth, 3) # omega -> alpha -> A -> B
+ self.assertEqual(beta.depth, 4) # B -> beta
+
+ def test_depth_recursion(self):
+ A = Group('A')
+ B = Group('B')
+ A.add_child_group(B)
+ # hypothetical of adding B as child group to A
+ A.parent_groups.append(B)
+ B.child_groups.append(A)
+ # can't update depths of groups, because of loop
+ with self.assertRaises(AnsibleError):
+ B._check_children_depth()
+
+ def test_loop_detection(self):
+ A = Group('A')
+ B = Group('B')
+ C = Group('C')
+ A.add_child_group(B)
+ B.add_child_group(C)
+ with self.assertRaises(AnsibleError):
+ C.add_child_group(A)
+
+ def test_populates_descendant_hosts(self):
+ A = Group('A')
+ B = Group('B')
+ C = Group('C')
+ h = Host('h')
+ C.add_host(h)
+ A.add_child_group(B) # B is child of A
+ B.add_child_group(C) # C is descendant of A
+ A.add_child_group(B)
+ self.assertEqual(set(h.groups), set([C, B, A]))
+ h2 = Host('h2')
+ C.add_host(h2)
+ self.assertEqual(set(h2.groups), set([C, B, A]))
+
+ def test_ancestor_example(self):
+ # see docstring for Group._walk_relationship
+ groups = {}
+ for name in ['A', 'B', 'C', 'D', 'E', 'F']:
+ groups[name] = Group(name)
+ # first row
+ groups['A'].add_child_group(groups['D'])
+ groups['B'].add_child_group(groups['D'])
+ groups['B'].add_child_group(groups['E'])
+ groups['C'].add_child_group(groups['D'])
+ # second row
+ groups['D'].add_child_group(groups['E'])
+ groups['D'].add_child_group(groups['F'])
+ groups['E'].add_child_group(groups['F'])
+
+ self.assertEqual(
+ set(groups['F'].get_ancestors()),
+ set([
+ groups['A'], groups['B'], groups['C'], groups['D'], groups['E']
+ ])
+ )
+
+ def test_ancestors_recursive_loop_safe(self):
+ '''
+ The get_ancestors method may be referenced before circular parenting
+ checks, so the method is expected to be stable even with loops
+ '''
+ A = Group('A')
+ B = Group('B')
+ A.parent_groups.append(B)
+ B.parent_groups.append(A)
+ # finishes in finite time
+ self.assertEqual(A.get_ancestors(), set([A, B]))