summaryrefslogtreecommitdiff
path: root/test/lib/ansible_test/_internal/import_analysis.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/lib/ansible_test/_internal/import_analysis.py')
-rw-r--r--test/lib/ansible_test/_internal/import_analysis.py62
1 files changed, 49 insertions, 13 deletions
diff --git a/test/lib/ansible_test/_internal/import_analysis.py b/test/lib/ansible_test/_internal/import_analysis.py
index daefe33eda..d115cffa7f 100644
--- a/test/lib/ansible_test/_internal/import_analysis.py
+++ b/test/lib/ansible_test/_internal/import_analysis.py
@@ -65,10 +65,10 @@ def get_python_module_utils_imports(compile_targets):
for result in matches:
results.add(result)
- import_path = os.path.join('lib/', '%s.py' % import_name.replace('.', '/'))
+ import_path = get_import_path(import_name)
if import_path not in imports_by_target_path:
- import_path = os.path.join('lib/', import_name.replace('.', '/'), '__init__.py')
+ import_path = get_import_path(import_name, package=True)
if import_path not in imports_by_target_path:
raise ApplicationError('Cannot determine path for module_utils import: %s' % import_name)
@@ -127,7 +127,7 @@ def get_python_module_utils_name(path): # type: (str) -> str
base_path = data_context().content.module_utils_path
if data_context().content.collection:
- prefix = 'ansible_collections.' + data_context().content.collection.prefix
+ prefix = 'ansible_collections.' + data_context().content.collection.prefix + 'plugins.module_utils.'
else:
prefix = 'ansible.module_utils.'
@@ -183,6 +183,23 @@ def extract_python_module_utils_imports(path, module_utils):
return finder.imports
+def get_import_path(name, package=False): # type: (str, bool) -> str
+ """Return a path from an import name."""
+ if package:
+ filename = os.path.join(name.replace('.', '/'), '__init__.py')
+ else:
+ filename = '%s.py' % name.replace('.', '/')
+
+ if name.startswith('ansible.module_utils.'):
+ path = os.path.join('lib', filename)
+ elif data_context().content.collection and name.startswith('ansible_collections.%s.plugins.module_utils.' % data_context().content.collection.full_name):
+ path = '/'.join(filename.split('/')[3:])
+ else:
+ raise Exception('Unexpected import name: %s' % name)
+
+ return path
+
+
class ModuleUtilFinder(ast.NodeVisitor):
"""AST visitor to find valid module_utils imports."""
def __init__(self, path, module_utils):
@@ -213,10 +230,9 @@ class ModuleUtilFinder(ast.NodeVisitor):
"""
self.generic_visit(node)
- for alias in node.names:
- if alias.name.startswith('ansible.module_utils.'):
- # import ansible.module_utils.MODULE[.MODULE]
- self.add_import(alias.name, node.lineno)
+ # import ansible.module_utils.MODULE[.MODULE]
+ # import ansible_collections.{ns}.{col}.plugins.module_utils.module_utils.MODULE[.MODULE]
+ self.add_imports([alias.name for alias in node.names], node.lineno)
# noinspection PyPep8Naming
# pylint: disable=locally-disabled, invalid-name
@@ -229,11 +245,14 @@ class ModuleUtilFinder(ast.NodeVisitor):
if not node.module:
return
- if node.module == 'ansible.module_utils' or node.module.startswith('ansible.module_utils.'):
- for alias in node.names:
- # from ansible.module_utils import MODULE[, MODULE]
- # from ansible.module_utils.MODULE[.MODULE] import MODULE[, MODULE]
- self.add_import('%s.%s' % (node.module, alias.name), node.lineno)
+ if not node.module.startswith('ansible'):
+ return
+
+ # from ansible.module_utils import MODULE[, MODULE]
+ # from ansible.module_utils.MODULE[.MODULE] import MODULE[, MODULE]
+ # from ansible_collections.{ns}.{col}.plugins.module_utils import MODULE[, MODULE]
+ # from ansible_collections.{ns}.{col}.plugins.module_utils.MODULE[.MODULE] import MODULE[, MODULE]
+ self.add_imports(['%s.%s' % (node.module, alias.name) for alias in node.names], node.lineno)
def add_import(self, name, line_number):
"""
@@ -242,7 +261,7 @@ class ModuleUtilFinder(ast.NodeVisitor):
"""
import_name = name
- while len(name) > len('ansible.module_utils.'):
+ while self.is_module_util_name(name):
if name in self.module_utils:
if name not in self.imports:
display.info('%s:%d imports module_utils: %s' % (self.path, line_number, name), verbosity=5)
@@ -258,3 +277,20 @@ class ModuleUtilFinder(ast.NodeVisitor):
# Treat this error as a warning so tests can be executed as best as possible.
# This error should be detected by unit or integration tests.
display.warning('%s:%d Invalid module_utils import: %s' % (self.path, line_number, import_name))
+
+ def add_imports(self, names, line_no): # type: (t.List[str], int) -> None
+ """Add the given import names if they are module_utils imports."""
+ for name in names:
+ if self.is_module_util_name(name):
+ self.add_import(name, line_no)
+
+ @staticmethod
+ def is_module_util_name(name): # type: (str) -> bool
+ """Return True if the given name is a module_util name for the content under test. External module_utils are ignored."""
+ if data_context().content.is_ansible and name.startswith('ansible.module_utils.'):
+ return True
+
+ if data_context().content.collection and name.startswith('ansible_collections.%s.plugins.module_utils.' % data_context().content.collection.full_name):
+ return True
+
+ return False