diff options
author | Matt Clay <matt@mystile.com> | 2017-02-03 17:19:59 -0800 |
---|---|---|
committer | Matt Clay <matt@mystile.com> | 2017-02-15 11:57:16 -0800 |
commit | cb93ecaef9c0a9b3d3328b5158e1d54902b8d714 (patch) | |
tree | 3c714669ad8616e3ea7ffdf0da21f6caf9125dde | |
parent | 6176c95838271d56b3eb4a2024ce53471ded20f2 (diff) | |
download | ansible-cb93ecaef9c0a9b3d3328b5158e1d54902b8d714.tar.gz |
Fix @contextmanager leak on exception. (#21031)
* Fix @contextmanager leak on exception.
* Fix test leaks of global module args cache.
(cherry picked from commit 272ff10fa13e949eb637506969c40da453aae821)
-rw-r--r-- | test/units/mock/procenv.py | 31 | ||||
-rw-r--r-- | test/units/module_utils/basic/test__log_invocation.py | 1 | ||||
-rw-r--r-- | test/units/module_utils/test_basic.py | 5 | ||||
-rw-r--r-- | test/units/module_utils/test_distribution_version.py | 1 |
4 files changed, 25 insertions, 13 deletions
diff --git a/test/units/mock/procenv.py b/test/units/mock/procenv.py index e9d470c079..6cf69a7acc 100644 --- a/test/units/mock/procenv.py +++ b/test/units/mock/procenv.py @@ -36,18 +36,22 @@ def swap_stdin_and_argv(stdin_data='', argv_data=tuple()): context manager that temporarily masks the test runner's values for stdin and argv """ real_stdin = sys.stdin + real_argv = sys.argv if PY3: - sys.stdin = StringIO(stdin_data) - sys.stdin.buffer = BytesIO(to_bytes(stdin_data)) + fake_stream = StringIO(stdin_data) + fake_stream.buffer = BytesIO(to_bytes(stdin_data)) else: - sys.stdin = BytesIO(to_bytes(stdin_data)) + fake_stream = BytesIO(to_bytes(stdin_data)) - real_argv = sys.argv - sys.argv = argv_data - yield - sys.stdin = real_stdin - sys.argv = real_argv + try: + sys.stdin = fake_stream + sys.argv = argv_data + + yield + finally: + sys.stdin = real_stdin + sys.argv = real_argv @contextmanager @@ -56,13 +60,18 @@ def swap_stdout(): context manager that temporarily replaces stdout for tests that need to verify output """ old_stdout = sys.stdout + if PY3: fake_stream = StringIO() else: fake_stream = BytesIO() - sys.stdout = fake_stream - yield fake_stream - sys.stdout = old_stdout + + try: + sys.stdout = fake_stream + + yield fake_stream + finally: + sys.stdout = old_stdout class ModuleTestCase(unittest.TestCase): diff --git a/test/units/module_utils/basic/test__log_invocation.py b/test/units/module_utils/basic/test__log_invocation.py index d4510c5efc..3723697bed 100644 --- a/test/units/module_utils/basic/test__log_invocation.py +++ b/test/units/module_utils/basic/test__log_invocation.py @@ -40,6 +40,7 @@ class TestModuleUtilsBasic(unittest.TestCase): from ansible.module_utils import basic # test basic log invocation + basic._ANSIBLE_ARGS = None am = basic.AnsibleModule( argument_spec=dict( foo = dict(default=True, type='bool'), diff --git a/test/units/module_utils/test_basic.py b/test/units/module_utils/test_basic.py index 24b3214f57..8ec28c9e62 100644 --- a/test/units/module_utils/test_basic.py +++ b/test/units/module_utils/test_basic.py @@ -315,6 +315,7 @@ class TestModuleUtilsBasic(ModuleTestCase): args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"foo":"hello", "bar": "bad", "bam": "bad"})) with swap_stdin_and_argv(stdin_data=args): + basic._ANSIBLE_ARGS = None self.assertRaises( SystemExit, basic.AnsibleModule, @@ -331,6 +332,7 @@ class TestModuleUtilsBasic(ModuleTestCase): args = json.dumps(dict(ANSIBLE_MODULE_ARGS={"bam": "bad"})) with swap_stdin_and_argv(stdin_data=args): + basic._ANSIBLE_ARGS = None self.assertRaises( SystemExit, basic.AnsibleModule, @@ -583,12 +585,11 @@ class TestModuleUtilsBasic(ModuleTestCase): def test_module_utils_basic_ansible_module_is_special_selinux_path(self): from ansible.module_utils import basic - basic._ANSIBLE_ARGS = None args = json.dumps(dict(ANSIBLE_MODULE_ARGS={'_ansible_selinux_special_fs': "nfs,nfsd,foos"})) with swap_stdin_and_argv(stdin_data=args): - + basic._ANSIBLE_ARGS = None am = basic.AnsibleModule( argument_spec = dict(), ) diff --git a/test/units/module_utils/test_distribution_version.py b/test/units/module_utils/test_distribution_version.py index 584ac710c0..d0b7e27415 100644 --- a/test/units/module_utils/test_distribution_version.py +++ b/test/units/module_utils/test_distribution_version.py @@ -703,6 +703,7 @@ def test_distribution_version(): args = json.dumps(dict(ANSIBLE_MODULE_ARGS={})) with swap_stdin_and_argv(stdin_data=args): + basic._ANSIBLE_ARGS = None module = basic.AnsibleModule(argument_spec=dict()) for t in TESTSETS: |