summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohannes Gijsbers <jlg@dds.nl>2004-08-14 13:30:02 +0000
committerJohannes Gijsbers <jlg@dds.nl>2004-08-14 13:30:02 +0000
commit46f1459860280d9f57d92f1fb389f657e0b2724b (patch)
treef94de8d643c7a7171d664e73ca75999251ee8223
parent404b06814c1e6d527b6e659a907ca1b15a6eaa28 (diff)
downloadcpython-git-46f1459860280d9f57d92f1fb389f657e0b2724b.tar.gz
Raise an exception when src and dst refer to the same file via a hard link or a
symbolic link (bug #851123 / patch #854853, thanks Gregory Ball).
-rw-r--r--Lib/shutil.py16
-rw-r--r--Lib/test/test_shutil.py21
2 files changed, 32 insertions, 5 deletions
diff --git a/Lib/shutil.py b/Lib/shutil.py
index fde8c90fe9..d361fa2b5d 100644
--- a/Lib/shutil.py
+++ b/Lib/shutil.py
@@ -24,16 +24,22 @@ def copyfileobj(fsrc, fdst, length=16*1024):
break
fdst.write(buf)
+def _samefile(src, dst):
+ # Macintosh, Unix.
+ if hasattr(os.path,'samefile'):
+ return os.path.samefile(src, dst)
+
+ # All other platforms: check for same pathname.
+ return (os.path.normcase(os.path.abspath(src)) ==
+ os.path.normcase(os.path.abspath(dst)))
def copyfile(src, dst):
"""Copy data from src to dst"""
+ if _samefile(src, dst):
+ raise Error, "`%s` and `%s` are the same file" % (src, dst)
+
fsrc = None
fdst = None
- # check for same pathname; all platforms
- _src = os.path.normcase(os.path.abspath(src))
- _dst = os.path.normcase(os.path.abspath(dst))
- if _src == _dst:
- return
try:
fsrc = open(src, 'rb')
fdst = open(dst, 'wb')
diff --git a/Lib/test/test_shutil.py b/Lib/test/test_shutil.py
index bcae72f1da..083dbda706 100644
--- a/Lib/test/test_shutil.py
+++ b/Lib/test/test_shutil.py
@@ -6,6 +6,7 @@ import tempfile
import os
import os.path
from test import test_support
+from test.test_support import TESTFN
class TestShutil(unittest.TestCase):
def test_rmtree_errors(self):
@@ -26,6 +27,26 @@ class TestShutil(unittest.TestCase):
except:
pass
+ if hasattr(os, "symlink"):
+ def test_dont_copy_file_onto_link_to_itself(self):
+ # bug 851123.
+ os.mkdir(TESTFN)
+ src = os.path.join(TESTFN,'cheese')
+ dst = os.path.join(TESTFN,'shop')
+ try:
+ f = open(src,'w')
+ f.write('cheddar')
+ f.close()
+ for funcname in 'link','symlink':
+ getattr(os, funcname)(src, dst)
+ self.assertRaises(shutil.Error, shutil.copyfile, src, dst)
+ self.assertEqual(open(src,'r').read(), 'cheddar')
+ os.remove(dst)
+ finally:
+ try:
+ shutil.rmtree(TESTFN)
+ except OSError:
+ pass
def test_main():