summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/cmd/gofmt/gofmt.go125
-rw-r--r--src/cmd/gofmt/gofmt_unix_test.go67
2 files changed, 166 insertions, 26 deletions
diff --git a/src/cmd/gofmt/gofmt.go b/src/cmd/gofmt/gofmt.go
index bb22aea031..f4fb6bff84 100644
--- a/src/cmd/gofmt/gofmt.go
+++ b/src/cmd/gofmt/gofmt.go
@@ -17,10 +17,12 @@ import (
"internal/diff"
"io"
"io/fs"
+ "math/rand"
"os"
"path/filepath"
"runtime"
"runtime/pprof"
+ "strconv"
"strings"
"golang.org/x/sync/semaphore"
@@ -269,21 +271,9 @@ func processFile(filename string, info fs.FileInfo, in io.Reader, r *reporter) e
if info == nil {
panic("-w should not have been allowed with stdin")
}
- // make a temporary backup before overwriting original
+
perm := info.Mode().Perm()
- bakname, err := backupFile(filename+".", src, perm)
- if err != nil {
- return err
- }
- fdSem <- true
- err = os.WriteFile(filename, res, perm)
- <-fdSem
- if err != nil {
- os.Rename(bakname, filename)
- return err
- }
- err = os.Remove(bakname)
- if err != nil {
+ if err := writeFile(filename, src, res, perm, info.Size()); err != nil {
return err
}
}
@@ -467,28 +457,111 @@ func fileWeight(path string, info fs.FileInfo) int64 {
return info.Size()
}
-// backupFile writes data to a new file named filename<number> with permissions perm,
-// with <number randomly chosen such that the file name is unique. backupFile returns
-// the chosen file name.
-func backupFile(filename string, data []byte, perm fs.FileMode) (string, error) {
+// writeFile updates a file with the new formatted data.
+func writeFile(filename string, orig, formatted []byte, perm fs.FileMode, size int64) error {
+ // Make a temporary backup file before rewriting the original file.
+ bakname, err := backupFile(filename, orig, perm)
+ if err != nil {
+ return err
+ }
+
fdSem <- true
defer func() { <-fdSem }()
- // create backup file
- f, err := os.CreateTemp(filepath.Dir(filename), filepath.Base(filename))
+ fout, err := os.OpenFile(filename, os.O_WRONLY, perm)
if err != nil {
- return "", err
+ // We couldn't even open the file, so it should
+ // not have changed.
+ os.Remove(bakname)
+ return err
}
- bakname := f.Name()
- err = f.Chmod(perm)
+ defer fout.Close() // for error paths
+
+ restoreFail := func(err error) {
+ fmt.Fprintf(os.Stderr, "gofmt: %s: error restoring file to original: %v; backup in %s\n", filename, err, bakname)
+ }
+
+ n, err := fout.Write(formatted)
+ if err == nil && int64(n) < size {
+ err = fout.Truncate(int64(n))
+ }
+
if err != nil {
- f.Close()
+ // Rewriting the file failed.
+
+ if n == 0 {
+ // Original file unchanged.
+ os.Remove(bakname)
+ return err
+ }
+
+ // Try to restore the original contents.
+
+ no, erro := fout.WriteAt(orig, 0)
+ if erro != nil {
+ // That failed too.
+ restoreFail(erro)
+ return err
+ }
+
+ if no < n {
+ // Original file is shorter. Truncate.
+ if erro = fout.Truncate(int64(no)); erro != nil {
+ restoreFail(erro)
+ return err
+ }
+ }
+
+ if erro := fout.Close(); erro != nil {
+ restoreFail(erro)
+ return err
+ }
+
+ // Original contents restored.
os.Remove(bakname)
- return bakname, err
+ return err
+ }
+
+ if err := fout.Close(); err != nil {
+ restoreFail(err)
+ return err
+ }
+
+ // File updated.
+ os.Remove(bakname)
+ return nil
+}
+
+// backupFile writes data to a new file named filename<number> with permissions perm,
+// with <number> randomly chosen such that the file name is unique. backupFile returns
+// the chosen file name.
+func backupFile(filename string, data []byte, perm fs.FileMode) (string, error) {
+ fdSem <- true
+ defer func() { <-fdSem }()
+
+ nextRandom := func() string {
+ return strconv.Itoa(rand.Int())
+ }
+
+ dir, base := filepath.Split(filename)
+ var (
+ bakname string
+ f *os.File
+ )
+ for {
+ bakname = filepath.Join(dir, base+"."+nextRandom())
+ var err error
+ f, err = os.OpenFile(bakname, os.O_RDWR|os.O_CREATE|os.O_EXCL, perm)
+ if err == nil {
+ break
+ }
+ if err != nil && !os.IsExist(err) {
+ return "", err
+ }
}
// write data to backup file
- _, err = f.Write(data)
+ _, err := f.Write(data)
if err1 := f.Close(); err == nil {
err = err1
}
diff --git a/src/cmd/gofmt/gofmt_unix_test.go b/src/cmd/gofmt/gofmt_unix_test.go
new file mode 100644
index 0000000000..45b9234312
--- /dev/null
+++ b/src/cmd/gofmt/gofmt_unix_test.go
@@ -0,0 +1,67 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix
+
+package main
+
+import (
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+ "time"
+)
+
+func TestPermissions(t *testing.T) {
+ if os.Getuid() == 0 {
+ t.Skip("skipping permission test when running as root")
+ }
+
+ dir := t.TempDir()
+ fn := filepath.Join(dir, "perm.go")
+
+ // Create a file that needs formatting without write permission.
+ if err := os.WriteFile(filepath.Join(fn), []byte(" package main"), 0o400); err != nil {
+ t.Fatal(err)
+ }
+
+ // Set mtime of the file in the past.
+ past := time.Now().Add(-time.Hour)
+ if err := os.Chtimes(fn, past, past); err != nil {
+ t.Fatal(err)
+ }
+
+ info, err := os.Stat(fn)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ defer func() { *write = false }()
+ *write = true
+
+ initParserMode()
+ initRewrite()
+
+ const maxWeight = 2 << 20
+ var buf, errBuf strings.Builder
+ s := newSequencer(maxWeight, &buf, &errBuf)
+ s.Add(fileWeight(fn, info), func(r *reporter) error {
+ return processFile(fn, info, nil, r)
+ })
+ if errBuf.Len() > 0 {
+ t.Log(errBuf)
+ }
+ if s.GetExitCode() == 0 {
+ t.Fatal("rewrite of read-only file succeeded unexpectedly")
+ }
+
+ info, err = os.Stat(fn)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !info.ModTime().Equal(past) {
+ t.Errorf("after rewrite mod time is %v, want %v", info.ModTime(), past)
+ }
+}