summaryrefslogtreecommitdiff
path: root/nova
diff options
context:
space:
mode:
Diffstat (limited to 'nova')
-rw-r--r--nova/auth/manager.py77
-rw-r--r--nova/cloudpipe/pipelib.py63
-rw-r--r--nova/compat/flagfile.py2
-rw-r--r--nova/crypto.py104
-rw-r--r--nova/tests/test_crypto.py19
-rw-r--r--nova/tests/test_imagecache.py232
-rw-r--r--nova/tests/test_libvirt.py36
-rw-r--r--nova/utils.py14
-rw-r--r--nova/virt/libvirt/connection.py33
-rw-r--r--nova/virt/xenapi/vm_utils.py7
10 files changed, 261 insertions, 326 deletions
diff --git a/nova/auth/manager.py b/nova/auth/manager.py
index 2b67907bcb..e2516bcc16 100644
--- a/nova/auth/manager.py
+++ b/nova/auth/manager.py
@@ -24,9 +24,7 @@ Nova authentication management
"""
import os
-import shutil
import string # pylint: disable=W0402
-import tempfile
import uuid
import zipfile
@@ -767,45 +765,44 @@ class AuthManager(object):
pid = Project.safe_id(project)
private_key, signed_cert = crypto.generate_x509_cert(user.id, pid)
- tmpdir = tempfile.mkdtemp()
- zf = os.path.join(tmpdir, "temp.zip")
- zippy = zipfile.ZipFile(zf, 'w')
- if use_dmz and FLAGS.region_list:
- regions = {}
- for item in FLAGS.region_list:
- region, _sep, region_host = item.partition("=")
- regions[region] = region_host
- else:
- regions = {'nova': FLAGS.ec2_host}
- for region, host in regions.iteritems():
- rc = self.__generate_rc(user,
- pid,
- use_dmz,
- host)
- zippy.writestr(FLAGS.credential_rc_file % region, rc)
-
- zippy.writestr(FLAGS.credential_key_file, private_key)
- zippy.writestr(FLAGS.credential_cert_file, signed_cert)
-
- (vpn_ip, vpn_port) = self.get_project_vpn_data(project)
- if vpn_ip:
- configfile = open(FLAGS.vpn_client_template, "r")
- s = string.Template(configfile.read())
- configfile.close()
- config = s.substitute(keyfile=FLAGS.credential_key_file,
- certfile=FLAGS.credential_cert_file,
- ip=vpn_ip,
- port=vpn_port)
- zippy.writestr(FLAGS.credential_vpn_file, config)
- else:
- LOG.warn(_("No vpn data for project %s"), pid)
-
- zippy.writestr(FLAGS.ca_file, crypto.fetch_ca(pid))
- zippy.close()
- with open(zf, 'rb') as f:
- read_buffer = f.read()
+ with utils.tempdir() as tmpdir:
+ zf = os.path.join(tmpdir, "temp.zip")
+ zippy = zipfile.ZipFile(zf, 'w')
+ if use_dmz and FLAGS.region_list:
+ regions = {}
+ for item in FLAGS.region_list:
+ region, _sep, region_host = item.partition("=")
+ regions[region] = region_host
+ else:
+ regions = {'nova': FLAGS.ec2_host}
+ for region, host in regions.iteritems():
+ rc = self.__generate_rc(user,
+ pid,
+ use_dmz,
+ host)
+ zippy.writestr(FLAGS.credential_rc_file % region, rc)
+
+ zippy.writestr(FLAGS.credential_key_file, private_key)
+ zippy.writestr(FLAGS.credential_cert_file, signed_cert)
+
+ (vpn_ip, vpn_port) = self.get_project_vpn_data(project)
+ if vpn_ip:
+ configfile = open(FLAGS.vpn_client_template, "r")
+ s = string.Template(configfile.read())
+ configfile.close()
+ config = s.substitute(keyfile=FLAGS.credential_key_file,
+ certfile=FLAGS.credential_cert_file,
+ ip=vpn_ip,
+ port=vpn_port)
+ zippy.writestr(FLAGS.credential_vpn_file, config)
+ else:
+ LOG.warn(_("No vpn data for project %s"), pid)
+
+ zippy.writestr(FLAGS.ca_file, crypto.fetch_ca(pid))
+ zippy.close()
+ with open(zf, 'rb') as f:
+ read_buffer = f.read()
- shutil.rmtree(tmpdir)
return read_buffer
def get_environment_rc(self, user, project=None, use_dmz=True):
diff --git a/nova/cloudpipe/pipelib.py b/nova/cloudpipe/pipelib.py
index 4e5f7d4ba1..70c28d463e 100644
--- a/nova/cloudpipe/pipelib.py
+++ b/nova/cloudpipe/pipelib.py
@@ -65,36 +65,39 @@ class CloudPipe(object):
def get_encoded_zip(self, project_id):
# Make a payload.zip
- tmpfolder = tempfile.mkdtemp()
- filename = "payload.zip"
- zippath = os.path.join(tmpfolder, filename)
- z = zipfile.ZipFile(zippath, "w", zipfile.ZIP_DEFLATED)
- shellfile = open(FLAGS.boot_script_template, "r")
- s = string.Template(shellfile.read())
- shellfile.close()
- boot_script = s.substitute(cc_dmz=FLAGS.ec2_dmz_host,
- cc_port=FLAGS.ec2_port,
- dmz_net=FLAGS.dmz_net,
- dmz_mask=FLAGS.dmz_mask,
- num_vpn=FLAGS.cnt_vpn_clients)
- # genvpn, sign csr
- crypto.generate_vpn_files(project_id)
- z.writestr('autorun.sh', boot_script)
- crl = os.path.join(crypto.ca_folder(project_id), 'crl.pem')
- z.write(crl, 'crl.pem')
- server_key = os.path.join(crypto.ca_folder(project_id), 'server.key')
- z.write(server_key, 'server.key')
- ca_crt = os.path.join(crypto.ca_path(project_id))
- z.write(ca_crt, 'ca.crt')
- server_crt = os.path.join(crypto.ca_folder(project_id), 'server.crt')
- z.write(server_crt, 'server.crt')
- z.close()
- zippy = open(zippath, "r")
- # NOTE(vish): run instances expects encoded userdata, it is decoded
- # in the get_metadata_call. autorun.sh also decodes the zip file,
- # hence the double encoding.
- encoded = zippy.read().encode("base64").encode("base64")
- zippy.close()
+ with utils.tempdir() as tmpdir:
+ filename = "payload.zip"
+ zippath = os.path.join(tmpdir, filename)
+ z = zipfile.ZipFile(zippath, "w", zipfile.ZIP_DEFLATED)
+ shellfile = open(FLAGS.boot_script_template, "r")
+ s = string.Template(shellfile.read())
+ shellfile.close()
+ boot_script = s.substitute(cc_dmz=FLAGS.ec2_dmz_host,
+ cc_port=FLAGS.ec2_port,
+ dmz_net=FLAGS.dmz_net,
+ dmz_mask=FLAGS.dmz_mask,
+ num_vpn=FLAGS.cnt_vpn_clients)
+ # genvpn, sign csr
+ crypto.generate_vpn_files(project_id)
+ z.writestr('autorun.sh', boot_script)
+ crl = os.path.join(crypto.ca_folder(project_id), 'crl.pem')
+ z.write(crl, 'crl.pem')
+ server_key = os.path.join(crypto.ca_folder(project_id),
+ 'server.key')
+ z.write(server_key, 'server.key')
+ ca_crt = os.path.join(crypto.ca_path(project_id))
+ z.write(ca_crt, 'ca.crt')
+ server_crt = os.path.join(crypto.ca_folder(project_id),
+ 'server.crt')
+ z.write(server_crt, 'server.crt')
+ z.close()
+ zippy = open(zippath, "r")
+ # NOTE(vish): run instances expects encoded userdata, it is decoded
+ # in the get_metadata_call. autorun.sh also decodes the zip file,
+ # hence the double encoding.
+ encoded = zippy.read().encode("base64").encode("base64")
+ zippy.close()
+
return encoded
def launch_vpn_instance(self, project_id, user_id):
diff --git a/nova/compat/flagfile.py b/nova/compat/flagfile.py
index 8721d3485d..02d571cbf3 100644
--- a/nova/compat/flagfile.py
+++ b/nova/compat/flagfile.py
@@ -175,6 +175,8 @@ def handle_flagfiles_managed(args):
# Do stuff
# Any temporary fils have been removed
'''
+ # NOTE(johannes): Would be nice to use utils.tempdir(), but it
+ # causes an import loop
tempdir = tempfile.mkdtemp(prefix='nova-conf-')
try:
yield handle_flagfiles(args, tempdir=tempdir)
diff --git a/nova/crypto.py b/nova/crypto.py
index 64a4af6837..a5c7711855 100644
--- a/nova/crypto.py
+++ b/nova/crypto.py
@@ -27,9 +27,7 @@ from __future__ import absolute_import
import base64
import hashlib
import os
-import shutil
import string
-import tempfile
import Crypto.Cipher.AES
@@ -127,36 +125,26 @@ def _generate_fingerprint(public_key_file):
def generate_fingerprint(public_key):
- tmpdir = tempfile.mkdtemp()
- try:
- pubfile = os.path.join(tmpdir, 'temp.pub')
- with open(pubfile, 'w') as f:
- f.write(public_key)
- return _generate_fingerprint(pubfile)
- except exception.ProcessExecutionError:
- raise exception.InvalidKeypair()
- finally:
+ with utils.tempdir() as tmpdir:
try:
- shutil.rmtree(tmpdir)
- except IOError, e:
- LOG.debug(_('Could not remove tmpdir: %s'), str(e))
+ pubfile = os.path.join(tmpdir, 'temp.pub')
+ with open(pubfile, 'w') as f:
+ f.write(public_key)
+ return _generate_fingerprint(pubfile)
+ except exception.ProcessExecutionError:
+ raise exception.InvalidKeypair()
def generate_key_pair(bits=1024):
# what is the magic 65537?
- tmpdir = tempfile.mkdtemp()
- keyfile = os.path.join(tmpdir, 'temp')
- utils.execute('ssh-keygen', '-q', '-b', bits, '-N', '',
- '-t', 'rsa', '-f', keyfile)
- fingerprint = _generate_fingerprint('%s.pub' % (keyfile))
- private_key = open(keyfile).read()
- public_key = open(keyfile + '.pub').read()
-
- try:
- shutil.rmtree(tmpdir)
- except OSError, e:
- LOG.debug(_('Could not remove tmpdir: %s'), str(e))
+ with utils.tempdir() as tmpdir:
+ keyfile = os.path.join(tmpdir, 'temp')
+ utils.execute('ssh-keygen', '-q', '-b', bits, '-N', '',
+ '-t', 'rsa', '-f', keyfile)
+ fingerprint = _generate_fingerprint('%s.pub' % (keyfile))
+ private_key = open(keyfile).read()
+ public_key = open(keyfile + '.pub').read()
return (private_key, public_key, fingerprint)
@@ -233,19 +221,15 @@ def _user_cert_subject(user_id, project_id):
def generate_x509_cert(user_id, project_id, bits=1024):
"""Generate and sign a cert for user in project."""
subject = _user_cert_subject(user_id, project_id)
- tmpdir = tempfile.mkdtemp()
- keyfile = os.path.abspath(os.path.join(tmpdir, 'temp.key'))
- csrfile = os.path.join(tmpdir, 'temp.csr')
- utils.execute('openssl', 'genrsa', '-out', keyfile, str(bits))
- utils.execute('openssl', 'req', '-new', '-key', keyfile, '-out', csrfile,
- '-batch', '-subj', subject)
- private_key = open(keyfile).read()
- csr = open(csrfile).read()
- try:
- shutil.rmtree(tmpdir)
- except OSError, e:
- LOG.debug(_('Could not remove tmpdir: %s'), str(e))
+ with utils.tempdir() as tmpdir:
+ keyfile = os.path.abspath(os.path.join(tmpdir, 'temp.key'))
+ csrfile = os.path.join(tmpdir, 'temp.csr')
+ utils.execute('openssl', 'genrsa', '-out', keyfile, str(bits))
+ utils.execute('openssl', 'req', '-new', '-key', keyfile, '-out',
+ csrfile, '-batch', '-subj', subject)
+ private_key = open(keyfile).read()
+ csr = open(csrfile).read()
(serial, signed_csr) = sign_csr(csr, project_id)
fname = os.path.join(ca_folder(project_id), 'newcerts/%s.pem' % serial)
@@ -298,26 +282,30 @@ def sign_csr(csr_text, project_id=None):
def _sign_csr(csr_text, ca_folder):
- tmpfolder = tempfile.mkdtemp()
- inbound = os.path.join(tmpfolder, 'inbound.csr')
- outbound = os.path.join(tmpfolder, 'outbound.csr')
- csrfile = open(inbound, 'w')
- csrfile.write(csr_text)
- csrfile.close()
- LOG.debug(_('Flags path: %s'), ca_folder)
- start = os.getcwd()
- # Change working dir to CA
- if not os.path.exists(ca_folder):
- os.makedirs(ca_folder)
- os.chdir(ca_folder)
- utils.execute('openssl', 'ca', '-batch', '-out', outbound, '-config',
- './openssl.cnf', '-infiles', inbound)
- out, _err = utils.execute('openssl', 'x509', '-in', outbound,
- '-serial', '-noout')
- serial = string.strip(out.rpartition('=')[2])
- os.chdir(start)
- with open(outbound, 'r') as crtfile:
- return (serial, crtfile.read())
+ with utils.tempdir() as tmpdir:
+ inbound = os.path.join(tmpdir, 'inbound.csr')
+ outbound = os.path.join(tmpdir, 'outbound.csr')
+
+ with open(inbound, 'w') as csrfile:
+ csrfile.write(csr_text)
+
+ LOG.debug(_('Flags path: %s'), ca_folder)
+ start = os.getcwd()
+
+ # Change working dir to CA
+ if not os.path.exists(ca_folder):
+ os.makedirs(ca_folder)
+
+ os.chdir(ca_folder)
+ utils.execute('openssl', 'ca', '-batch', '-out', outbound, '-config',
+ './openssl.cnf', '-infiles', inbound)
+ out, _err = utils.execute('openssl', 'x509', '-in', outbound,
+ '-serial', '-noout')
+ serial = string.strip(out.rpartition('=')[2])
+ os.chdir(start)
+
+ with open(outbound, 'r') as crtfile:
+ return (serial, crtfile.read())
def _build_cipher(key, iv):
diff --git a/nova/tests/test_crypto.py b/nova/tests/test_crypto.py
index 2bfb345b95..ee48375825 100644
--- a/nova/tests/test_crypto.py
+++ b/nova/tests/test_crypto.py
@@ -17,8 +17,6 @@ Tests for Crypto module.
"""
import os
-import shutil
-import tempfile
import mox
@@ -50,9 +48,8 @@ class SymmetricKeyTestCase(test.TestCase):
class X509Test(test.TestCase):
def test_can_generate_x509(self):
- tmpdir = tempfile.mkdtemp()
- self.flags(ca_path=tmpdir)
- try:
+ with utils.tempdir() as tmpdir:
+ self.flags(ca_path=tmpdir)
crypto.ensure_ca_filesystem()
_key, cert_str = crypto.generate_x509_cert('fake', 'fake')
@@ -70,14 +67,10 @@ class X509Test(test.TestCase):
project_cert_file, '-verbose', signed_cert_file)
self.assertFalse(err)
- finally:
- shutil.rmtree(tmpdir)
-
def test_encrypt_decrypt_x509(self):
- tmpdir = tempfile.mkdtemp()
- self.flags(ca_path=tmpdir)
- project_id = "fake"
- try:
+ with utils.tempdir() as tmpdir:
+ self.flags(ca_path=tmpdir)
+ project_id = "fake"
crypto.ensure_ca_filesystem()
cert = crypto.fetch_ca(project_id)
public_key = os.path.join(tmpdir, "public.pem")
@@ -92,8 +85,6 @@ class X509Test(test.TestCase):
process_input=text)
dec = crypto.decrypt_text(project_id, enc)
self.assertEqual(text, dec)
- finally:
- shutil.rmtree(tmpdir)
class RevokeCertsTest(test.TestCase):
diff --git a/nova/tests/test_imagecache.py b/nova/tests/test_imagecache.py
index c9f300de17..242f9c0106 100644
--- a/nova/tests/test_imagecache.py
+++ b/nova/tests/test_imagecache.py
@@ -17,12 +17,11 @@
# under the License.
+import contextlib
import cStringIO
import hashlib
import logging
import os
-import shutil
-import tempfile
import time
from nova import test
@@ -58,9 +57,8 @@ class ImageCacheManagerTestCase(test.TestCase):
self.assertEquals(csum, None)
def test_read_stored_checksum(self):
- try:
- dirname = tempfile.mkdtemp()
- fname = os.path.join(dirname, 'aaa')
+ with utils.tempdir() as tmpdir:
+ fname = os.path.join(tmpdir, 'aaa')
csum_input = 'fdghkfhkgjjksfdgjksjkghsdf'
f = open('%s.sha1' % fname, 'w')
@@ -71,9 +69,6 @@ class ImageCacheManagerTestCase(test.TestCase):
self.assertEquals(csum_input, csum_output)
- finally:
- shutil.rmtree(dirname)
-
def test_list_base_images(self):
listing = ['00000001',
'ephemeral_0_20_None',
@@ -281,13 +276,17 @@ class ImageCacheManagerTestCase(test.TestCase):
(base_file2, True, False),
(base_file3, False, True)])
+ @contextlib.contextmanager
def _intercept_log_messages(self):
- mylog = log.getLogger()
- stream = cStringIO.StringIO()
- handler = logging.StreamHandler(stream)
- handler.setFormatter(log.LegacyNovaFormatter())
- mylog.logger.addHandler(handler)
- return mylog, handler, stream
+ try:
+ mylog = log.getLogger()
+ stream = cStringIO.StringIO()
+ handler = logging.StreamHandler(stream)
+ handler.setFormatter(log.LegacyNovaFormatter())
+ mylog.logger.addHandler(handler)
+ yield stream
+ finally:
+ mylog.logger.removeHandler(handler)
def test_verify_checksum(self):
testdata = ('OpenStack Software delivers a massively scalable cloud '
@@ -295,74 +294,69 @@ class ImageCacheManagerTestCase(test.TestCase):
img = {'container_format': 'ami', 'id': '42'}
self.flags(checksum_base_images=True)
- mylog, handler, stream = self._intercept_log_messages()
-
- try:
- dirname = tempfile.mkdtemp()
- fname = os.path.join(dirname, 'aaa')
-
- f = open(fname, 'w')
- f.write(testdata)
- f.close()
-
- # Checksum is valid
- f = open('%s.sha1' % fname, 'w')
- csum = hashlib.sha1()
- csum.update(testdata)
- f.write(csum.hexdigest())
- f.close()
-
- image_cache_manager = imagecache.ImageCacheManager()
- res = image_cache_manager._verify_checksum(img, fname)
- self.assertTrue(res)
-
- # Checksum is invalid
- f = open('%s.sha1' % fname, 'w')
- f.write('banana')
- f.close()
-
- image_cache_manager = imagecache.ImageCacheManager()
- res = image_cache_manager._verify_checksum(img, fname)
- self.assertFalse(res)
- self.assertNotEqual(stream.getvalue().find('image verification '
- 'failed'), -1)
-
- # Checksum file missing
- os.remove('%s.sha1' % fname)
- image_cache_manager = imagecache.ImageCacheManager()
- res = image_cache_manager._verify_checksum(img, fname)
- self.assertEquals(res, None)
-
- # Checksum requests for a file with no checksum now have the
- # side effect of creating the checksum
- self.assertTrue(os.path.exists('%s.sha1' % fname))
-
- finally:
- shutil.rmtree(dirname)
- mylog.logger.removeHandler(handler)
- def _make_base_file(checksum=True):
+ with self._intercept_log_messages() as stream:
+ with utils.tempdir() as tmpdir:
+ fname = os.path.join(tmpdir, 'aaa')
+
+ f = open(fname, 'w')
+ f.write(testdata)
+ f.close()
+
+ # Checksum is valid
+ f = open('%s.sha1' % fname, 'w')
+ csum = hashlib.sha1()
+ csum.update(testdata)
+ f.write(csum.hexdigest())
+ f.close()
+
+ image_cache_manager = imagecache.ImageCacheManager()
+ res = image_cache_manager._verify_checksum(img, fname)
+ self.assertTrue(res)
+
+ # Checksum is invalid
+ f = open('%s.sha1' % fname, 'w')
+ f.write('banana')
+ f.close()
+
+ image_cache_manager = imagecache.ImageCacheManager()
+ res = image_cache_manager._verify_checksum(img, fname)
+ self.assertFalse(res)
+ log = stream.getvalue()
+ self.assertNotEqual(log.find('image verification failed'), -1)
+
+ # Checksum file missing
+ os.remove('%s.sha1' % fname)
+ image_cache_manager = imagecache.ImageCacheManager()
+ res = image_cache_manager._verify_checksum(img, fname)
+ self.assertEquals(res, None)
+
+ # Checksum requests for a file with no checksum now have the
+ # side effect of creating the checksum
+ self.assertTrue(os.path.exists('%s.sha1' % fname))
+
+ @contextlib.contextmanager
+ def _make_base_file(self, checksum=True):
"""Make a base file for testing."""
- dirname = tempfile.mkdtemp()
- fname = os.path.join(dirname, 'aaa')
+ with utils.tempdir() as tmpdir:
+ fname = os.path.join(tmpdir, 'aaa')
- base_file = open(fname, 'w')
- base_file.write('data')
- base_file.close()
- base_file = open(fname, 'r')
+ base_file = open(fname, 'w')
+ base_file.write('data')
+ base_file.close()
+ base_file = open(fname, 'r')
- if checksum:
- checksum_file = open('%s.sha1' % fname, 'w')
- checksum_file.write(utils.hash_file(base_file))
- checksum_file.close()
+ if checksum:
+ checksum_file = open('%s.sha1' % fname, 'w')
+ checksum_file.write(utils.hash_file(base_file))
+ checksum_file.close()
- base_file.close()
- return dirname, fname
+ base_file.close()
+ yield fname
def test_remove_base_file(self):
- dirname, fname = self._make_base_file()
- try:
+ with self._make_base_file() as fname:
image_cache_manager = imagecache.ImageCacheManager()
image_cache_manager._remove_base_file(fname)
@@ -377,12 +371,8 @@ class ImageCacheManagerTestCase(test.TestCase):
self.assertFalse(os.path.exists(fname))
self.assertFalse(os.path.exists('%s.sha1' % fname))
- finally:
- shutil.rmtree(dirname)
-
def test_remove_base_file_original(self):
- dirname, fname = self._make_base_file()
- try:
+ with self._make_base_file() as fname:
image_cache_manager = imagecache.ImageCacheManager()
image_cache_manager.originals = [fname]
image_cache_manager._remove_base_file(fname)
@@ -405,51 +395,38 @@ class ImageCacheManagerTestCase(test.TestCase):
self.assertFalse(os.path.exists(fname))
self.assertFalse(os.path.exists('%s.sha1' % fname))
- finally:
- shutil.rmtree(dirname)
-
def test_remove_base_file_dne(self):
# This test is solely to execute the "does not exist" code path. We
# don't expect the method being tested to do anything in this case.
- dirname = tempfile.mkdtemp()
- try:
- fname = os.path.join(dirname, 'aaa')
+ with utils.tempdir() as tmpdir:
+ fname = os.path.join(tmpdir, 'aaa')
image_cache_manager = imagecache.ImageCacheManager()
image_cache_manager._remove_base_file(fname)
- finally:
- shutil.rmtree(dirname)
-
def test_remove_base_file_oserror(self):
- dirname = tempfile.mkdtemp()
- fname = os.path.join(dirname, 'aaa')
- mylog, handler, stream = self._intercept_log_messages()
+ with self._intercept_log_messages() as stream:
+ with utils.tempdir() as tmpdir:
+ fname = os.path.join(tmpdir, 'aaa')
- try:
- os.mkdir(fname)
- os.utime(fname, (-1, time.time() - 3601))
+ os.mkdir(fname)
+ os.utime(fname, (-1, time.time() - 3601))
- # This will raise an OSError because of file permissions
- image_cache_manager = imagecache.ImageCacheManager()
- image_cache_manager._remove_base_file(fname)
-
- self.assertTrue(os.path.exists(fname))
- self.assertNotEqual(stream.getvalue().find('Failed to remove'),
- -1)
+ # This will raise an OSError because of file permissions
+ image_cache_manager = imagecache.ImageCacheManager()
+ image_cache_manager._remove_base_file(fname)
- finally:
- shutil.rmtree(dirname)
- mylog.logger.removeHandler(handler)
+ self.assertTrue(os.path.exists(fname))
+ self.assertNotEqual(stream.getvalue().find('Failed to remove'),
+ -1)
def test_handle_base_image_unused(self):
img = {'container_format': 'ami',
'id': '123',
'uuid': '1234-4567-2378'}
- dirname, fname = self._make_base_file()
- os.utime(fname, (-1, time.time() - 3601))
+ with self._make_base_file() as fname:
+ os.utime(fname, (-1, time.time() - 3601))
- try:
image_cache_manager = imagecache.ImageCacheManager()
image_cache_manager.unexplained_images = [fname]
image_cache_manager._handle_base_image(img, fname)
@@ -459,18 +436,14 @@ class ImageCacheManagerTestCase(test.TestCase):
[fname])
self.assertEquals(image_cache_manager.corrupt_base_files, [])
- finally:
- shutil.rmtree(dirname)
-
def test_handle_base_image_used(self):
img = {'container_format': 'ami',
'id': '123',
'uuid': '1234-4567-2378'}
- dirname, fname = self._make_base_file()
- os.utime(fname, (-1, time.time() - 3601))
+ with self._make_base_file() as fname:
+ os.utime(fname, (-1, time.time() - 3601))
- try:
image_cache_manager = imagecache.ImageCacheManager()
image_cache_manager.unexplained_images = [fname]
image_cache_manager.used_images = {'123': (1, 0, ['banana-42'])}
@@ -480,18 +453,14 @@ class ImageCacheManagerTestCase(test.TestCase):
self.assertEquals(image_cache_manager.removable_base_files, [])
self.assertEquals(image_cache_manager.corrupt_base_files, [])
- finally:
- shutil.rmtree(dirname)
-
def test_handle_base_image_used_remotely(self):
img = {'container_format': 'ami',
'id': '123',
'uuid': '1234-4567-2378'}
- dirname, fname = self._make_base_file()
- os.utime(fname, (-1, time.time() - 3601))
+ with self._make_base_file() as fname:
+ os.utime(fname, (-1, time.time() - 3601))
- try:
image_cache_manager = imagecache.ImageCacheManager()
image_cache_manager.used_images = {'123': (0, 1, ['banana-42'])}
image_cache_manager._handle_base_image(img, None)
@@ -500,9 +469,6 @@ class ImageCacheManagerTestCase(test.TestCase):
self.assertEquals(image_cache_manager.removable_base_files, [])
self.assertEquals(image_cache_manager.corrupt_base_files, [])
- finally:
- shutil.rmtree(dirname)
-
def test_handle_base_image_absent(self):
"""Ensure we warn for use of a missing base image."""
@@ -510,9 +476,7 @@ class ImageCacheManagerTestCase(test.TestCase):
'id': '123',
'uuid': '1234-4567-2378'}
- mylog, handler, stream = self._intercept_log_messages()
-
- try:
+ with self._intercept_log_messages() as stream:
image_cache_manager = imagecache.ImageCacheManager()
image_cache_manager.used_images = {'123': (1, 0, ['banana-42'])}
image_cache_manager._handle_base_image(img, None)
@@ -523,18 +487,14 @@ class ImageCacheManagerTestCase(test.TestCase):
self.assertNotEqual(stream.getvalue().find('an absent base file'),
-1)
- finally:
- mylog.logger.removeHandler(handler)
-
def test_handle_base_image_used_missing(self):
img = {'container_format': 'ami',
'id': '123',
'uuid': '1234-4567-2378'}
- dirname = tempfile.mkdtemp()
- fname = os.path.join(dirname, 'aaa')
+ with utils.tempdir() as tmpdir:
+ fname = os.path.join(tmpdir, 'aaa')
- try:
image_cache_manager = imagecache.ImageCacheManager()
image_cache_manager.unexplained_images = [fname]
image_cache_manager.used_images = {'123': (1, 0, ['banana-42'])}
@@ -544,17 +504,12 @@ class ImageCacheManagerTestCase(test.TestCase):
self.assertEquals(image_cache_manager.removable_base_files, [])
self.assertEquals(image_cache_manager.corrupt_base_files, [])
- finally:
- shutil.rmtree(dirname)
-
def test_handle_base_image_checksum_fails(self):
img = {'container_format': 'ami',
'id': '123',
'uuid': '1234-4567-2378'}
- dirname, fname = self._make_base_file()
-
- try:
+ with self._make_base_file() as fname:
f = open(fname, 'w')
f.write('banana')
f.close()
@@ -569,9 +524,6 @@ class ImageCacheManagerTestCase(test.TestCase):
self.assertEquals(image_cache_manager.corrupt_base_files,
[fname])
- finally:
- shutil.rmtree(dirname)
-
def test_verify_base_images(self):
self.flags(instances_path='/instance_path')
self.flags(remove_unused_base_images=True)
diff --git a/nova/tests/test_libvirt.py b/nova/tests/test_libvirt.py
index 3044cd136e..3b866ed01f 100644
--- a/nova/tests/test_libvirt.py
+++ b/nova/tests/test_libvirt.py
@@ -1050,26 +1050,25 @@ class LibvirtConnTestCase(test.TestCase):
def test_pre_block_migration_works_correctly(self):
"""Confirms pre_block_migration works correctly."""
# Replace instances_path since this testcase creates tmpfile
- tmpdir = tempfile.mkdtemp()
- self.flags(instances_path=tmpdir)
+ with utils.tempdir() as tmpdir:
+ self.flags(instances_path=tmpdir)
- # Test data
- instance_ref = db.instance_create(self.context, self.test_instance)
- dummyjson = ('[{"path": "%s/disk", "disk_size": "10737418240",'
- ' "type": "raw", "backing_file": ""}]')
+ # Test data
+ instance_ref = db.instance_create(self.context, self.test_instance)
+ dummyjson = ('[{"path": "%s/disk", "disk_size": "10737418240",'
+ ' "type": "raw", "backing_file": ""}]')
- # Preparing mocks
- # qemu-img should be mockd since test environment might not have
- # large disk space.
- self.mox.ReplayAll()
- conn = connection.LibvirtConnection(False)
- conn.pre_block_migration(self.context, instance_ref,
- dummyjson % tmpdir)
+ # Preparing mocks
+ # qemu-img should be mockd since test environment might not have
+ # large disk space.
+ self.mox.ReplayAll()
+ conn = connection.LibvirtConnection(False)
+ conn.pre_block_migration(self.context, instance_ref,
+ dummyjson % tmpdir)
- self.assertTrue(os.path.exists('%s/%s/' %
- (tmpdir, instance_ref.name)))
+ self.assertTrue(os.path.exists('%s/%s/' %
+ (tmpdir, instance_ref.name)))
- shutil.rmtree(tmpdir)
db.instance_destroy(self.context, instance_ref['id'])
@test.skip_if(missing_libvirt(), "Test requires libvirt")
@@ -1926,13 +1925,10 @@ disk size: 4.4M''', ''))
libvirt_utils.mkfs('swap', '/my/swap/block/dev')
def test_ensure_tree(self):
- tmpdir = tempfile.mkdtemp()
- try:
+ with utils.tempdir() as tmpdir:
testdir = '%s/foo/bar/baz' % (tmpdir,)
libvirt_utils.ensure_tree(testdir)
self.assertTrue(os.path.isdir(testdir))
- finally:
- shutil.rmtree(tmpdir)
def test_write_to_file(self):
dst_fd, dst_path = tempfile.mkstemp()
diff --git a/nova/utils.py b/nova/utils.py
index 6bb0dd0f26..ef4932146b 100644
--- a/nova/utils.py
+++ b/nova/utils.py
@@ -31,9 +31,11 @@ import pyclbr
import random
import re
import shlex
+import shutil
import socket
import struct
import sys
+import tempfile
import time
import types
import uuid
@@ -1543,3 +1545,15 @@ def temporary_chown(path, owner_uid=None):
finally:
if orig_uid != owner_uid:
execute('chown', orig_uid, path, run_as_root=True)
+
+
+@contextlib.contextmanager
+def tempdir(**kwargs):
+ tmpdir = tempfile.mkdtemp(**kwargs)
+ try:
+ yield tmpdir
+ finally:
+ try:
+ shutil.rmtree(tmpdir)
+ except OSError, e:
+ LOG.debug(_('Could not remove tmpdir: %s'), str(e))
diff --git a/nova/virt/libvirt/connection.py b/nova/virt/libvirt/connection.py
index c531c2cc81..9ee0985079 100644
--- a/nova/virt/libvirt/connection.py
+++ b/nova/virt/libvirt/connection.py
@@ -46,7 +46,6 @@ import multiprocessing
import os
import shutil
import sys
-import tempfile
import uuid
from eventlet import greenthread
@@ -622,23 +621,21 @@ class LibvirtConnection(driver.ComputeDriver):
disk_path = source.get('file')
# Export the snapshot to a raw image
- temp_dir = tempfile.mkdtemp()
- try:
- out_path = os.path.join(temp_dir, snapshot_name)
- libvirt_utils.extract_snapshot(disk_path, source_format,
- snapshot_name, out_path,
- image_format)
- # Upload that image to the image service
- with libvirt_utils.file_open(out_path) as image_file:
- image_service.update(context,
- image_href,
- metadata,
- image_file)
-
- finally:
- # Clean up
- shutil.rmtree(temp_dir)
- snapshot_ptr.delete(0)
+ with utils.tempdir() as tmpdir:
+ try:
+ out_path = os.path.join(tmpdir, snapshot_name)
+ libvirt_utils.extract_snapshot(disk_path, source_format,
+ snapshot_name, out_path,
+ image_format)
+ # Upload that image to the image service
+ with libvirt_utils.file_open(out_path) as image_file:
+ image_service.update(context,
+ image_href,
+ metadata,
+ image_file)
+
+ finally:
+ snapshot_ptr.delete(0)
@exception.wrap_exception()
def reboot(self, instance, network_info, reboot_type=None, xml=None):
diff --git a/nova/virt/xenapi/vm_utils.py b/nova/virt/xenapi/vm_utils.py
index 3c5264f975..83103d6f28 100644
--- a/nova/virt/xenapi/vm_utils.py
+++ b/nova/virt/xenapi/vm_utils.py
@@ -25,7 +25,6 @@ import json
import os
import pickle
import re
-import tempfile
import time
import urllib
import urlparse
@@ -1750,8 +1749,7 @@ def _mounted_processing(device, key, net, metadata):
"""Callback which runs with the image VDI attached"""
# NB: Partition 1 hardcoded
dev_path = utils.make_dev_path(device, partition=1)
- tmpdir = tempfile.mkdtemp()
- try:
+ with utils.tempdir() as tmpdir:
# Mount only Linux filesystems, to avoid disturbing NTFS images
err = _mount_filesystem(dev_path, tmpdir)
if not err:
@@ -1770,9 +1768,6 @@ def _mounted_processing(device, key, net, metadata):
else:
LOG.info(_('Failed to mount filesystem (expected for '
'non-linux instances): %s') % err)
- finally:
- # remove temporary directory
- os.rmdir(tmpdir)
def _prepare_injectables(inst, networks_info):