summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatt Clay <matt@mystile.com>2017-02-03 17:19:59 -0800
committerMatt Clay <matt@mystile.com>2017-02-15 11:57:16 -0800
commitcb93ecaef9c0a9b3d3328b5158e1d54902b8d714 (patch)
tree3c714669ad8616e3ea7ffdf0da21f6caf9125dde
parent6176c95838271d56b3eb4a2024ce53471ded20f2 (diff)
downloadansible-cb93ecaef9c0a9b3d3328b5158e1d54902b8d714.tar.gz
Fix @contextmanager leak on exception. (#21031)
* Fix @contextmanager leak on exception. * Fix test leaks of global module args cache. (cherry picked from commit 272ff10fa13e949eb637506969c40da453aae821)
-rw-r--r--test/units/mock/procenv.py31
-rw-r--r--test/units/module_utils/basic/test__log_invocation.py1
-rw-r--r--test/units/module_utils/test_basic.py5
-rw-r--r--test/units/module_utils/test_distribution_version.py1
4 files changed, 25 insertions, 13 deletions
diff --git a/test/units/mock/procenv.py b/test/units/mock/procenv.py
index e9d470c079..6cf69a7acc 100644
--- a/test/units/mock/procenv.py
+++ b/test/units/mock/procenv.py
@@ -36,18 +36,22 @@ def swap_stdin_and_argv(stdin_data='', argv_data=tuple()):
context manager that temporarily masks the test runner's values for stdin and argv
"""
real_stdin = sys.stdin
+ real_argv = sys.argv
if PY3:
- sys.stdin = StringIO(stdin_data)
- sys.stdin.buffer = BytesIO(to_bytes(stdin_data))
+ fake_stream = StringIO(stdin_data)
+ fake_stream.buffer = BytesIO(to_bytes(stdin_data))
else:
- sys.stdin = BytesIO(to_bytes(stdin_data))
+ fake_stream = BytesIO(to_bytes(stdin_data))
- real_argv = sys.argv
- sys.argv = argv_data
- yield
- sys.stdin = real_stdin
- sys.argv = real_argv
+ try:
+ sys.stdin = fake_stream
+ sys.argv = argv_data
+
+ yield
+ finally:
+ sys.stdin = real_stdin
+ sys.argv = real_argv
@contextmanager
@@ -56,13 +60,18 @@ def swap_stdout():
context manager that temporarily replaces stdout for tests that need to verify output
"""
old_stdout = sys.stdout
+
if PY3:
fake_stream = StringIO()
else:
fake_stream = BytesIO()
- sys.stdout = fake_stream
- yield fake_stream
- sys.stdout = old_stdout
+
+ try:
+ sys.stdout = fake_stream
+
+ yield fake_stream
+ finally:
+ sys.stdout = old_stdout
class ModuleTestCase(unittest.TestCase):
diff --git a/test/units/module_utils/basic/test__log_invocation.py b/test/units/module_utils/basic/test__log_invocation.py
index d4510c5efc..3723697bed 100644
--- a/test/units/module_utils/basic/test__log_invocation.py
+++ b/test/units/module_utils/basic/test__log_invocation.py
@@ -40,6 +40,7 @@ class TestModuleUtilsBasic(unittest.TestCase):
from ansible.module_utils import basic
# test basic log invocation
+ basic._ANSIBLE_ARGS = None
am = basic.AnsibleModule(
argument_spec=dict(
foo = dict(default=True, type='bool'),
diff --git a/test/units/module_utils/test_basic.py b/test/units/module_utils/test_basic.py
index 24b3214f57..8ec28c9e62 100644
--- a/test/units/module_utils/test_basic.py
+++ b/test/units/module_utils/test_basic.py
@@ -315,6 +315,7 @@ class TestModuleUtilsBasic(ModuleTestCase):
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"foo":"hello", "bar": "bad", "bam": "bad"}))
with swap_stdin_and_argv(stdin_data=args):
+ basic._ANSIBLE_ARGS = None
self.assertRaises(
SystemExit,
basic.AnsibleModule,
@@ -331,6 +332,7 @@ class TestModuleUtilsBasic(ModuleTestCase):
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"bam": "bad"}))
with swap_stdin_and_argv(stdin_data=args):
+ basic._ANSIBLE_ARGS = None
self.assertRaises(
SystemExit,
basic.AnsibleModule,
@@ -583,12 +585,11 @@ class TestModuleUtilsBasic(ModuleTestCase):
def test_module_utils_basic_ansible_module_is_special_selinux_path(self):
from ansible.module_utils import basic
- basic._ANSIBLE_ARGS = None
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={'_ansible_selinux_special_fs': "nfs,nfsd,foos"}))
with swap_stdin_and_argv(stdin_data=args):
-
+ basic._ANSIBLE_ARGS = None
am = basic.AnsibleModule(
argument_spec = dict(),
)
diff --git a/test/units/module_utils/test_distribution_version.py b/test/units/module_utils/test_distribution_version.py
index 584ac710c0..d0b7e27415 100644
--- a/test/units/module_utils/test_distribution_version.py
+++ b/test/units/module_utils/test_distribution_version.py
@@ -703,6 +703,7 @@ def test_distribution_version():
args = json.dumps(dict(ANSIBLE_MODULE_ARGS={}))
with swap_stdin_and_argv(stdin_data=args):
+ basic._ANSIBLE_ARGS = None
module = basic.AnsibleModule(argument_spec=dict())
for t in TESTSETS: