From 155374d95d8ecd235d3a3edd92dd6f6a23d59f11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Walter=20D=C3=B6rwald?= Date: Fri, 1 May 2009 19:58:58 +0000 Subject: Merged revisions 72167 via svnmerge from svn+ssh://pythondev@svn.python.org/python/trunk ........ r72167 | walter.doerwald | 2009-05-01 19:35:37 +0200 (Fr, 01 Mai 2009) | 5 lines Make test.test_support.EnvironmentVarGuard behave like a dictionary. All changes are mirrored to the underlying os.environ dict, but rolled back on exit from the with block. ........ --- Lib/test/support.py | 43 ++++++++++++++++++++++++++++++++----------- 1 file changed, 32 insertions(+), 11 deletions(-) (limited to 'Lib/test/support.py') diff --git a/Lib/test/support.py b/Lib/test/support.py index bdc61645ad..df114a2188 100644 --- a/Lib/test/support.py +++ b/Lib/test/support.py @@ -13,6 +13,7 @@ import shutil import warnings import unittest import importlib +import collections __all__ = ["Error", "TestFailed", "ResourceDenied", "import_module", "verbose", "use_resources", "max_memuse", "record_original_stdout", @@ -510,26 +511,45 @@ class CleanImport(object): sys.modules.update(self.original_modules) -class EnvironmentVarGuard(object): +class EnvironmentVarGuard(collections.MutableMapping): """Class to help protect the environment variable properly. Can be used as a context manager.""" def __init__(self): + self._environ = os.environ self._changed = {} - def set(self, envvar, value): + def __getitem__(self, envvar): + return self._environ[envvar] + + def __setitem__(self, envvar, value): # Remember the initial value on the first access if envvar not in self._changed: - self._changed[envvar] = os.environ.get(envvar) - os.environ[envvar] = value + self._changed[envvar] = self._environ.get(envvar) + self._environ[envvar] = value - def unset(self, envvar): + def __delitem__(self, envvar): # Remember the initial value on the first access if envvar not in self._changed: - self._changed[envvar] = os.environ.get(envvar) - if envvar in os.environ: - del os.environ[envvar] + self._changed[envvar] = self._environ.get(envvar) + if envvar in self._environ: + del self._environ[envvar] + + def keys(self): + return self._environ.keys() + + def __iter__(self): + return iter(self._environ) + + def __len__(self): + return len(self._environ) + + def set(self, envvar, value): + self[envvar] = value + + def unset(self, envvar): + del self[envvar] def __enter__(self): return self @@ -537,10 +557,11 @@ class EnvironmentVarGuard(object): def __exit__(self, *ignore_exc): for (k, v) in self._changed.items(): if v is None: - if k in os.environ: - del os.environ[k] + if k in self._environ: + del self._environ[k] else: - os.environ[k] = v + self._environ[k] = v + class TransientResource(object): -- cgit v1.2.1