diff --git a/file/file.go b/file/file.go index 7a50c03..56f26f8 100644 --- a/file/file.go +++ b/file/file.go @@ -54,3 +54,21 @@ func ExpandFileList(fileList []string) ([]string, error) { return result, nil } + +// WriteTmpFile creates a temporary file with the given name and content, and returns the path to the created file. +func WriteTmpFile(name, content string) (string, error) { + tmpfile, err := os.CreateTemp("", name) + if err != nil { + return "", err + } + + if _, err := tmpfile.Write([]byte(content)); err != nil { + return "", err + } + + if err := tmpfile.Close(); err != nil { + return "", err + } + + return tmpfile.Name(), nil +} diff --git a/file/file_test.go b/file/file_test.go new file mode 100644 index 0000000..618a623 --- /dev/null +++ b/file/file_test.go @@ -0,0 +1,64 @@ +package file + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +const helloWorld = "Hello, World!" + +func TestWriteTmpFile(t *testing.T) { + tests := []struct { + name string + fileName string + content string + wantErr bool + }{ + { + name: "write to temp file", + fileName: "test.txt", + content: helloWorld, + wantErr: false, + }, + { + name: "empty file name", + fileName: "", + content: helloWorld, + wantErr: false, + }, + { + name: "empty file content", + fileName: "test.txt", + content: "", + wantErr: false, + }, + { + name: "create temp file error", + fileName: filepath.Join(os.TempDir(), "non-existent", "test.txt"), + content: helloWorld, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpFile, err := WriteTmpFile(tt.fileName, tt.content) + if tt.wantErr { + assert.Error(t, err) + + return + } + + assert.NoError(t, err) + + defer os.Remove(tmpFile) + + data, err := os.ReadFile(tmpFile) + assert.NoError(t, err) + assert.Equal(t, tt.content, string(data)) + }) + } +} diff --git a/util/user.go b/util/user.go new file mode 100644 index 0000000..e1af45f --- /dev/null +++ b/util/user.go @@ -0,0 +1,15 @@ +package util + +import "os/user" + +// GetUserHomeDir returns the home directory path for the current user. +// If the current user cannot be determined, it returns the default "/root" path. +func GetUserHomeDir() string { + home := "/root" + + if currentUser, err := user.Current(); err == nil { + home = currentUser.HomeDir + } + + return home +}