Validate archive member paths before extraction (#9820)

Signed-off-by: massy-o <telitos000@gmail.com>
This commit is contained in:
massy_o
2026-05-15 18:12:13 +09:00
committed by GitHub
parent 4abf5befbb
commit 8c785dbe4a
2 changed files with 178 additions and 0 deletions

View File

@@ -1,9 +1,13 @@
package utils
import (
"archive/tar"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/klauspost/compress/zip"
"github.com/mholt/archiver/v3"
)
@@ -54,7 +58,15 @@ func ExtractArchive(archive, dst string) error {
v.Tar = mytar
}
extractRoot, err := filepath.Abs(dst)
if err != nil {
return err
}
err = archiver.Walk(archive, func(f archiver.File) error {
if err := validateArchiveMemberPath(extractRoot, archiveMemberName(f)); err != nil {
return err
}
if f.FileInfo.Mode()&os.ModeSymlink != 0 {
return fmt.Errorf("archive contains a symlink")
}
@@ -67,3 +79,41 @@ func ExtractArchive(archive, dst string) error {
return un.Unarchive(archive, dst)
}
func archiveMemberName(f archiver.File) string {
switch h := f.Header.(type) {
case tar.Header:
return h.Name
case *tar.Header:
return h.Name
case zip.FileHeader:
return h.Name
case *zip.FileHeader:
return h.Name
default:
return f.Name()
}
}
func validateArchiveMemberPath(root, name string) error {
if name == "" {
return fmt.Errorf("archive contains an empty path")
}
normalizedName := filepath.FromSlash(strings.ReplaceAll(name, "\\", "/"))
cleanedName := filepath.Clean(normalizedName)
if filepath.IsAbs(cleanedName) || cleanedName == ".." || strings.HasPrefix(cleanedName, ".."+string(os.PathSeparator)) {
return fmt.Errorf("archive contains an unsafe path: %s", name)
}
targetPath := filepath.Join(root, cleanedName)
relativePath, err := filepath.Rel(root, targetPath)
if err != nil {
return err
}
if relativePath == ".." || strings.HasPrefix(relativePath, ".."+string(os.PathSeparator)) || filepath.IsAbs(relativePath) {
return fmt.Errorf("archive contains an unsafe path: %s", name)
}
return nil
}

128
pkg/utils/untar_test.go Normal file
View File

@@ -0,0 +1,128 @@
package utils_test
import (
"archive/tar"
"archive/zip"
"os"
"path/filepath"
. "github.com/mudler/LocalAI/pkg/utils"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("utils/archive tests", func() {
It("extracts regular nested zip members", func() {
tmpDir := GinkgoT().TempDir()
archivePath := filepath.Join(tmpDir, "model.zip")
extractPath := filepath.Join(tmpDir, "models")
Expect(writeZipArchive(archivePath, map[string]string{
"nested/model.yaml": "name: test",
})).To(Succeed())
Expect(ExtractArchive(archivePath, extractPath)).To(Succeed())
extracted, err := os.ReadFile(filepath.Join(extractPath, "nested", "model.yaml"))
Expect(err).ToNot(HaveOccurred())
Expect(string(extracted)).To(Equal("name: test"))
})
It("rejects zip members that escape the destination", func() {
tmpDir := GinkgoT().TempDir()
archivePath := filepath.Join(tmpDir, "model.zip")
extractPath := filepath.Join(tmpDir, "models")
Expect(writeZipArchive(archivePath, map[string]string{
"../escaped.txt": "escaped",
})).To(Succeed())
err := ExtractArchive(archivePath, extractPath)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("unsafe path"))
Expect(filepath.Join(tmpDir, "escaped.txt")).ToNot(BeAnExistingFile())
})
It("rejects tar members that escape the destination", func() {
tmpDir := GinkgoT().TempDir()
archivePath := filepath.Join(tmpDir, "model.tar")
extractPath := filepath.Join(tmpDir, "models")
Expect(writeTarArchive(archivePath, map[string]string{
"../escaped.txt": "escaped",
})).To(Succeed())
err := ExtractArchive(archivePath, extractPath)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("unsafe path"))
Expect(filepath.Join(tmpDir, "escaped.txt")).ToNot(BeAnExistingFile())
})
})
func writeZipArchive(path string, files map[string]string) (err error) {
out, err := os.Create(path)
if err != nil {
return err
}
defer func() {
if closeErr := out.Close(); err == nil {
err = closeErr
}
}()
writer := zip.NewWriter(out)
defer func() {
if closeErr := writer.Close(); err == nil {
err = closeErr
}
}()
for name, contents := range files {
fileWriter, err := writer.Create(name)
if err != nil {
return err
}
if _, err := fileWriter.Write([]byte(contents)); err != nil {
return err
}
}
return nil
}
func writeTarArchive(path string, files map[string]string) (err error) {
out, err := os.Create(path)
if err != nil {
return err
}
defer func() {
if closeErr := out.Close(); err == nil {
err = closeErr
}
}()
writer := tar.NewWriter(out)
defer func() {
if closeErr := writer.Close(); err == nil {
err = closeErr
}
}()
for name, contents := range files {
data := []byte(contents)
if err := writer.WriteHeader(&tar.Header{
Name: name,
Mode: 0o600,
Size: int64(len(data)),
}); err != nil {
return err
}
if _, err := writer.Write(data); err != nil {
return err
}
}
return nil
}