diff --git a/pkg/utils/untar.go b/pkg/utils/untar.go index ed6c6cb2b..592bfa226 100644 --- a/pkg/utils/untar.go +++ b/pkg/utils/untar.go @@ -1,9 +1,13 @@ package utils import ( + "archive/tar" "fmt" "os" + "path/filepath" + "strings" + "github.com/klauspost/compress/zip" "github.com/mholt/archiver/v3" ) @@ -54,7 +58,15 @@ func ExtractArchive(archive, dst string) error { v.Tar = mytar } + extractRoot, err := filepath.Abs(dst) + if err != nil { + return err + } + err = archiver.Walk(archive, func(f archiver.File) error { + if err := validateArchiveMemberPath(extractRoot, archiveMemberName(f)); err != nil { + return err + } if f.FileInfo.Mode()&os.ModeSymlink != 0 { return fmt.Errorf("archive contains a symlink") } @@ -67,3 +79,41 @@ func ExtractArchive(archive, dst string) error { return un.Unarchive(archive, dst) } + +func archiveMemberName(f archiver.File) string { + switch h := f.Header.(type) { + case tar.Header: + return h.Name + case *tar.Header: + return h.Name + case zip.FileHeader: + return h.Name + case *zip.FileHeader: + return h.Name + default: + return f.Name() + } +} + +func validateArchiveMemberPath(root, name string) error { + if name == "" { + return fmt.Errorf("archive contains an empty path") + } + + normalizedName := filepath.FromSlash(strings.ReplaceAll(name, "\\", "/")) + cleanedName := filepath.Clean(normalizedName) + if filepath.IsAbs(cleanedName) || cleanedName == ".." || strings.HasPrefix(cleanedName, ".."+string(os.PathSeparator)) { + return fmt.Errorf("archive contains an unsafe path: %s", name) + } + + targetPath := filepath.Join(root, cleanedName) + relativePath, err := filepath.Rel(root, targetPath) + if err != nil { + return err + } + if relativePath == ".." || strings.HasPrefix(relativePath, ".."+string(os.PathSeparator)) || filepath.IsAbs(relativePath) { + return fmt.Errorf("archive contains an unsafe path: %s", name) + } + + return nil +} diff --git a/pkg/utils/untar_test.go b/pkg/utils/untar_test.go new file mode 100644 index 000000000..e82b3611f --- /dev/null +++ b/pkg/utils/untar_test.go @@ -0,0 +1,128 @@ +package utils_test + +import ( + "archive/tar" + "archive/zip" + "os" + "path/filepath" + + . "github.com/mudler/LocalAI/pkg/utils" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("utils/archive tests", func() { + It("extracts regular nested zip members", func() { + tmpDir := GinkgoT().TempDir() + archivePath := filepath.Join(tmpDir, "model.zip") + extractPath := filepath.Join(tmpDir, "models") + + Expect(writeZipArchive(archivePath, map[string]string{ + "nested/model.yaml": "name: test", + })).To(Succeed()) + + Expect(ExtractArchive(archivePath, extractPath)).To(Succeed()) + + extracted, err := os.ReadFile(filepath.Join(extractPath, "nested", "model.yaml")) + Expect(err).ToNot(HaveOccurred()) + Expect(string(extracted)).To(Equal("name: test")) + }) + + It("rejects zip members that escape the destination", func() { + tmpDir := GinkgoT().TempDir() + archivePath := filepath.Join(tmpDir, "model.zip") + extractPath := filepath.Join(tmpDir, "models") + + Expect(writeZipArchive(archivePath, map[string]string{ + "../escaped.txt": "escaped", + })).To(Succeed()) + + err := ExtractArchive(archivePath, extractPath) + + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("unsafe path")) + Expect(filepath.Join(tmpDir, "escaped.txt")).ToNot(BeAnExistingFile()) + }) + + It("rejects tar members that escape the destination", func() { + tmpDir := GinkgoT().TempDir() + archivePath := filepath.Join(tmpDir, "model.tar") + extractPath := filepath.Join(tmpDir, "models") + + Expect(writeTarArchive(archivePath, map[string]string{ + "../escaped.txt": "escaped", + })).To(Succeed()) + + err := ExtractArchive(archivePath, extractPath) + + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("unsafe path")) + Expect(filepath.Join(tmpDir, "escaped.txt")).ToNot(BeAnExistingFile()) + }) +}) + +func writeZipArchive(path string, files map[string]string) (err error) { + out, err := os.Create(path) + if err != nil { + return err + } + defer func() { + if closeErr := out.Close(); err == nil { + err = closeErr + } + }() + + writer := zip.NewWriter(out) + defer func() { + if closeErr := writer.Close(); err == nil { + err = closeErr + } + }() + + for name, contents := range files { + fileWriter, err := writer.Create(name) + if err != nil { + return err + } + if _, err := fileWriter.Write([]byte(contents)); err != nil { + return err + } + } + + return nil +} + +func writeTarArchive(path string, files map[string]string) (err error) { + out, err := os.Create(path) + if err != nil { + return err + } + defer func() { + if closeErr := out.Close(); err == nil { + err = closeErr + } + }() + + writer := tar.NewWriter(out) + defer func() { + if closeErr := writer.Close(); err == nil { + err = closeErr + } + }() + + for name, contents := range files { + data := []byte(contents) + if err := writer.WriteHeader(&tar.Header{ + Name: name, + Mode: 0o600, + Size: int64(len(data)), + }); err != nil { + return err + } + if _, err := writer.Write(data); err != nil { + return err + } + } + + return nil +}