summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Hupp <adam@hupp.org>2014-01-04 12:48:39 -0800
committerAdam Hupp <adam@hupp.org>2014-01-04 13:01:41 -0800
commita5ebf0d558c3a851c0884d30a0051836f7db0c2a (patch)
treeeeb8c036fd53d1430ab90a6547f0a78856ac9bfd
parenta0f2249dad53b0c6e7997560d8d333830c578b96 (diff)
downloadpython-magic-a5ebf0d558c3a851c0884d30a0051836f7db0c2a.tar.gz
Properly handle unicode filenames given in both byte-str and unicode
values, across both python 3 and 2.
-rw-r--r--magic.py13
-rw-r--r--test.py6
2 files changed, 17 insertions, 2 deletions
diff --git a/magic.py b/magic.py
index 5025842..e7336c3 100644
--- a/magic.py
+++ b/magic.py
@@ -192,7 +192,18 @@ def errorcheck_negative_one(result, func, args):
def coerce_filename(filename):
if filename is None:
return None
- return filename.encode(sys.getfilesystemencoding())
+
+ # ctypes will implicitly convert unicode strings to bytes with
+ # .encode('ascii'). A more useful default here is
+ # getfilesystemencoding(). We need to leave byte-str unchanged.
+ is_unicode = (sys.version_info.major <= 2 and
+ isinstance(filename, unicode)) or \
+ (sys.version_info.major >= 3 and
+ isinstance(filename, str))
+ if is_unicode:
+ return filename.encode(sys.getfilesystemencoding())
+ else:
+ return filename
magic_open = libmagic.magic_open
magic_open.restype = magic_t
diff --git a/test.py b/test.py
index 6412045..3d922cd 100644
--- a/test.py
+++ b/test.py
@@ -8,7 +8,10 @@ class MagicTest(unittest.TestCase):
def assert_values(self, m, expected_values):
for filename, expected_value in expected_values.items():
- filename = os.path.join(self.TESTDATA_DIR, filename)
+ try:
+ filename = os.path.join(self.TESTDATA_DIR, filename)
+ except TypeError:
+ filename = os.path.join(self.TESTDATA_DIR.encode('utf-8'), filename)
value = m.from_buffer(open(filename, 'rb').read())
expected_value_bytes = expected_value.encode('utf-8')
@@ -25,6 +28,7 @@ class MagicTest(unittest.TestCase):
'test.gz': 'application/x-gzip',
'text.txt': 'text/plain',
b'\xce\xbb'.decode('utf-8'): 'text/plain',
+ b'\xce\xbb': 'text/plain',
})
def test_descriptions(self):