summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--oslotest/base.py17
-rw-r--r--tests/unit/test_base.py53
2 files changed, 68 insertions, 2 deletions
diff --git a/oslotest/base.py b/oslotest/base.py
index 213b28c..5b85b16 100644
--- a/oslotest/base.py
+++ b/oslotest/base.py
@@ -20,6 +20,7 @@ import os
import tempfile
import fixtures
+import six
from six.moves import mock
import testtools
@@ -121,17 +122,29 @@ class BaseTestCase(testtools.TestCase):
else:
logging.basicConfig(format=_LOG_FORMAT, level=level)
- def create_tempfiles(self, files, ext='.conf'):
+ def create_tempfiles(self, files, ext='.conf', default_encoding='utf-8'):
"""Safely create temporary files.
:param files: Sequence of tuples containing (filename, file_contents).
:type files: list of tuple
:param ext: File name extension for the temporary file.
:type ext: str
+ :param default_encoding: Default file content encoding when it is
+ not provided, used to decode the tempfile
+ contents from a text string into a binary
+ string.
+ :type default_encoding: str
:return: A list of str with the names of the files created.
"""
tempfiles = []
- for (basename, contents) in files:
+ for f in files:
+ if len(f) == 3:
+ basename, contents, encoding = f
+ else:
+ basename, contents = f
+ encoding = default_encoding
+ if isinstance(contents, six.text_type):
+ contents = contents.encode(encoding)
if not os.path.isabs(basename):
(fd, path) = tempfile.mkstemp(prefix=basename, suffix=ext)
else:
diff --git a/tests/unit/test_base.py b/tests/unit/test_base.py
index 0816041..3ba0fa7 100644
--- a/tests/unit/test_base.py
+++ b/tests/unit/test_base.py
@@ -1,3 +1,5 @@
+# -*- coding: utf-8 -*-
+
# Copyright 2014 Deutsche Telekom AG
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
@@ -13,8 +15,10 @@
# under the License.
import logging
+import os
import unittest
+import six
from six.moves import mock
import testtools
@@ -121,3 +125,52 @@ class TestManualMock(base.BaseTestCase):
patcher = mock.patch('os.environ.get')
patcher.start()
self.addCleanup(patcher.stop)
+
+
+class TestTempFiles(base.BaseTestCase):
+ def test_create_unicode_files(self):
+ files = [["no_approve", u'ಠ_ಠ']]
+ temps = self.create_tempfiles(files)
+ self.assertEqual(1, len(temps))
+ with open(temps[0], 'rb') as f:
+ contents = f.read()
+ self.assertEqual(u'ಠ_ಠ', six.text_type(contents, encoding='utf-8'))
+
+ def test_create_unicode_files_encoding(self):
+ files = [["embarrassed", u'⊙﹏⊙', 'utf-8']]
+ temps = self.create_tempfiles(files)
+ self.assertEqual(1, len(temps))
+ with open(temps[0], 'rb') as f:
+ contents = f.read()
+ self.assertEqual(u'⊙﹏⊙', six.text_type(contents, encoding='utf-8'))
+
+ def test_create_unicode_files_multi_encoding(self):
+ files = [
+ ["embarrassed", u'⊙﹏⊙', 'utf-8'],
+ ['abc', 'abc', 'ascii'],
+ ]
+ temps = self.create_tempfiles(files)
+ self.assertEqual(2, len(temps))
+ for i, (basename, raw_contents, raw_encoding) in enumerate(files):
+ with open(temps[i], 'rb') as f:
+ contents = f.read()
+ if not isinstance(raw_contents, six.text_type):
+ raw_contents = six.text_type(raw_contents,
+ encoding=raw_encoding)
+ self.assertEqual(raw_contents,
+ six.text_type(contents, encoding=raw_encoding))
+
+ def test_create_bad_encoding(self):
+ files = [["hrm", u'ಠ~ಠ', 'ascii']]
+ self.assertRaises(UnicodeError, self.create_tempfiles, files)
+
+ def test_prefix(self):
+ files = [["testing", '']]
+ temps = self.create_tempfiles(files)
+ self.assertEqual(1, len(temps))
+ basename = os.path.basename(temps[0])
+ self.assertTrue(basename.startswith('testing'))
+
+ def test_wrong_length(self):
+ files = [["testing"]]
+ self.assertRaises(ValueError, self.create_tempfiles, files)