diff options
author | Matt Clay <mclay@redhat.com> | 2020-09-03 18:42:38 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-09-03 18:42:38 -0700 |
commit | 909ac41247387231c85c59c42645381c6a31928f (patch) | |
tree | bd68c7859ef96e907720f059001ab514fa2129d3 | |
parent | 04ff489698dd1d2c3e4c720242b30d6da54a4cf1 (diff) | |
download | ansible-909ac41247387231c85c59c42645381c6a31928f.tar.gz |
[stable-2.8] Backport ansible-test CI provider support. (#71625)
* Add types.py from devel to support backport.
* [stable-2.9] Backport ansible-test CI provider support. (#71614)
* Add encoding.py from devel to support backports.
* Add io.py from devel to support backports.
* Update ansible-test support for CI providers. (#69522)
Refactored CI provider code to simplify multiple provider support and addition of new providers.
(cherry picked from commit d8e0aadc0d630c776349b745243c20b62f22ebec)
* Add Shippable request signing to ansible-test. (#69526)
(cherry picked from commit e7c2eb519be2612832de15aa85c5d7618e979f85)
* ansible-test local change detection: use --base-branch if specified (#69508)
(cherry picked from commit 43acd61901e26fdb2faf89baa3e9b2c7647fc89e)
* Add Azure Pipelines support to ansible-test.
(cherry picked from commit 8ffaed00f87f1ae41f86cb4a7fe399d8aa339286)
* Update ansible-test remote endpoint handling. (#71413)
* Request ansible-core-ci resources by provider.
* Remove obsolete us-east-2 CI endpoint.
* Add new --remote-endpoint option.
* Add warning for --remote-aws-region option.
* Update service endpoints.
* Allow non-standard remote stages.
* Add changelog fragment.
(cherry picked from commit d099591964d6fedc174eb1d3fc1bbee8d2ba0f16)
* Fix ansible-test coverage traceback. (#71446)
* Add integration test for ansible-test coverage.
* Fix ansible-test coverage traceback.
* Fix coverage reporting on Python 2.6.
(cherry picked from commit f5b6df14ab2e691f5f059fff0fcf59449132549f)
* Use new endpoint for Parallels based instances.
(cherry picked from commit 98febab9751ed73b70c3004fe9dafb4dc09abf26)
* Add pause to avoid same mtime in test.
(cherry picked from commit 3d769f3a76a867ea61df16eee328bd40fc91a950)
Co-authored-by: Felix Fontein <felix@fontein.de>
(cherry picked from commit 417e408f596cbf48503f2240b8e2b9a97ed05d51)
38 files changed, 1390 insertions, 464 deletions
diff --git a/changelogs/fragments/69508-ansible-test-local-changes-detection.yml b/changelogs/fragments/69508-ansible-test-local-changes-detection.yml new file mode 100644 index 0000000000..1157f945fd --- /dev/null +++ b/changelogs/fragments/69508-ansible-test-local-changes-detection.yml @@ -0,0 +1,2 @@ +bugfixes: +- "ansible-test - for local change detection, allow to specify branch to compare to with ``--base-branch`` for all types of tests (https://github.com/ansible/ansible/pull/69508)." diff --git a/changelogs/fragments/ansible-test-ci-support-azure.yml b/changelogs/fragments/ansible-test-ci-support-azure.yml new file mode 100644 index 0000000000..541cfebdf1 --- /dev/null +++ b/changelogs/fragments/ansible-test-ci-support-azure.yml @@ -0,0 +1,2 @@ +minor_changes: + - ansible-test - Added CI provider support for Azure Pipelines. diff --git a/changelogs/fragments/ansible-test-ci-support-shippable-auth.yml b/changelogs/fragments/ansible-test-ci-support-shippable-auth.yml new file mode 100644 index 0000000000..0ac250799f --- /dev/null +++ b/changelogs/fragments/ansible-test-ci-support-shippable-auth.yml @@ -0,0 +1,2 @@ +minor_changes: + - ansible-test - Added support for Ansible Core CI request signing for Shippable. diff --git a/changelogs/fragments/ansible-test-ci-support.yml b/changelogs/fragments/ansible-test-ci-support.yml new file mode 100644 index 0000000000..8739a977a8 --- /dev/null +++ b/changelogs/fragments/ansible-test-ci-support.yml @@ -0,0 +1,2 @@ +minor_changes: + - ansible-test - Refactored CI related logic into a basic provider abstraction. diff --git a/changelogs/fragments/ansible-test-endpoint-update.yml b/changelogs/fragments/ansible-test-endpoint-update.yml new file mode 100644 index 0000000000..b5634afc9b --- /dev/null +++ b/changelogs/fragments/ansible-test-endpoint-update.yml @@ -0,0 +1,7 @@ +minor_changes: + - ansible-test - Allow custom ``--remote-stage`` options for development and testing. + - ansible-test - Update built-in service endpoints for the ``--remote`` option. + - ansible-test - Show a warning when the obsolete ``--remote-aws-region`` option is used. + - ansible-test - Support custom remote endpoints with the ``--remote-endpoint`` option. + - ansible-test - Remove the discontinued ``us-east-2`` choice from the ``--remote-aws-region`` option. + - ansible-test - Request remote resources by provider name for all provider types. diff --git a/changelogs/fragments/ansible-test-parallels-endpoint.yml b/changelogs/fragments/ansible-test-parallels-endpoint.yml new file mode 100644 index 0000000000..71b74e3bdd --- /dev/null +++ b/changelogs/fragments/ansible-test-parallels-endpoint.yml @@ -0,0 +1,2 @@ +minor_changes: + - ansible-test - Use new endpoint for Parallels based instances with the ``--remote`` option. diff --git a/test/integration/targets/file/tasks/directory_as_dest.yml b/test/integration/targets/file/tasks/directory_as_dest.yml index b51fa72a7f..9b6ddb5dc9 100644 --- a/test/integration/targets/file/tasks/directory_as_dest.yml +++ b/test/integration/targets/file/tasks/directory_as_dest.yml @@ -242,6 +242,10 @@ follow: False register: file8_initial_dir_stat +- name: Pause to ensure stat times are not the exact same + pause: + seconds: 1 + - name: Use touch with directory as dest file: dest: '{{output_dir}}/sub1' diff --git a/test/runner/lib/changes.py b/test/runner/lib/changes.py deleted file mode 100644 index ca31219b9c..0000000000 --- a/test/runner/lib/changes.py +++ /dev/null @@ -1,175 +0,0 @@ -"""Detect changes in Ansible code.""" - -from __future__ import absolute_import, print_function - -import re -import os - -from lib.util import ( - ApplicationError, - SubprocessError, - MissingEnvironmentVariable, - CommonConfig, - display, -) - -from lib.http import ( - HttpClient, - urlencode, -) - -from lib.git import ( - Git, -) - - -class InvalidBranch(ApplicationError): - """Exception for invalid branch specification.""" - def __init__(self, branch, reason): - """ - :type branch: str - :type reason: str - """ - message = 'Invalid branch: %s\n%s' % (branch, reason) - - super(InvalidBranch, self).__init__(message) - - self.branch = branch - - -class ChangeDetectionNotSupported(ApplicationError): - """Exception for cases where change detection is not supported.""" - pass - - -class ShippableChanges(object): - """Change information for Shippable build.""" - def __init__(self, args, git): - """ - :type args: CommonConfig - :type git: Git - """ - self.args = args - - try: - self.branch = os.environ['BRANCH'] - self.is_pr = os.environ['IS_PULL_REQUEST'] == 'true' - self.is_tag = os.environ['IS_GIT_TAG'] == 'true' - self.commit = os.environ['COMMIT'] - self.project_id = os.environ['PROJECT_ID'] - self.commit_range = os.environ['SHIPPABLE_COMMIT_RANGE'] - except KeyError as ex: - raise MissingEnvironmentVariable(name=ex.args[0]) - - if self.is_tag: - raise ChangeDetectionNotSupported('Change detection is not supported for tags.') - - if self.is_pr: - self.paths = sorted(git.get_diff_names([self.commit_range])) - self.diff = git.get_diff([self.commit_range]) - else: - merge_runs = self.get_merge_runs(self.project_id, self.branch) - last_successful_commit = self.get_last_successful_commit(git, merge_runs) - - if last_successful_commit: - self.paths = sorted(git.get_diff_names([last_successful_commit, self.commit])) - self.diff = git.get_diff([last_successful_commit, self.commit]) - else: - # first run for branch - self.paths = None # act as though change detection not enabled, do not filter targets - self.diff = [] - - def get_merge_runs(self, project_id, branch): - """ - :type project_id: str - :type branch: str - :rtype: list[dict] - """ - params = dict( - isPullRequest='false', - projectIds=project_id, - branch=branch, - ) - - client = HttpClient(self.args, always=True) - response = client.get('https://api.shippable.com/runs?%s' % urlencode(params)) - return response.json() - - @staticmethod - def get_last_successful_commit(git, merge_runs): - """ - :type git: Git - :type merge_runs: dict | list[dict] - :rtype: str - """ - if 'id' in merge_runs and merge_runs['id'] == 4004: - display.warning('Unable to find project. Cannot determine changes. All tests will be executed.') - return None - - successful_commits = set(run['commitSha'] for run in merge_runs if run['statusCode'] == 30) - commit_history = git.get_rev_list(max_count=100) - ordered_successful_commits = [commit for commit in commit_history if commit in successful_commits] - last_successful_commit = ordered_successful_commits[0] if ordered_successful_commits else None - - if last_successful_commit is None: - display.warning('No successful commit found. All tests will be executed.') - - return last_successful_commit - - -class LocalChanges(object): - """Change information for local work.""" - def __init__(self, args, git): - """ - :type args: CommonConfig - :type git: Git - """ - self.args = args - self.current_branch = git.get_branch() - - if self.is_official_branch(self.current_branch): - raise InvalidBranch(branch=self.current_branch, - reason='Current branch is not a feature branch.') - - self.fork_branch = None - self.fork_point = None - - self.local_branches = sorted(git.get_branches()) - self.official_branches = sorted([b for b in self.local_branches if self.is_official_branch(b)]) - - for self.fork_branch in self.official_branches: - try: - self.fork_point = git.get_branch_fork_point(self.fork_branch) - break - except SubprocessError: - pass - - if self.fork_point is None: - raise ApplicationError('Unable to auto-detect fork branch and fork point.') - - # tracked files (including unchanged) - self.tracked = sorted(git.get_file_names(['--cached'])) - # untracked files (except ignored) - self.untracked = sorted(git.get_file_names(['--others', '--exclude-standard'])) - # tracked changes (including deletions) committed since the branch was forked - self.committed = sorted(git.get_diff_names([self.fork_point, 'HEAD'])) - # tracked changes (including deletions) which are staged - self.staged = sorted(git.get_diff_names(['--cached'])) - # tracked changes (including deletions) which are not staged - self.unstaged = sorted(git.get_diff_names([])) - # diff of all tracked files from fork point to working copy - self.diff = git.get_diff([self.fork_point]) - - @staticmethod - def is_official_branch(name): - """ - :type name: str - :rtype: bool - """ - if name == 'devel': - return True - - if re.match(r'^stable-[0-9]+\.[0-9]+$', name): - return True - - return False diff --git a/test/runner/lib/ci/__init__.py b/test/runner/lib/ci/__init__.py new file mode 100644 index 0000000000..d6e2ad6e75 --- /dev/null +++ b/test/runner/lib/ci/__init__.py @@ -0,0 +1,227 @@ +"""Support code for CI environments.""" +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import abc +import base64 +import json +import os +import tempfile + + +from .. import types as t + +from ..encoding import ( + to_bytes, + to_text, +) + +from ..io import ( + read_text_file, + write_text_file, +) + +from ..config import ( + CommonConfig, + TestConfig, +) + +from ..util import ( + ABC, + ApplicationError, + display, + get_subclasses, + import_plugins, + raw_command, +) + + +class ChangeDetectionNotSupported(ApplicationError): + """Exception for cases where change detection is not supported.""" + + +class AuthContext: + """Context information required for Ansible Core CI authentication.""" + def __init__(self): # type: () -> None + self.region = None # type: t.Optional[str] + + +class CIProvider(ABC): + """Base class for CI provider plugins.""" + priority = 500 + + @staticmethod + @abc.abstractmethod + def is_supported(): # type: () -> bool + """Return True if this provider is supported in the current running environment.""" + + @property + @abc.abstractmethod + def code(self): # type: () -> str + """Return a unique code representing this provider.""" + + @property + @abc.abstractmethod + def name(self): # type: () -> str + """Return descriptive name for this provider.""" + + @abc.abstractmethod + def generate_resource_prefix(self): # type: () -> str + """Return a resource prefix specific to this CI provider.""" + + @abc.abstractmethod + def get_base_branch(self): # type: () -> str + """Return the base branch or an empty string.""" + + @abc.abstractmethod + def detect_changes(self, args): # type: (TestConfig) -> t.Optional[t.List[str]] + """Initialize change detection.""" + + @abc.abstractmethod + def supports_core_ci_auth(self, context): # type: (AuthContext) -> bool + """Return True if Ansible Core CI is supported.""" + + @abc.abstractmethod + def prepare_core_ci_auth(self, context): # type: (AuthContext) -> t.Dict[str, t.Any] + """Return authentication details for Ansible Core CI.""" + + @abc.abstractmethod + def get_git_details(self, args): # type: (CommonConfig) -> t.Optional[t.Dict[str, t.Any]] + """Return details about git in the current environment.""" + + +def get_ci_provider(): # type: () -> CIProvider + """Return a CI provider instance for the current environment.""" + try: + return get_ci_provider.provider + except AttributeError: + pass + + provider = None + + import_plugins('ci') + + candidates = sorted(get_subclasses(CIProvider), key=lambda c: (c.priority, c.__name__)) + + for candidate in candidates: + if candidate.is_supported(): + provider = candidate() + break + + if provider.code: + display.info('Detected CI provider: %s' % provider.name) + + get_ci_provider.provider = provider + + return provider + + +class AuthHelper(ABC): + """Public key based authentication helper for Ansible Core CI.""" + def sign_request(self, request): # type: (t.Dict[str, t.Any]) -> None + """Sign the given auth request and make the public key available.""" + payload_bytes = to_bytes(json.dumps(request, sort_keys=True)) + signature_raw_bytes = self.sign_bytes(payload_bytes) + signature = to_text(base64.b64encode(signature_raw_bytes)) + + request.update(signature=signature) + + def initialize_private_key(self): # type: () -> str + """ + Initialize and publish a new key pair (if needed) and return the private key. + The private key is cached across ansible-test invocations so it is only generated and published once per CI job. + """ + path = os.path.expanduser('~/.ansible-core-ci-private.key') + + if os.path.exists(to_bytes(path)): + private_key_pem = read_text_file(path) + else: + private_key_pem = self.generate_private_key() + write_text_file(path, private_key_pem) + + return private_key_pem + + @abc.abstractmethod + def sign_bytes(self, payload_bytes): # type: (bytes) -> bytes + """Sign the given payload and return the signature, initializing a new key pair if required.""" + + @abc.abstractmethod + def publish_public_key(self, public_key_pem): # type: (str) -> None + """Publish the given public key.""" + + @abc.abstractmethod + def generate_private_key(self): # type: () -> str + """Generate a new key pair, publishing the public key and returning the private key.""" + + +class CryptographyAuthHelper(AuthHelper, ABC): # pylint: disable=abstract-method + """Cryptography based public key based authentication helper for Ansible Core CI.""" + def sign_bytes(self, payload_bytes): # type: (bytes) -> bytes + """Sign the given payload and return the signature, initializing a new key pair if required.""" + # import cryptography here to avoid overhead and failures in environments which do not use/provide it + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import hashes + from cryptography.hazmat.primitives.asymmetric import ec + from cryptography.hazmat.primitives.serialization import load_pem_private_key + + private_key_pem = self.initialize_private_key() + private_key = load_pem_private_key(to_bytes(private_key_pem), None, default_backend()) + + signature_raw_bytes = private_key.sign(payload_bytes, ec.ECDSA(hashes.SHA256())) + + return signature_raw_bytes + + def generate_private_key(self): # type: () -> str + """Generate a new key pair, publishing the public key and returning the private key.""" + # import cryptography here to avoid overhead and failures in environments which do not use/provide it + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import ec + + private_key = ec.generate_private_key(ec.SECP384R1(), default_backend()) + public_key = private_key.public_key() + + private_key_pem = to_text(private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + )) + + public_key_pem = to_text(public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + )) + + self.publish_public_key(public_key_pem) + + return private_key_pem + + +class OpenSSLAuthHelper(AuthHelper, ABC): # pylint: disable=abstract-method + """OpenSSL based public key based authentication helper for Ansible Core CI.""" + def sign_bytes(self, payload_bytes): # type: (bytes) -> bytes + """Sign the given payload and return the signature, initializing a new key pair if required.""" + private_key_pem = self.initialize_private_key() + + with tempfile.NamedTemporaryFile() as private_key_file: + private_key_file.write(to_bytes(private_key_pem)) + private_key_file.flush() + + with tempfile.NamedTemporaryFile() as payload_file: + payload_file.write(payload_bytes) + payload_file.flush() + + with tempfile.NamedTemporaryFile() as signature_file: + raw_command(['openssl', 'dgst', '-sha256', '-sign', private_key_file.name, '-out', signature_file.name, payload_file.name], capture=True) + signature_raw_bytes = signature_file.read() + + return signature_raw_bytes + + def generate_private_key(self): # type: () -> str + """Generate a new key pair, publishing the public key and returning the private key.""" + private_key_pem = raw_command(['openssl', 'ecparam', '-genkey', '-name', 'secp384r1', '-noout'], capture=True)[0] + public_key_pem = raw_command(['openssl', 'ec', '-pubout'], data=private_key_pem, capture=True)[0] + + self.publish_public_key(public_key_pem) + + return private_key_pem diff --git a/test/runner/lib/ci/azp.py b/test/runner/lib/ci/azp.py new file mode 100644 index 0000000000..5b816b547c --- /dev/null +++ b/test/runner/lib/ci/azp.py @@ -0,0 +1,262 @@ +"""Support code for working with Azure Pipelines.""" +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os +import re +import tempfile +import uuid + +from .. import types as t + +from ..encoding import ( + to_bytes, +) + +from ..config import ( + CommonConfig, + TestConfig, +) + +from ..git import ( + Git, +) + +from ..http import ( + HttpClient, + urlencode, +) + +from ..util import ( + display, + MissingEnvironmentVariable, +) + +from . import ( + AuthContext, + ChangeDetectionNotSupported, + CIProvider, + CryptographyAuthHelper, +) + +CODE = 'azp' + + +class AzurePipelines(CIProvider): + """CI provider implementation for Azure Pipelines.""" + def __init__(self): + self.auth = AzurePipelinesAuthHelper() + + @staticmethod + def is_supported(): # type: () -> bool + """Return True if this provider is supported in the current running environment.""" + return os.environ.get('SYSTEM_COLLECTIONURI', '').startswith('https://dev.azure.com/') + + @property + def code(self): # type: () -> str + """Return a unique code representing this provider.""" + return CODE + + @property + def name(self): # type: () -> str + """Return descriptive name for this provider.""" + return 'Azure Pipelines' + + def generate_resource_prefix(self): # type: () -> str + """Return a resource prefix specific to this CI provider.""" + try: + prefix = 'azp-%s-%s-%s' % ( + os.environ['BUILD_BUILDID'], + os.environ['SYSTEM_JOBATTEMPT'], + os.environ['SYSTEM_JOBIDENTIFIER'], + ) + except KeyError as ex: + raise MissingEnvironmentVariable(name=ex.args[0]) + + prefix = re.sub(r'[^a-zA-Z0-9]+', '-', prefix) + + return prefix + + def get_base_branch(self): # type: () -> str + """Return the base branch or an empty string.""" + base_branch = os.environ.get('SYSTEM_PULLREQUEST_TARGETBRANCH') or os.environ.get('BUILD_SOURCEBRANCHNAME') + + if base_branch: + base_branch = 'origin/%s' % base_branch + + return base_branch or '' + + def detect_changes(self, args): # type: (TestConfig) -> t.Optional[t.List[str]] + """Initialize change detection.""" + result = AzurePipelinesChanges(args) + + if result.is_pr: + job_type = 'pull request' + else: + job_type = 'merge commit' + + display.info('Processing %s for branch %s commit %s' % (job_type, result.branch, result.commit)) + + if not args.metadata.changes: + args.metadata.populate_changes(result.diff) + + if result.paths is None: + # There are several likely causes of this: + # - First run on a new branch. + # - Too many pull requests passed since the last merge run passed. + display.warning('No successful commit found. All tests will be executed.') + + return result.paths + + def supports_core_ci_auth(self, context): # type: (AuthContext) -> bool + """Return True if Ansible Core CI is supported.""" + return True + + def prepare_core_ci_auth(self, context): # type: (AuthContext) -> t.Dict[str, t.Any] + """Return authentication details for Ansible Core CI.""" + try: + request = dict( + org_name=os.environ['SYSTEM_COLLECTIONURI'].strip('/').split('/')[-1], + project_name=os.environ['SYSTEM_TEAMPROJECT'], + build_id=int(os.environ['BUILD_BUILDID']), + task_id=str(uuid.UUID(os.environ['SYSTEM_TASKINSTANCEID'])), + ) + except KeyError as ex: + raise MissingEnvironmentVariable(name=ex.args[0]) + + self.auth.sign_request(request) + + auth = dict( + azp=request, + ) + + return auth + + def get_git_details(self, args): # type: (CommonConfig) -> t.Optional[t.Dict[str, t.Any]] + """Return details about git in the current environment.""" + changes = AzurePipelinesChanges(args) + + details = dict( + base_commit=changes.base_commit, + commit=changes.commit, + ) + + return details + + +class AzurePipelinesAuthHelper(CryptographyAuthHelper): + """ + Authentication helper for Azure Pipelines. + Based on cryptography since it is provided by the default Azure Pipelines environment. + """ + def publish_public_key(self, public_key_pem): # type: (str) -> None + """Publish the given public key.""" + # the temporary file cannot be deleted because we do not know when the agent has processed it + with tempfile.NamedTemporaryFile(prefix='public-key-', suffix='.pem', delete=False) as public_key_file: + public_key_file.write(to_bytes(public_key_pem)) + public_key_file.flush() + + # make the agent aware of the public key by declaring it as an attachment + vso_add_attachment('ansible-core-ci', 'public-key.pem', public_key_file.name) + + +class AzurePipelinesChanges: + """Change information for an Azure Pipelines build.""" + def __init__(self, args): # type: (CommonConfig) -> None + self.args = args + self.git = Git(args) + + try: + self.org_uri = os.environ['SYSTEM_COLLECTIONURI'] # ex: https://dev.azure.com/{org}/ + self.project = os.environ['SYSTEM_TEAMPROJECT'] + self.repo_type = os.environ['BUILD_REPOSITORY_PROVIDER'] # ex: GitHub + self.source_branch = os.environ['BUILD_SOURCEBRANCH'] + self.source_branch_name = os.environ['BUILD_SOURCEBRANCHNAME'] + self.pr_branch_name = os.environ.get('SYSTEM_PULLREQUEST_TARGETBRANCH') + except KeyError as ex: + raise MissingEnvironmentVariable(name=ex.args[0]) + + if self.source_branch.startswith('refs/tags/'): + raise ChangeDetectionNotSupported('Change detection is not supported for tags.') + + self.org = self.org_uri.strip('/').split('/')[-1] + self.is_pr = self.pr_branch_name is not None + + if self.is_pr: + # HEAD is a merge commit of the PR branch into the target branch + # HEAD^1 is HEAD of the target branch (first parent of merge commit) + # HEAD^2 is HEAD of the PR branch (second parent of merge commit) + # see: https://git-scm.com/docs/gitrevisions + self.branch = self.pr_branch_name + self.base_commit = 'HEAD^1' + self.commit = 'HEAD^2' + else: + commits = self.get_successful_merge_run_commits() + + self.branch = self.source_branch_name + self.base_commit = self.get_last_successful_commit(commits) + self.commit = 'HEAD' + + self.commit = self.git.run_git(['rev-parse', self.commit]).strip() + + if self.base_commit: + self.base_commit = self.git.run_git(['rev-parse', self.base_commit]).strip() + + # <rev1>...<rev2> + # Include commits that are reachable from <rev2> but exclude those that are reachable from <rev1>. + # see: https://git-scm.com/docs/gitrevisions + dot_range = '%s..%s' % (self.base_commit, self.commit) + + self.paths = sorted(self.git.get_diff_names([dot_range])) + self.diff = self.git.get_diff([dot_range]) + else: + self.paths = None # act as though change detection not enabled, do not filter targets + self.diff = [] + + def get_successful_merge_run_commits(self): # type: () -> t.Set[str] + """Return a set of recent successsful merge commits from Azure Pipelines.""" + parameters = dict( + maxBuildsPerDefinition=100, # max 5000 + queryOrder='queueTimeDescending', # assumes under normal circumstances that later queued jobs are for later commits + resultFilter='succeeded', + reasonFilter='batchedCI', # may miss some non-PR reasons, the alternative is to filter the list after receiving it + repositoryType=self.repo_type, + repositoryId='%s/%s' % (self.org, self.project), + ) + + url = '%s%s/build/builds?%s' % (self.org_uri, self.project, urlencode(parameters)) + + http = HttpClient(self.args) + response = http.get(url) + + # noinspection PyBroadException + try: + result = response.json() + except Exception: # pylint: disable=broad-except + # most likely due to a private project, which returns an HTTP 203 response with HTML + display.warning('Unable to find project. Cannot determine changes. All tests will be executed.') + return set() + + commits = set(build['sourceVersion'] for build in result['value']) + + return commits + + def get_last_successful_commit(self, commits): # type: (t.Set[str]) -> t.Optional[str] + """Return the last successful commit from git history that is found in the given commit list, or None.""" + commit_history = self.git.get_rev_list(max_count=100) + ordered_successful_commits = [commit for commit in commit_history if commit in commits] + last_successful_commit = ordered_successful_commits[0] if ordered_successful_commits else None + return last_successful_commit + + +def vso_add_attachment(file_type, file_name, path): # type: (str, str, str) -> None + """Upload and attach a file to the current timeline record.""" + vso('task.addattachment', dict(type=file_type, name=file_name), path) + + +def vso(name, data, message): # type: (str, t.Dict[str, str], str) -> None + """ + Write a logging command for the Azure Pipelines agent to process. + See: https://docs.microsoft.com/en-us/azure/devops/pipelines/scripts/logging-commands?view=azure-devops&tabs=bash + """ + display.info('##vso[%s %s]%s' % (name, ';'.join('='.join((key, value)) for key, value in data.items()), message)) diff --git a/test/runner/lib/ci/local.py b/test/runner/lib/ci/local.py new file mode 100644 index 0000000000..be87129a8c --- /dev/null +++ b/test/runner/lib/ci/local.py @@ -0,0 +1,217 @@ +"""Support code for working without a supported CI provider.""" +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os +import platform +import random +import re + +from .. import types as t + +from ..config import ( + CommonConfig, + TestConfig, +) + +from ..io import ( + read_text_file, +) + +from ..git import ( + Git, +) + +from ..util import ( + ApplicationError, + display, + is_binary_file, + SubprocessError, +) + +from . import ( + AuthContext, + CIProvider, +) + +CODE = '' # not really a CI provider, so use an empty string for the code + + +class Local(CIProvider): + """CI provider implementation when not using CI.""" + priority = 1000 + + @staticmethod + def is_supported(): # type: () -> bool + """Return True if this provider is supported in the current running environment.""" + return True + + @property + def code(self): # type: () -> str + """Return a unique code representing this provider.""" + return CODE + + @property + def name(self): # type: () -> str + """Return descriptive name for this provider.""" + return 'Local' + + def generate_resource_prefix(self): # type: () -> str + """Return a resource prefix specific to this CI provider.""" + node = re.sub(r'[^a-zA-Z0-9]+', '-', platform.node().split('.')[0]).lower() + + prefix = 'ansible-test-%s-%d' % (node, random.randint(10000000, 99999999)) + + return prefix + + def get_base_branch(self): # type: () -> str + """Return the base branch or an empty string.""" + return '' + + def detect_changes(self, args): # type: (TestConfig) -> t.Optional[t.List[str]] + """Initialize change detection.""" + result = LocalChanges(args) + + display.info('Detected branch %s forked from %s at commit %s' % ( + result.current_branch, result.fork_branch, result.fork_point)) + + if result.untracked and not args.untracked: + display.warning('Ignored %s untracked file(s). Use --untracked to include them.' % + len(result.untracked)) + + if result.committed and not args.committed: + display.warning('Ignored %s committed change(s). Omit --ignore-committed to include them.' % + len(result.committed)) + + if result.staged and not args.staged: + display.warning('Ignored %s staged change(s). Omit --ignore-staged to include them.' % + len(result.staged)) + + if result.unstaged and not args.unstaged: + display.warning('Ignored %s unstaged change(s). Omit --ignore-unstaged to include them.' % + len(result.unstaged)) + + names = set() + + if args.tracked: + names |= set(result.tracked) + if args.untracked: + names |= set(result.untracked) + if args.committed: + names |= set(result.committed) + if args.staged: + names |= set(result.staged) + if args.unstaged: + names |= set(result.unstaged) + + if not args.metadata.changes: + args.metadata.populate_changes(result.diff) + + for path in result.untracked: + if is_binary_file(path): + args.metadata.changes[path] = ((0, 0),) + continue + + line_count = len(read_text_file(path).splitlines()) + + args.metadata.changes[path] = ((1, line_count),) + + return sorted(names) + + def supports_core_ci_auth(self, context): # type: (AuthContext) -> bool + """Return True if Ansible Core CI is supported.""" + path = self._get_aci_key_path(context) + return os.path.exists(path) + + def prepare_core_ci_auth(self, context): # type: (AuthContext) -> t.Dict[str, t.Any] + """Return authentication details for Ansible Core CI.""" + path = self._get_aci_key_path(context) + auth_key = read_text_file(path).strip() + + request = dict( + key=auth_key, + nonce=None, + ) + + auth = dict( + remote=request, + ) + + return auth + + def get_git_details(self, args): # type: (CommonConfig) -> t.Optional[t.Dict[str, t.Any]] + """Return details about git in the current environment.""" + return None # not yet implemented for local + + def _get_aci_key_path(self, context): # type: (AuthContext) -> str + path = os.path.expanduser('~/.ansible-core-ci.key') + + if context.region: + path += '.%s' % context.region + + return path + + +class InvalidBranch(ApplicationError): + """Exception for invalid branch specification.""" + def __init__(self, branch, reason): # type: (str, str) -> None + message = 'Invalid branch: %s\n%s' % (branch, reason) + + super(InvalidBranch, self).__init__(message) + + self.branch = branch + + +class LocalChanges: + """Change information for local work.""" + def __init__(self, args): # type: (CommonConfig) -> None + self.args = args + self.git = Git(args) + + self.current_branch = self.git.get_branch() + + if self.is_official_branch(self.current_branch): + raise InvalidBranch(branch=self.current_branch, + reason='Current branch is not a feature branch.') + + self.fork_branch = None + self.fork_point = None + + self.local_branches = sorted(self.git.get_branches()) + self.official_branches = sorted([b for b in self.local_branches if self.is_official_branch(b)]) + + for self.fork_branch in self.official_branches: + try: + self.fork_point = self.git.get_branch_fork_point(self.fork_branch) + break + except SubprocessError: + pass + + if self.fork_point is None: + raise ApplicationError('Unable to auto-detect fork branch and fork point.') + + # tracked files (including unchanged) + self.tracked = sorted(self.git.get_file_names(['--cached'])) + # untracked files (except ignored) + self.untracked = sorted(self.git.get_file_names(['--others', '--exclude-standard'])) + # tracked changes (including deletions) committed since the branch was forked + self.committed = sorted(self.git.get_diff_names([self.fork_point, 'HEAD'])) + # tracked changes (including deletions) which are staged + self.staged = sorted(self.git.get_diff_names(['--cached'])) + # tracked changes (including deletions) which are not staged + self.unstaged = sorted(self.git.get_diff_names([])) + # diff of all tracked files from fork point to working copy + self.diff = self.git.get_diff([self.fork_point]) + + def is_official_branch(self, name): # type: (str) -> bool + """Return True if the given branch name an official branch for development or releases.""" + if self.args.base_branch: + return name == self.args.base_branch + + if name == 'devel': + return True + + if re.match(r'^stable-[0-9]+\.[0-9]+$', name): + return True + + return False diff --git a/test/runner/lib/ci/shippable.py b/test/runner/lib/ci/shippable.py new file mode 100644 index 0000000000..bdb2f5964c --- /dev/null +++ b/test/runner/lib/ci/shippable.py @@ -0,0 +1,269 @@ +"""Support code for working with Shippable.""" +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os +import re +import time + +from .. import types as t + +from ..config import ( + CommonConfig, + TestConfig, +) + +from ..git import ( + Git, +) + +from ..http import ( + HttpClient, + urlencode, +) + +from ..util import ( + ApplicationError, + display, + MissingEnvironmentVariable, + SubprocessError, +) + +from . import ( + AuthContext, + ChangeDetectionNotSupported, + CIProvider, + OpenSSLAuthHelper, +) + + +CODE = 'shippable' + + +class Shippable(CIProvider): + """CI provider implementation for Shippable.""" + def __init__(self): + self.auth = ShippableAuthHelper() + + @staticmethod + def is_supported(): # type: () -> bool + """Return True if this provider is supported in the current running environment.""" + return os.environ.get('SHIPPABLE') == 'true' + + @property + def code(self): # type: () -> str + """Return a unique code representing this provider.""" + return CODE + + @property + def name(self): # type: () -> str + """Return descriptive name for this provider.""" + return 'Shippable' + + def generate_resource_prefix(self): # type: () -> str + """Return a resource prefix specific to this CI provider.""" + try: + prefix = 'shippable-%s-%s' % ( + os.environ['SHIPPABLE_BUILD_NUMBER'], + os.environ['SHIPPABLE_JOB_NUMBER'], + ) + except KeyError as ex: + raise MissingEnvironmentVariable(name=ex.args[0]) + + return prefix + + def get_base_branch(self): # type: () -> str + """Return the base branch or an empty string.""" + base_branch = os.environ.get('BASE_BRANCH') + + if base_branch: + base_branch = 'origin/%s' % base_branch + + return base_branch or '' + + def detect_changes(self, args): # type: (TestConfig) -> t.Optional[t.List[str]] + """Initialize change detection.""" + result = ShippableChanges(args) + + if result.is_pr: + job_type = 'pull request' + elif result.is_tag: + job_type = 'tag' + else: + job_type = 'merge commit' + + display.info('Processing %s for branch %s commit %s' % (job_type, result.branch, result.commit)) + + if not args.metadata.changes: + args.metadata.populate_changes(result.diff) + + if result.paths is None: + # There are several likely causes of this: + # - First run on a new branch. + # - Too many pull requests passed since the last merge run passed. + display.warning('No successful commit found. All tests will be executed.') + + return result.paths + + def supports_core_ci_auth(self, context): # type: (AuthContext) -> bool + """Return True if Ansible Core CI is supported.""" + return True + + def prepare_core_ci_auth(self, context): # type: (AuthContext) -> t.Dict[str, t.Any] + """Return authentication details for Ansible Core CI.""" + try: + request = dict( + run_id=os.environ['SHIPPABLE_BUILD_ID'], + job_number=int(os.environ['SHIPPABLE_JOB_NUMBER']), + ) + except KeyError as ex: + raise MissingEnvironmentVariable(name=ex.args[0]) + + self.auth.sign_request(request) + + auth = dict( + shippable=request, + ) + + return auth + + def get_git_details(self, args): # type: (CommonConfig) -> t.Optional[t.Dict[str, t.Any]] + """Return details about git in the current environment.""" + commit = os.environ.get('COMMIT') + base_commit = os.environ.get('BASE_COMMIT') + + details = dict( + base_commit=base_commit, + commit=commit, + merged_commit=self._get_merged_commit(args, commit), + ) + + return details + + # noinspection PyUnusedLocal + def _get_merged_commit(self, args, commit): # type: (CommonConfig, str) -> t.Optional[str] # pylint: disable=unused-argument + """Find the merged commit that should be present.""" + if not commit: + return None + + git = Git(args) + + try: + show_commit = git.run_git(['show', '--no-patch', '--no-abbrev', commit]) + except SubprocessError as ex: + # This should only fail for pull requests where the commit does not exist. + # Merge runs would fail much earlier when attempting to checkout the commit. + raise ApplicationError('Commit %s was not found:\n\n%s\n\n' + 'GitHub may not have fully replicated the commit across their infrastructure.\n' + 'It is also possible the commit was removed by a force push between job creation and execution.\n' + 'Find the latest run for the pull request and restart failed jobs as needed.' + % (commit, ex.stderr.strip())) + + head_commit = git.run_git(['show', '--no-patch', '--no-abbrev', 'HEAD']) + + if show_commit == head_commit: + # Commit is HEAD, so this is not a pull request or the base branch for the pull request is up-to-date. + return None + + match_merge = re.search(r'^Merge: (?P<parents>[0-9a-f]{40} [0-9a-f]{40})$', head_commit, flags=re.MULTILINE) + + if not match_merge: + # The most likely scenarios resulting in a failure here are: + # A new run should or does supersede this job, but it wasn't cancelled in time. + # A job was superseded and then later restarted. + raise ApplicationError('HEAD is not commit %s or a merge commit:\n\n%s\n\n' + 'This job has likely been superseded by another run due to additional commits being pushed.\n' + 'Find the latest run for the pull request and restart failed jobs as needed.' + % (commit, head_commit.strip())) + + parents = set(match_merge.group('parents').split(' ')) + + if len(parents) != 2: + raise ApplicationError('HEAD is a %d-way octopus merge.' % len(parents)) + + if commit not in parents: + raise ApplicationError('Commit %s is not a parent of HEAD.' % commit) + + parents.remove(commit) + + last_commit = parents.pop() + + return last_commit + + +class ShippableAuthHelper(OpenSSLAuthHelper): + """ + Authentication helper for Shippable. + Based on OpenSSL since cryptography is not provided by the default Shippable environment. + """ + def publish_public_key(self, public_key_pem): # type: (str) -> None + """Publish the given public key.""" + # display the public key as a single line to avoid mangling such as when prefixing each line with a timestamp + display.info(public_key_pem.replace('\n', ' ')) + # allow time for logs to become available to reduce repeated API calls + time.sleep(3) + + +class ShippableChanges: + """Change information for Shippable build.""" + def __init__(self, args): # type: (CommonConfig) -> None + self.args = args + self.git = Git(args) + + try: + self.branch = os.environ['BRANCH'] + self.is_pr = os.environ['IS_PULL_REQUEST'] == 'true' + self.is_tag = os.environ['IS_GIT_TAG'] == 'true' + self.commit = os.environ['COMMIT'] + self.project_id = os.environ['PROJECT_ID'] + self.commit_range = os.environ['SHIPPABLE_COMMIT_RANGE'] + except KeyError as ex: + raise MissingEnvironmentVariable(name=ex.args[0]) + + if self.is_tag: + raise ChangeDetectionNotSupported('Change detection is not supported for tags.') + + if self.is_pr: + self.paths = sorted(self.git.get_diff_names([self.commit_range])) + self.diff = self.git.get_diff([self.commit_range]) + else: + commits = self.get_successful_merge_run_commits(self.project_id, self.branch) + last_successful_commit = self.get_last_successful_commit(commits) + + if last_successful_commit: + self.paths = sorted(self.git.get_diff_names([last_successful_commit, self.commit])) + self.diff = self.git.get_diff([last_successful_commit, self.commit]) + else: + # first run for branch + self.paths = None # act as though change detection not enabled, do not filter targets + self.diff = [] + + def get_successful_merge_run_commits(self, project_id, branch): # type: (str, str) -> t.Set[str] + """Return a set of recent successsful merge commits from Shippable for the given project and branch.""" + parameters = dict( + isPullRequest='false', + projectIds=project_id, + branch=branch, + ) + + url = 'https://api.shippable.com/runs?%s' % urlencode(parameters) + + http = HttpClient(self.args, always=True) + response = http.get(url) + result = response.json() + + if 'id' in result and result['id'] == 4004: + # most likely due to a private project, which returns an HTTP 200 response with JSON + display.warning('Unable to find project. Cannot determine changes. All tests will be executed.') + return set() + + commits = set(run['commitSha'] for run in result if run['statusCode'] == 30) + + return commits + + def get_last_successful_commit(self, successful_commits): # type: (t.Set[str]) -> t.Optional[str] + """Return the last successful commit from git history that is found in the given commit list, or None.""" + commit_history = self.git.get_rev_list(max_count=100) + ordered_successful_commits = [commit for commit in commit_history if commit in successful_commits] + last_successful_commit = ordered_successful_commits[0] if ordered_successful_commits else None + return last_successful_commit diff --git a/test/runner/lib/cli.py b/test/runner/lib/cli.py index 9987f5dbb5..cb49cacba4 100644 --- a/test/runner/lib/cli.py +++ b/test/runner/lib/cli.py @@ -226,6 +226,9 @@ def parse_args(): test.add_argument('--metadata', help=argparse.SUPPRESS) + test.add_argument('--base-branch', + help='base branch used for change detection') + add_changes(test, argparse) add_environments(test) @@ -423,9 +426,6 @@ def parse_args(): choices=SUPPORTED_PYTHON_VERSIONS + ('default',), help='python version: %s' % ', '.join(SUPPORTED_PYTHON_VERSIONS)) - sanity.add_argument('--base-branch', - help=argparse.SUPPRESS) - add_lint(sanity) add_extra_docker_options(sanity, integration=False) @@ -637,6 +637,7 @@ def add_environments(parser, tox_version=False, tox_only=False): remote_provider=None, remote_aws_region=None, remote_terminate=None, + remote_endpoint=None, python_interpreter=None, ) @@ -658,9 +659,8 @@ def add_environments(parser, tox_version=False, tox_only=False): remote.add_argument('--remote-stage', metavar='STAGE', - help='remote stage to use: %(choices)s', - choices=['prod', 'dev'], - default='prod') + help='remote stage to use: prod, dev', + default='prod').completer = complete_remote_stage remote.add_argument('--remote-provider', metavar='PROVIDER', @@ -668,6 +668,11 @@ def add_environments(parser, tox_version=False, tox_only=False): choices=['default', 'aws', 'azure', 'parallels'], default='default') + remote.add_argument('--remote-endpoint', + metavar='ENDPOINT', + help='remote provisioning endpoint to use (default: auto)', + default=None) + remote.add_argument('--remote-aws-region', metavar='REGION', help='remote aws region to use: %(choices)s (default: auto)', @@ -756,6 +761,16 @@ def add_extra_docker_options(parser, integration=True): help='memory limit for docker in bytes', type=int) +# noinspection PyUnusedLocal +def complete_remote_stage(prefix, parsed_args, **_): # pylint: disable=unused-argument + """ + :type prefix: unicode + :type parsed_args: any + :rtype: list[str] + """ + return [stage for stage in ('prod', 'dev') if stage.startswith(prefix)] + + def complete_target(prefix, parsed_args, **_): """ :type prefix: unicode diff --git a/test/runner/lib/cloud/__init__.py b/test/runner/lib/cloud/__init__.py index 8f21c715ce..43940f8297 100644 --- a/test/runner/lib/cloud/__init__.py +++ b/test/runner/lib/cloud/__init__.py @@ -7,15 +7,12 @@ import datetime import json import time import os -import platform -import random import re import tempfile from lib.util import ( ApplicationError, display, - is_shippable, import_plugins, load_plugins, ABC, @@ -29,6 +26,10 @@ from lib.config import ( IntegrationConfig, ) +from lib.ci import ( + get_ci_provider, +) + PROVIDERS = {} ENVIRONMENTS = {} @@ -260,6 +261,7 @@ class CloudProvider(CloudBase): """ super(CloudProvider, self).__init__(args) + self.ci_provider = get_ci_provider() self.remove_config = False self.config_static_path = '%s/cloud-config-%s%s' % (self.TEST_DIR, self.platform, config_extension) self.config_template_path = '%s.template' % self.config_static_path @@ -280,7 +282,7 @@ class CloudProvider(CloudBase): def setup(self): """Setup the cloud resource before delegation and register a cleanup callback.""" - self.resource_prefix = self._generate_resource_prefix() + self.resource_prefix = self.ci_provider.generate_resource_prefix() atexit.register(self.cleanup) @@ -359,21 +361,6 @@ class CloudProvider(CloudBase): return template - @staticmethod - def _generate_resource_prefix(): - """ - :rtype: str - """ - if is_shippable(): - return 'shippable-%s-%s' % ( - os.environ['SHIPPABLE_BUILD_NUMBER'], - os.environ['SHIPPABLE_JOB_NUMBER'], - ) - - node = re.sub(r'[^a-zA-Z0-9]+', '-', platform.node().split('.')[0]).lower() - - return 'ansible-test-%s-%d' % (node, random.randint(10000000, 99999999)) - class CloudEnvironment(CloudBase): """Base class for cloud environment plugins. Updates integration test environment after delegation.""" diff --git a/test/runner/lib/cloud/aws.py b/test/runner/lib/cloud/aws.py index 0720fe2bc4..62b1f38cd7 100644 --- a/test/runner/lib/cloud/aws.py +++ b/test/runner/lib/cloud/aws.py @@ -6,7 +6,6 @@ import os from lib.util import ( ApplicationError, display, - is_shippable, ConfigParser, ) @@ -33,10 +32,7 @@ class AwsCloudProvider(CloudProvider): aci = self._create_ansible_core_ci() - if os.path.isfile(aci.ci_key): - return - - if is_shippable(): + if aci.available: return super(AwsCloudProvider, self).filter(targets, exclude) diff --git a/test/runner/lib/cloud/azure.py b/test/runner/lib/cloud/azure.py index e539d3abc1..0d48f1775d 100644 --- a/test/runner/lib/cloud/azure.py +++ b/test/runner/lib/cloud/azure.py @@ -6,7 +6,6 @@ import os from lib.util import ( ApplicationError, display, - is_shippable, ConfigParser, ) @@ -50,15 +49,12 @@ class AzureCloudProvider(CloudProvider): aci = self._create_ansible_core_ci() - if os.path.isfile(aci.ci_key): + if aci.available: return if os.path.isfile(self.SHERLOCK_CONFIG_PATH): return - if is_shippable(): - return - super(AzureCloudProvider, self).filter(targets, exclude) def setup(self): diff --git a/test/runner/lib/cloud/cs.py b/test/runner/lib/cloud/cs.py index 2ad54ed61b..c64f17aaba 100644 --- a/test/runner/lib/cloud/cs.py +++ b/test/runner/lib/cloud/cs.py @@ -17,7 +17,6 @@ from lib.util import ( ApplicationError, display, SubprocessError, - is_shippable, ConfigParser, ) @@ -106,7 +105,7 @@ class CsCloudProvider(CloudProvider): def cleanup(self): """Clean up the cloud resource and any temporary configuration files after tests complete.""" if self.container_name: - if is_shippable(): + if self.ci_provider.code: docker_rm(self.args, self.container_name) elif not self.args.explain: display.notice('Remember to run `docker rm -f %s` when finished testing.' % self.container_name) diff --git a/test/runner/lib/cloud/hcloud.py b/test/runner/lib/cloud/hcloud.py index 9d38aeb5bb..5cc24ea5a2 100644 --- a/test/runner/lib/cloud/hcloud.py +++ b/test/runner/lib/cloud/hcloud.py @@ -5,7 +5,6 @@ import os from lib.util import ( display, - is_shippable, ConfigParser, ) @@ -41,10 +40,7 @@ class HcloudCloudProvider(CloudProvider): aci = self._create_ansible_core_ci() - if os.path.isfile(aci.ci_key): - return - - if is_shippable(): + if aci.available: return super(HcloudCloudProvider, self).filter(targets, exclude) diff --git a/test/runner/lib/cloud/tower.py b/test/runner/lib/cloud/tower.py index fcc8bb5d9c..e55d55781c 100644 --- a/test/runner/lib/cloud/tower.py +++ b/test/runner/lib/cloud/tower.py @@ -7,7 +7,6 @@ import time from lib.util import ( display, ApplicationError, - is_shippable, run_command, SubprocessError, ConfigParser, @@ -45,10 +44,7 @@ class TowerCloudProvider(CloudProvider): aci = get_tower_aci(self.args) - if os.path.isfile(aci.ci_key): - return - - if is_shippable(): + if aci.available: return super(TowerCloudProvider, self).filter(targets, exclude) diff --git a/test/runner/lib/config.py b/test/runner/lib/config.py index 4c276333e7..95b059f982 100644 --- a/test/runner/lib/config.py +++ b/test/runner/lib/config.py @@ -7,13 +7,14 @@ import sys from lib.util import ( CommonConfig, - is_shippable, docker_qualify_image, find_python, generate_pip_command, get_docker_completion, ) +from lib import types as t + from lib.metadata import ( Metadata, ) @@ -56,6 +57,7 @@ class EnvironmentConfig(CommonConfig): self.remote_stage = args.remote_stage # type: str self.remote_provider = args.remote_provider # type: str + self.remote_endpoint = args.remote_endpoint # type: t.Optional[str] self.remote_aws_region = args.remote_aws_region # type: str self.remote_terminate = args.remote_terminate # type: str @@ -118,6 +120,7 @@ class TestConfig(EnvironmentConfig): self.unstaged = args.unstaged # type: bool self.changed_from = args.changed_from # type: str self.changed_path = args.changed_path # type: list [str] + self.base_branch = args.base_branch # type: str self.lint = args.lint if 'lint' in args else False # type: bool self.junit = args.junit if 'junit' in args else False # type: bool @@ -157,16 +160,6 @@ class SanityConfig(TestConfig): self.list_tests = args.list_tests # type: bool self.allow_disabled = args.allow_disabled # type: bool - if args.base_branch: - self.base_branch = args.base_branch # str - elif is_shippable(): - self.base_branch = os.environ.get('BASE_BRANCH', '') # str - - if self.base_branch: - self.base_branch = 'origin/%s' % self.base_branch - else: - self.base_branch = '' - class IntegrationConfig(TestConfig): """Configuration for the integration command.""" diff --git a/test/runner/lib/core_ci.py b/test/runner/lib/core_ci.py index e8425e090a..b10cbcaa4a 100644 --- a/test/runner/lib/core_ci.py +++ b/test/runner/lib/core_ci.py @@ -22,16 +22,19 @@ from lib.util import ( run_command, make_dirs, display, - is_shippable, ) from lib.config import ( EnvironmentConfig, ) +from lib.ci import ( + AuthContext, + get_ci_provider, +) + AWS_ENDPOINTS = { - 'us-east-1': 'https://14blg63h2i.execute-api.us-east-1.amazonaws.com', - 'us-east-2': 'https://g5xynwbk96.execute-api.us-east-2.amazonaws.com', + 'us-east-1': 'https://ansible-core-ci.testing.ansible.com', } @@ -56,9 +59,9 @@ class AnsibleCoreCI(object): self.instance_id = None self.endpoint = None self.max_threshold = 1 + self.ci_provider = get_ci_provider() + self.auth_context = AuthContext() self.name = name if name else '%s-%s' % (self.platform, self.version) - self.ci_key = os.path.expanduser('~/.ansible-core-ci.key') - self.resource = 'jobs' # Assign each supported platform to one provider. # This is used to determine the provider from the platform when no provider is specified. @@ -100,19 +103,22 @@ class AnsibleCoreCI(object): self.path = os.path.expanduser('~/.ansible/test/instances/%s-%s-%s' % (self.name, self.provider, self.stage)) if self.provider in ('aws', 'azure'): - if self.provider != 'aws': - self.resource = self.provider - if args.remote_aws_region: + display.warning('The --remote-aws-region option is obsolete and will be removed in a future version of ansible-test.') # permit command-line override of region selection region = args.remote_aws_region # use a dedicated CI key when overriding the region selection - self.ci_key += '.%s' % args.remote_aws_region + self.auth_context.region = args.remote_aws_region else: region = 'us-east-1' self.path = "%s-%s" % (self.path, region) - self.endpoints = (AWS_ENDPOINTS[region],) + + if self.args.remote_endpoint: + self.endpoints = (self.args.remote_endpoint,) + else: + self.endpoints = (AWS_ENDPOINTS[region],) + self.ssh_key = SshKey(args) if self.platform == 'windows': @@ -120,8 +126,10 @@ class AnsibleCoreCI(object): else: self.port = 22 elif self.provider == 'parallels': - self.endpoints = self._get_parallels_endpoints() - self.max_threshold = 6 + if self.args.remote_endpoint: + self.endpoints = (self.args.remote_endpoint,) + else: + self.endpoints = (AWS_ENDPOINTS['us-east-1'],) self.ssh_key = SshKey(args) self.port = None @@ -169,8 +177,8 @@ class AnsibleCoreCI(object): display.info('Getting available endpoints...', verbosity=1) sleep = 3 - for _ in range(1, 10): - response = client.get('https://s3.amazonaws.com/ansible-ci-files/ansible-test/parallels-endpoints.txt') + for _iteration in range(1, 10): + response = client.get('https://ansible-ci-files.s3.amazonaws.com/ansible-test/parallels-endpoints.txt') if response.status_code == 200: endpoints = tuple(response.response.splitlines()) @@ -182,6 +190,11 @@ class AnsibleCoreCI(object): raise ApplicationError('Unable to get available endpoints.') + @property + def available(self): + """Return True if Ansible Core CI is supported.""" + return self.ci_provider.supports_core_ci_auth(self.auth_context) + def start(self): """Start instance.""" if self.started: @@ -189,31 +202,7 @@ class AnsibleCoreCI(object): verbosity=1) return None - if is_shippable(): - return self.start_shippable() - - return self.start_remote() - - def start_remote(self): - """Start instance for remote development/testing.""" - with open(self.ci_key, 'r') as key_fd: - auth_key = key_fd.read().strip() - - return self._start(dict( - remote=dict( - key=auth_key, - nonce=None, - ), - )) - - def start_shippable(self): - """Start instance on Shippable.""" - return self._start(dict( - shippable=dict( - run_id=os.environ['SHIPPABLE_BUILD_ID'], - job_number=int(os.environ['SHIPPABLE_JOB_NUMBER']), - ), - )) + return self._start(self.ci_provider.prepare_core_ci_auth(self.auth_context)) def stop(self): """Stop instance.""" @@ -317,7 +306,7 @@ class AnsibleCoreCI(object): @property def _uri(self): - return '%s/%s/%s/%s' % (self.endpoint, self.stage, self.resource, self.instance_id) + return '%s/%s/%s/%s' % (self.endpoint, self.stage, self.provider, self.instance_id) def _start(self, auth): """Start instance.""" diff --git a/test/runner/lib/delegation.py b/test/runner/lib/delegation.py index 13dcb6d2f2..b3caf4c210 100644 --- a/test/runner/lib/delegation.py +++ b/test/runner/lib/delegation.py @@ -68,6 +68,10 @@ from lib.target import ( IntegrationTarget, ) +from lib.ci import ( + get_ci_provider, +) + def check_delegation_args(args): """ @@ -91,6 +95,8 @@ def delegate(args, exclude, require, integration_targets): :rtype: bool """ if isinstance(args, TestConfig): + args.metadata.ci_provider = get_ci_provider().code + with tempfile.NamedTemporaryFile(prefix='metadata-', suffix='.json', dir=os.getcwd()) as metadata_fd: args.metadata_path = os.path.basename(metadata_fd.name) args.metadata.to_file(args.metadata_path) @@ -472,8 +478,10 @@ def generate_command(args, python_interpreter, path, options, exclude, require): if isinstance(args, ShellConfig): cmd = create_shell_command(cmd) elif isinstance(args, SanityConfig): - if args.base_branch: - cmd += ['--base-branch', args.base_branch] + base_branch = args.base_branch or get_ci_provider().get_base_branch() + + if base_branch: + cmd += ['--base-branch', base_branch] return cmd diff --git a/test/runner/lib/encoding.py b/test/runner/lib/encoding.py new file mode 100644 index 0000000000..8e014794c7 --- /dev/null +++ b/test/runner/lib/encoding.py @@ -0,0 +1,41 @@ +"""Functions for encoding and decoding strings.""" +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from . import types as t + +ENCODING = 'utf-8' + +Text = type(u'') + + +def to_optional_bytes(value, errors='strict'): # type: (t.Optional[t.AnyStr], str) -> t.Optional[bytes] + """Return the given value as bytes encoded using UTF-8 if not already bytes, or None if the value is None.""" + return None if value is None else to_bytes(value, errors) + + +def to_optional_text(value, errors='strict'): # type: (t.Optional[t.AnyStr], str) -> t.Optional[t.Text] + """Return the given value as text decoded using UTF-8 if not already text, or None if the value is None.""" + return None if value is None else to_text(value, errors) + + +def to_bytes(value, errors='strict'): # type: (t.AnyStr, str) -> bytes + """Return the given value as bytes encoded using UTF-8 if not already bytes.""" + if isinstance(value, bytes): + return value + + if isinstance(value, Text): + return value.encode(ENCODING, errors) + + raise Exception('value is not bytes or text: %s' % type(value)) + + +def to_text(value, errors='strict'): # type: (t.AnyStr, str) -> t.Text + """Return the given value as text decoded using UTF-8 if not already text.""" + if isinstance(value, bytes): + return value.decode(ENCODING, errors) + + if isinstance(value, Text): + return value + + raise Exception('value is not bytes or text: %s' % type(value)) diff --git a/test/runner/lib/env.py b/test/runner/lib/env.py index 03cfc0daa3..106cd84740 100644 --- a/test/runner/lib/env.py +++ b/test/runner/lib/env.py @@ -7,7 +7,6 @@ import json import functools import os import platform -import re import signal import sys import time @@ -29,10 +28,6 @@ from lib.ansible_util import ( ansible_environment, ) -from lib.git import ( - Git, -) - from lib.docker_util import ( docker_info, docker_version @@ -50,6 +45,10 @@ from lib.test import ( TestTimeout, ) +from lib.ci import ( + get_ci_provider, +) + class EnvConfig(CommonConfig): """Configuration for the tools command.""" @@ -85,7 +84,7 @@ def show_dump_env(args): ), docker=get_docker_details(args), environ=os.environ.copy(), - git=get_git_details(args), + git=get_ci_provider().get_git_details(args), platform=dict( datetime=datetime.datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%SZ'), platform=platform.platform(), @@ -284,75 +283,3 @@ def get_docker_details(args): ) return docker_details - - -def get_git_details(args): - """ - :type args: CommonConfig - :rtype: dict[str, any] - """ - commit = os.environ.get('COMMIT') - base_commit = os.environ.get('BASE_COMMIT') - - git_details = dict( - base_commit=base_commit, - commit=commit, - merged_commit=get_merged_commit(args, commit), - root=os.getcwd(), - ) - - return git_details - - -def get_merged_commit(args, commit): - """ - :type args: CommonConfig - :type commit: str - :rtype: str | None - """ - if not commit: - return None - - git = Git(args) - - try: - show_commit = git.run_git(['show', '--no-patch', '--no-abbrev', commit]) - except SubprocessError as ex: - # This should only fail for pull requests where the commit does not exist. - # Merge runs would fail much earlier when attempting to checkout the commit. - raise ApplicationError('Commit %s was not found:\n\n%s\n\n' - 'GitHub may not have fully replicated the commit across their infrastructure.\n' - 'It is also possible the commit was removed by a force push between job creation and execution.\n' - 'Find the latest run for the pull request and restart failed jobs as needed.' - % (commit, ex.stderr.strip())) - - head_commit = git.run_git(['show', '--no-patch', '--no-abbrev', 'HEAD']) - - if show_commit == head_commit: - # Commit is HEAD, so this is not a pull request or the base branch for the pull request is up-to-date. - return None - - match_merge = re.search(r'^Merge: (?P<parents>[0-9a-f]{40} [0-9a-f]{40})$', head_commit, flags=re.MULTILINE) - - if not match_merge: - # The most likely scenarios resulting in a failure here are: - # A new run should or does supersede this job, but it wasn't cancelled in time. - # A job was superseded and then later restarted. - raise ApplicationError('HEAD is not commit %s or a merge commit:\n\n%s\n\n' - 'This job has likely been superseded by another run due to additional commits being pushed.\n' - 'Find the latest run for the pull request and restart failed jobs as needed.' - % (commit, head_commit.strip())) - - parents = set(match_merge.group('parents').split(' ')) - - if len(parents) != 2: - raise ApplicationError('HEAD is a %d-way octopus merge.' % len(parents)) - - if commit not in parents: - raise ApplicationError('Commit %s is not a parent of HEAD.' % commit) - - parents.remove(commit) - - last_commit = parents.pop() - - return last_commit diff --git a/test/runner/lib/executor.py b/test/runner/lib/executor.py index 970ff52e81..305ce64083 100644 --- a/test/runner/lib/executor.py +++ b/test/runner/lib/executor.py @@ -48,8 +48,6 @@ from lib.util import ( intercept_command, remove_tree, make_dirs, - is_shippable, - is_binary_file, find_executable, raw_command, get_python_path, @@ -86,13 +84,8 @@ from lib.target import ( walk_units_targets, ) -from lib.changes import ( - ShippableChanges, - LocalChanges, -) - -from lib.git import ( - Git, +from lib.ci import ( + get_ci_provider, ) from lib.classification import ( @@ -1151,7 +1144,7 @@ def integration_environment(args, target, test_dir, inventory_path, ansible_conf integration = dict( JUNIT_OUTPUT_DIR=os.path.abspath('test/results/junit'), ANSIBLE_CALLBACK_WHITELIST=','.join(sorted(set(callback_plugins))), - ANSIBLE_TEST_CI=args.metadata.ci_provider, + ANSIBLE_TEST_CI=args.metadata.ci_provider or get_ci_provider().code, OUTPUT_DIR=test_dir, INVENTORY_PATH=os.path.abspath(inventory_path), ) @@ -1388,16 +1381,13 @@ def detect_changes(args): :type args: TestConfig :rtype: list[str] | None """ - if args.changed and is_shippable(): - display.info('Shippable detected, collecting parameters from environment.') - paths = detect_changes_shippable(args) + if args.changed: + paths = get_ci_provider().detect_changes(args) elif args.changed_from or args.changed_path: paths = args.changed_path or [] if args.changed_from: with open(args.changed_from, 'r') as changes_fd: paths += changes_fd.read().splitlines() - elif args.changed: - paths = detect_changes_local(args) else: return None # change detection not enabled @@ -1412,85 +1402,6 @@ def detect_changes(args): return paths -def detect_changes_shippable(args): - """Initialize change detection on Shippable. - :type args: TestConfig - :rtype: list[str] | None - """ - git = Git(args) - result = ShippableChanges(args, git) - - if result.is_pr: - job_type = 'pull request' - elif result.is_tag: - job_type = 'tag' - else: - job_type = 'merge commit' - - display.info('Processing %s for branch %s commit %s' % (job_type, result.branch, result.commit)) - - if not args.metadata.changes: - args.metadata.populate_changes(result.diff) - - return result.paths - - -def detect_changes_local(args): - """ - :type args: TestConfig - :rtype: list[str] - """ - git = Git(args) - result = LocalChanges(args, git) - - display.info('Detected branch %s forked from %s at commit %s' % ( - result.current_branch, result.fork_branch, result.fork_point)) - - if result.untracked and not args.untracked: - display.warning('Ignored %s untracked file(s). Use --untracked to include them.' % - len(result.untracked)) - - if result.committed and not args.committed: - display.warning('Ignored %s committed change(s). Omit --ignore-committed to include them.' % - len(result.committed)) - - if result.staged and not args.staged: - display.warning('Ignored %s staged change(s). Omit --ignore-staged to include them.' % - len(result.staged)) - - if result.unstaged and not args.unstaged: - display.warning('Ignored %s unstaged change(s). Omit --ignore-unstaged to include them.' % - len(result.unstaged)) - - names = set() - - if args.tracked: - names |= set(result.tracked) - if args.untracked: - names |= set(result.untracked) - if args.committed: - names |= set(result.committed) - if args.staged: - names |= set(result.staged) - if args.unstaged: - names |= set(result.unstaged) - - if not args.metadata.changes: - args.metadata.populate_changes(result.diff) - - for path in result.untracked: - if is_binary_file(path): - args.metadata.changes[path] = ((0, 0),) - continue - - with open(path, 'r') as source_fd: - line_count = len(source_fd.read().splitlines()) - - args.metadata.changes[path] = ((1, line_count),) - - return sorted(names) - - def get_integration_filter(args, targets): """ :type args: IntegrationConfig diff --git a/test/runner/lib/io.py b/test/runner/lib/io.py new file mode 100644 index 0000000000..0f61cd2df2 --- /dev/null +++ b/test/runner/lib/io.py @@ -0,0 +1,94 @@ +"""Functions for disk IO.""" +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import errno +import io +import json +import os + +from . import types as t + +from .encoding import ( + ENCODING, + to_bytes, + to_text, +) + + +def read_json_file(path): # type: (t.AnyStr) -> t.Any + """Parse and return the json content from the specified path.""" + return json.loads(read_text_file(path)) + + +def read_text_file(path): # type: (t.AnyStr) -> t.Text + """Return the contents of the specified path as text.""" + return to_text(read_binary_file(path)) + + +def read_binary_file(path): # type: (t.AnyStr) -> bytes + """Return the contents of the specified path as bytes.""" + with open_binary_file(path) as file: + return file.read() + + +def make_dirs(path): # type: (str) -> None + """Create a directory at path, including any necessary parent directories.""" + try: + os.makedirs(to_bytes(path)) + except OSError as ex: + if ex.errno != errno.EEXIST: + raise + + +def write_json_file(path, # type: str + content, # type: t.Union[t.List[t.Any], t.Dict[str, t.Any]] + create_directories=False, # type: bool + formatted=True, # type: bool + encoder=None, # type: t.Optional[t.Callable[[t.Any], t.Any]] + ): # type: (...) -> None + """Write the given json content to the specified path, optionally creating missing directories.""" + text_content = json.dumps(content, + sort_keys=formatted, + indent=4 if formatted else None, + separators=(', ', ': ') if formatted else (',', ':'), + cls=encoder, + ) + '\n' + + write_text_file(path, text_content, create_directories=create_directories) + + +def write_text_file(path, content, create_directories=False): # type: (str, str, bool) -> None + """Write the given text content to the specified path, optionally creating missing directories.""" + if create_directories: + make_dirs(os.path.dirname(path)) + + with open_binary_file(path, 'wb') as file: + file.write(to_bytes(content)) + + +def open_text_file(path, mode='r'): # type: (str, str) -> t.TextIO + """Open the given path for text access.""" + if 'b' in mode: + raise Exception('mode cannot include "b" for text files: %s' % mode) + + # noinspection PyTypeChecker + return io.open(to_bytes(path), mode, encoding=ENCODING) + + +def open_binary_file(path, mode='rb'): # type: (str, str) -> t.BinaryIO + """Open the given path for binary access.""" + if 'b' not in mode: + raise Exception('mode must include "b" for binary files: %s' % mode) + + # noinspection PyTypeChecker + return io.open(to_bytes(path), mode) + + +class SortedSetEncoder(json.JSONEncoder): + """Encode sets as sorted lists.""" + def default(self, obj): # pylint: disable=method-hidden, arguments-differ + if isinstance(obj, set): + return sorted(obj) + + return super(SortedSetEncoder).default(self, obj) diff --git a/test/runner/lib/metadata.py b/test/runner/lib/metadata.py index 0bb13f826d..4a78aba42c 100644 --- a/test/runner/lib/metadata.py +++ b/test/runner/lib/metadata.py @@ -5,7 +5,6 @@ import json from lib.util import ( display, - is_shippable, ) from lib.diff import ( @@ -22,11 +21,7 @@ class Metadata(object): self.cloud_config = None # type: dict [str, str] self.instance_config = None # type: list[dict[str, str]] self.change_description = None # type: ChangeDescription - - if is_shippable(): - self.ci_provider = 'shippable' - else: - self.ci_provider = '' + self.ci_provider = None # type: t.Optional[str] def populate_changes(self, diff): """ diff --git a/test/runner/lib/sanity/validate_modules.py b/test/runner/lib/sanity/validate_modules.py index af51118245..e516b521e6 100644 --- a/test/runner/lib/sanity/validate_modules.py +++ b/test/runner/lib/sanity/validate_modules.py @@ -28,6 +28,10 @@ from lib.config import ( SanityConfig, ) +from lib.ci import ( + get_ci_provider, +) + from lib.test import ( calculate_confidence, calculate_best_confidence, @@ -91,12 +95,14 @@ class ValidateModulesTest(SanitySingleVersion): ignore[path][code] = line - if args.base_branch: + base_branch = args.base_branch or get_ci_provider().get_base_branch() + + if base_branch: cmd.extend([ - '--base-branch', args.base_branch, + '--base-branch', base_branch, ]) else: - display.warning('Cannot perform module comparison against the base branch. Base branch not detected when running locally.') + display.warning('Cannot perform module comparison against the base branch because the base branch was not detected.') try: stdout, stderr = run_command(args, cmd, env=env, capture=True) diff --git a/test/runner/lib/test.py b/test/runner/lib/test.py index 235de16aab..37028e4710 100644 --- a/test/runner/lib/test.py +++ b/test/runner/lib/test.py @@ -191,7 +191,7 @@ One or more of the following situations may be responsible: timestamp = datetime.datetime.utcnow().replace(microsecond=0).isoformat() - # hack to avoid requiring junit-xml, which isn't pre-installed on Shippable outside our test containers + # hack to avoid requiring junit-xml, which may not be pre-installed outside our test containers xml = ''' <?xml version="1.0" encoding="utf-8"?> <testsuites disabled="0" errors="1" failures="0" tests="1" time="0.0"> diff --git a/test/runner/lib/types.py b/test/runner/lib/types.py new file mode 100644 index 0000000000..46ef70668e --- /dev/null +++ b/test/runner/lib/types.py @@ -0,0 +1,32 @@ +"""Import wrapper for type hints when available.""" +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +TYPE_CHECKING = False + +try: + from typing import ( + Any, + AnyStr, + BinaryIO, + Callable, + Dict, + FrozenSet, + Generator, + IO, + Iterable, + Iterator, + List, + Optional, + Pattern, + Set, + Text, + TextIO, + Tuple, + Type, + TYPE_CHECKING, + TypeVar, + Union, + ) +except ImportError: + pass diff --git a/test/runner/lib/util.py b/test/runner/lib/util.py index b0f30b6873..6629366e4b 100644 --- a/test/runner/lib/util.py +++ b/test/runner/lib/util.py @@ -113,13 +113,6 @@ def parse_parameterized_completion(value): return name, data -def is_shippable(): - """ - :rtype: bool - """ - return os.environ.get('SHIPPABLE') == 'true' - - def remove_file(path): """ :type path: str diff --git a/test/sanity/pylint/config/ansible-test b/test/sanity/pylint/config/ansible-test index 14be00b2c7..887339ef5e 100644 --- a/test/sanity/pylint/config/ansible-test +++ b/test/sanity/pylint/config/ansible-test @@ -1,6 +1,8 @@ [MESSAGES CONTROL] disable= + invalid-name, + no-self-use, too-few-public-methods, too-many-arguments, too-many-branches, diff --git a/test/units/ansible_test/__init__.py b/test/units/ansible_test/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/test/units/ansible_test/__init__.py diff --git a/test/units/ansible_test/ci/__init__.py b/test/units/ansible_test/ci/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/test/units/ansible_test/ci/__init__.py diff --git a/test/units/ansible_test/ci/test_azp.py b/test/units/ansible_test/ci/test_azp.py new file mode 100644 index 0000000000..2676a12e3c --- /dev/null +++ b/test/units/ansible_test/ci/test_azp.py @@ -0,0 +1,31 @@ +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from .util import common_auth_test + + +def test_auth(): + # noinspection PyProtectedMember + from lib.ci.azp import ( + AzurePipelinesAuthHelper, + ) + + class TestAzurePipelinesAuthHelper(AzurePipelinesAuthHelper): + def __init__(self): + self.public_key_pem = None + self.private_key_pem = None + + def publish_public_key(self, public_key_pem): + # avoid publishing key + self.public_key_pem = public_key_pem + + def initialize_private_key(self): + # cache in memory instead of on disk + if not self.private_key_pem: + self.private_key_pem = self.generate_private_key() + + return self.private_key_pem + + auth = TestAzurePipelinesAuthHelper() + + common_auth_test(auth) diff --git a/test/units/ansible_test/ci/test_shippable.py b/test/units/ansible_test/ci/test_shippable.py new file mode 100644 index 0000000000..52a332be46 --- /dev/null +++ b/test/units/ansible_test/ci/test_shippable.py @@ -0,0 +1,31 @@ +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from .util import common_auth_test + + +def test_auth(): + # noinspection PyProtectedMember + from lib.ci.shippable import ( + ShippableAuthHelper, + ) + + class TestShippableAuthHelper(ShippableAuthHelper): + def __init__(self): + self.public_key_pem = None + self.private_key_pem = None + + def publish_public_key(self, public_key_pem): + # avoid publishing key + self.public_key_pem = public_key_pem + + def initialize_private_key(self): + # cache in memory instead of on disk + if not self.private_key_pem: + self.private_key_pem = self.generate_private_key() + + return self.private_key_pem + + auth = TestShippableAuthHelper() + + common_auth_test(auth) diff --git a/test/units/ansible_test/ci/util.py b/test/units/ansible_test/ci/util.py new file mode 100644 index 0000000000..ba8e358bc8 --- /dev/null +++ b/test/units/ansible_test/ci/util.py @@ -0,0 +1,53 @@ +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import base64 +import json +import re + + +def common_auth_test(auth): + private_key_pem = auth.initialize_private_key() + public_key_pem = auth.public_key_pem + + extract_pem_key(private_key_pem, private=True) + extract_pem_key(public_key_pem, private=False) + + request = dict(hello='World') + auth.sign_request(request) + + verify_signature(request, public_key_pem) + + +def extract_pem_key(value, private): + assert isinstance(value, type(u'')) + + key_type = '(EC )?PRIVATE' if private else 'PUBLIC' + pattern = r'^-----BEGIN ' + key_type + r' KEY-----\n(?P<key>.*?)\n-----END ' + key_type + r' KEY-----\n$' + match = re.search(pattern, value, flags=re.DOTALL) + + assert match, 'key "%s" does not match pattern "%s"' % (value, pattern) + + base64.b64decode(match.group('key')) # make sure the key can be decoded + + +def verify_signature(request, public_key_pem): + signature = request.pop('signature') + payload_bytes = json.dumps(request, sort_keys=True).encode() + + assert isinstance(signature, type(u'')) + + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import hashes + from cryptography.hazmat.primitives.asymmetric import ec + from cryptography.hazmat.primitives.serialization import load_pem_public_key + + public_key = load_pem_public_key(public_key_pem.encode(), default_backend()) + + verifier = public_key.verifier( + base64.b64decode(signature.encode()), + ec.ECDSA(hashes.SHA256()), + ) + + verifier.update(payload_bytes) + verifier.verify() diff --git a/test/units/ansible_test/conftest.py b/test/units/ansible_test/conftest.py new file mode 100644 index 0000000000..e045026cca --- /dev/null +++ b/test/units/ansible_test/conftest.py @@ -0,0 +1,14 @@ +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + + +import os +import pytest +import sys + + +@pytest.fixture(autouse=True, scope='session') +def ansible_test(): + """Make ansible_test available on sys.path for unit testing ansible-test.""" + test_lib = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'runner') + sys.path.insert(0, test_lib) |