Files
LocalAI/pkg/downloader/uri.go
Richard Palethorpe 12d1f3a697 security(http): refuse redirects on outbound clients via hardened pkg/httpclient (#10087)
LocalAI's outbound HTTP clients used Go's default redirect policy, which
follows up to 10 redirects. On a cross-host redirect Go forwards custom
request headers — including credential headers such as Anthropic's
x-api-key — to the redirect target (Go strips Authorization, Cookie and
WWW-Authenticate cross-host, but NOT arbitrary custom headers). An
attacker able to elicit a redirect from an upstream (a hijacked or
spoofed upstream, DNS trickery, or a malicious upstream_url) then
harvests the operator's provider API key.

This was first reported against the cloud-proxy / MITM PII path
(GHSA-3mj3-57v2-4636); the same class affects every other outbound
client. Rather than patch each call site, add pkg/httpclient as the one
sanctioned constructor for outbound HTTP and route everything through it.

pkg/httpclient:
  - New(...)             refuses redirects, TLS 1.2 floor, no body
                         deadline (streaming/SSE safe)
  - NewWithTimeout(d)    simple request/response calls
  - WithFollowRedirects  opt-in following that still strips credential
                         headers on any cross-host hop; different
                         scheme/host/port == different origin, guarding
                         the curl CVE-2022-27774 port-confusion class
  - WithTransport(rt)    keep a custom transport (IP-pin, HTTP/2, a
                         credential-injecting RoundTripper)
  - HardenedTransport()  base transport with the TLS floor + bounded setup
  - Harden(c)            apply the policy to a library-supplied *http.Client
  - NoRedirect           the CheckRedirect policy; wraps ErrRedirectBlocked

Lint: a forbidigo rule flags http.DefaultClient and http.Get/Post/
PostForm/Head, pointing at pkg/httpclient (.golangci.yml,
.agents/coding-style.md). forbidigo cannot match the &http.Client{}
composite literal without also flagging legitimate *http.Client type
references, so that form is enforced by review.

Migrates every non-test outbound call site across core/, pkg/, cmd/, and
the Go backend (backend/go/cloud-proxy). Credential-bearing and
internal-RPC clients refuse redirects; download / CDN / registry clients
use WithFollowRedirects so they keep working while stripping secrets
cross-host. The only credential-bearing client that follows redirects is
the gated-download path (pkg/downloader/uri.go), which strips the token
on the cross-host hop to the CDN. Hardening this closes, in passing:
  - MCP remote-server bearer token leaking via a redirect (the
    RoundTripper re-injected Authorization on every hop)
  - agent multimedia/webhook clients leaking user-supplied auth headers
  - cors_proxy following redirects, bypassing its SSRF IP-pin
  - downloader's authorized read path leaking the token cross-host

Fixes: GHSA-3mj3-57v2-4636 (cloud-proxy leaks operator provider API key
(x-api-key) to attacker host on cross-host redirect)
Reported-by: tonghuaroot
Assisted-by: Claude:claude-opus-4-8 [Claude Code]

Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-05-30 12:04:10 +02:00

724 lines
23 KiB
Go

package downloader
import (
"context"
"crypto/sha256"
"errors"
"fmt"
"hash"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"github.com/google/go-containerregistry/pkg/v1/tarball"
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
"github.com/mudler/xlog"
"github.com/mudler/LocalAI/pkg/httpclient"
"github.com/mudler/LocalAI/pkg/oci"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/mudler/LocalAI/pkg/xio"
)
const (
HuggingFacePrefix = "huggingface://"
HuggingFacePrefix1 = "hf://"
HuggingFacePrefix2 = "hf.co/"
OCIPrefix = "oci://"
OCIFilePrefix = "ocifile://"
OllamaPrefix = "ollama://"
HTTPPrefix = "http://"
HTTPSPrefix = "https://"
GithubURI = "github:"
GithubURI2 = "github://"
LocalPrefix = "file://"
)
type URI string
// ImageVerifier verifies the integrity of an OCI image — typically a
// cosign signature check against a sigstore policy. The downloader runs
// VerifyImage between fetching the image manifest and extracting its
// layers, so verification failure prevents any tampered bytes reaching
// disk.
//
// pkg/oci/cosignverify.Verifier satisfies this interface.
type ImageVerifier interface {
VerifyImage(ctx context.Context, imageRef string) error
}
type downloadOptions struct {
verifier ImageVerifier
}
// DownloadOption configures DownloadFileWithContext / DownloadFile.
//
// Variadic at the end of the signature keeps the public API backward
// compatible: existing callers that don't care about verification keep
// compiling untouched.
type DownloadOption func(*downloadOptions)
// WithImageVerifier attaches an ImageVerifier that runs against OCI
// downloads only. No-op for tarball / HTTP / Ollama / local downloads —
// those paths use SHA256 integrity instead.
func WithImageVerifier(v ImageVerifier) DownloadOption {
return func(o *downloadOptions) { o.verifier = v }
}
func applyDownloadOptions(opts []DownloadOption) downloadOptions {
var o downloadOptions
for _, fn := range opts {
fn(&o)
}
return o
}
// pinnedImageRef rewrites `repo:tag` (or `repo[@digest]`) into `repo@<digest>`
// so callers can pass the explicit digest the downloader just resolved to
// any tag-following client, eliminating TOCTOU between fetches.
func pinnedImageRef(ref, digest string) string {
// Strip an existing @digest if present so we always emit a clean ref.
if at := strings.LastIndex(ref, "@"); at != -1 {
// Only treat as a digest separator when not preceded by a slash
// (avoids breaking unusual hostnames). Conservative: just keep
// the registry+repo portion.
ref = ref[:at]
}
// Strip an existing :tag — find the rightmost colon after the last
// slash so we don't touch the registry port (e.g. localhost:5000/foo:latest).
slash := strings.LastIndex(ref, "/")
if colon := strings.LastIndex(ref, ":"); colon > slash {
ref = ref[:colon]
}
return ref + "@" + digest
}
// HF_ENDPOINT is the HuggingFace endpoint, can be overridden by setting the HF_ENDPOINT environment variable.
var HF_ENDPOINT string = loadConfig()
// loadConfig returns the HuggingFace endpoint URL.
// It supports the following environment variables in order of precedence:
// 1. HF_MIRROR - if set, uses this as the mirror URL (takes precedence over HF_ENDPOINT)
// 2. HF_ENDPOINT - if set, uses this as the endpoint
// 3. Default: https://huggingface.co
//
// HF_MIRROR supports both full URLs (https://hf-mirror.com) and simple hostnames (hf-mirror.com).
// If no scheme is provided, https:// is automatically added.
func loadConfig() string {
// Check for HF_MIRROR first (takes precedence)
HF_MIRROR := os.Getenv("HF_MIRROR")
if HF_MIRROR == "" {
HF_MIRROR = os.Getenv("HF")
}
if HF_MIRROR != "" {
// Normalize the mirror URL - add https:// if no scheme
if !strings.HasPrefix(HF_MIRROR, "http://") && !strings.HasPrefix(HF_MIRROR, "https://") {
HF_MIRROR = "https://" + HF_MIRROR
}
return HF_MIRROR
}
// Fall back to HF_ENDPOINT
HF_ENDPOINT := os.Getenv("HF_ENDPOINT")
if HF_ENDPOINT == "" {
HF_ENDPOINT = "https://huggingface.co"
}
return HF_ENDPOINT
}
func (uri URI) ReadWithCallback(basePath string, f func(url string, i []byte) error) error {
return uri.ReadWithAuthorizationAndCallback(context.Background(), basePath, "", f)
}
func (uri URI) ReadWithAuthorizationAndCallback(ctx context.Context, basePath string, authorization string, f func(url string, i []byte) error) error {
url := uri.ResolveURL()
if strings.HasPrefix(string(uri), LocalPrefix) {
// checks if the file is symbolic, and resolve if so - otherwise, this function returns the path unmodified.
resolvedFile, err := filepath.EvalSymlinks(url)
if err != nil {
return err
}
resolvedBasePath, err := filepath.EvalSymlinks(basePath)
if err != nil {
return err
}
// Check if the local file is rooted in basePath
err = utils.InTrustedRoot(resolvedFile, resolvedBasePath)
if err != nil {
xlog.Debug("downloader.GetURI blocked an attempt to ready a file url outside of basePath", "resolvedFile", resolvedFile, "basePath", basePath)
return err
}
// Read the response body
body, err := os.ReadFile(resolvedFile)
if err != nil {
return err
}
// Unmarshal YAML data into a struct
return f(url, body)
}
// Send a GET request to the URL
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return err
}
if authorization != "" {
req.Header.Add("Authorization", authorization)
}
response, err := downloadClient.Do(req)
if err != nil {
return err
}
defer response.Body.Close()
// Read the response body
body, err := io.ReadAll(response.Body)
if err != nil {
return err
}
// Unmarshal YAML data into a struct
return f(url, body)
}
func (u URI) FilenameFromUrl() (string, error) {
if f := filenameFromUrl(string(u)); f != "" {
return f, nil
}
f := utils.MD5(string(u))
if strings.HasSuffix(string(u), ".yaml") || strings.HasSuffix(string(u), ".yml") {
f = f + ".yaml"
}
return f, nil
}
func filenameFromUrl(urlstr string) string {
// strip anything after @
if strings.Contains(urlstr, "@") {
urlstr = strings.Split(urlstr, "@")[0]
}
u, err := url.Parse(urlstr)
if err != nil {
return ""
}
x, err := url.QueryUnescape(u.EscapedPath())
if err != nil {
return ""
}
return filepath.Base(x)
}
func (u URI) LooksLikeURL() bool {
return strings.HasPrefix(string(u), HTTPPrefix) ||
strings.HasPrefix(string(u), HTTPSPrefix) ||
strings.HasPrefix(string(u), HuggingFacePrefix) ||
strings.HasPrefix(string(u), HuggingFacePrefix1) ||
strings.HasPrefix(string(u), HuggingFacePrefix2) ||
strings.HasPrefix(string(u), GithubURI) ||
strings.HasPrefix(string(u), OllamaPrefix) ||
strings.HasPrefix(string(u), OCIPrefix) ||
strings.HasPrefix(string(u), GithubURI2)
}
func (u URI) LooksLikeHTTPURL() bool {
return strings.HasPrefix(string(u), HTTPPrefix) ||
strings.HasPrefix(string(u), HTTPSPrefix)
}
func (u URI) LooksLikeDir() bool {
f, err := os.Stat(string(u))
return err == nil && f.IsDir()
}
func (s URI) LooksLikeOCI() bool {
return strings.HasPrefix(string(s), "quay.io") ||
strings.HasPrefix(string(s), OCIPrefix) ||
strings.HasPrefix(string(s), OllamaPrefix) ||
strings.HasPrefix(string(s), OCIFilePrefix) ||
strings.HasPrefix(string(s), "ghcr.io") ||
strings.HasPrefix(string(s), "docker.io")
}
func (s URI) LooksLikeOCIFile() bool {
return strings.HasPrefix(string(s), OCIFilePrefix)
}
func (s URI) ResolveURL() string {
switch {
case strings.HasPrefix(string(s), LocalPrefix):
return strings.TrimPrefix(string(s), LocalPrefix)
case strings.HasPrefix(string(s), GithubURI2):
repository := strings.Replace(string(s), GithubURI2, "", 1)
repoParts := strings.Split(repository, "@")
branch := "main"
if len(repoParts) > 1 {
branch = repoParts[1]
}
repoPath := strings.Split(repoParts[0], "/")
org := repoPath[0]
project := repoPath[1]
projectPath := strings.Join(repoPath[2:], "/")
return fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath)
case strings.HasPrefix(string(s), GithubURI):
parts := strings.Split(string(s), ":")
repoParts := strings.Split(parts[1], "@")
branch := "main"
if len(repoParts) > 1 {
branch = repoParts[1]
}
repoPath := strings.Split(repoParts[0], "/")
org := repoPath[0]
project := repoPath[1]
projectPath := strings.Join(repoPath[2:], "/")
return fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath)
case strings.HasPrefix(string(s), HuggingFacePrefix) || strings.HasPrefix(string(s), HuggingFacePrefix1) || strings.HasPrefix(string(s), HuggingFacePrefix2):
repository := strings.Replace(string(s), HuggingFacePrefix, "", 1)
repository = strings.Replace(repository, HuggingFacePrefix1, "", 1)
repository = strings.Replace(repository, HuggingFacePrefix2, "", 1)
// convert repository to a full URL.
// e.g. TheBloke/Mixtral-8x7B-v0.1-GGUF/mixtral-8x7b-v0.1.Q2_K.gguf@main -> https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q2_K.gguf
repoPieces := strings.Split(repository, "/")
repoID := strings.Split(repository, "@")
if len(repoPieces) < 3 {
return string(s)
}
owner := repoPieces[0]
repo := repoPieces[1]
branch := "main"
filepath := strings.Join(repoPieces[2:], "/")
if len(repoID) > 1 {
if strings.Contains(repo, "@") {
branch = repoID[1]
}
if strings.Contains(filepath, "@") {
filepath = repoID[2]
}
}
return fmt.Sprintf("%s/%s/%s/resolve/%s/%s", HF_ENDPOINT, owner, repo, branch, filepath)
}
// If a HuggingFace mirror is configured, rewrite direct https://huggingface.co/ URLs
// to use the mirror. This ensures gallery entries with hardcoded URLs also benefit
// from the mirror setting.
if HF_ENDPOINT != "https://huggingface.co" && strings.HasPrefix(string(s), "https://huggingface.co/") {
return HF_ENDPOINT + strings.TrimPrefix(string(s), "https://huggingface.co")
}
return string(s)
}
func removePartialFile(tmpFilePath string) error {
xlog.Debug("Removing temporary file", "file", tmpFilePath)
if err := os.Remove(tmpFilePath); err != nil && !errors.Is(err, os.ErrNotExist) {
err1 := fmt.Errorf("failed to remove temporary download file %s: %v", tmpFilePath, err)
xlog.Warn("failed to remove temporary download file", "error", err1)
return err1
}
return nil
}
func calculateHashForPartialFile(file *os.File) (hash.Hash, error) {
hash := sha256.New()
_, err := io.Copy(hash, file)
if err != nil {
return nil, err
}
return hash, nil
}
// downloadClient is the shared client for HTTP(S) downloads and size
// probes. It follows redirects (model hosts and CDNs rely on them) but
// strips credential headers on any cross-host hop, and sets no body
// deadline so large downloads are not truncated.
var downloadClient = httpclient.New(httpclient.WithFollowRedirects())
func (uri URI) checkSeverSupportsRangeHeader() (bool, error) {
url := uri.ResolveURL()
resp, err := downloadClient.Head(url)
if err != nil {
return false, err
}
defer resp.Body.Close()
return resp.Header.Get("Accept-Ranges") == "bytes", nil
}
// ContentLength returns the size in bytes of the resource at the URI.
// For file:// it uses os.Stat on the resolved path; for HTTP/HTTPS it uses HEAD
// and optionally a Range request if Content-Length is missing.
func (u URI) ContentLength(ctx context.Context) (int64, error) {
urlStr := u.ResolveURL()
if strings.HasPrefix(string(u), LocalPrefix) {
info, err := os.Stat(urlStr)
if err != nil {
return 0, err
}
return info.Size(), nil
}
if !strings.HasPrefix(urlStr, HTTPPrefix) && !strings.HasPrefix(urlStr, HTTPSPrefix) {
return 0, fmt.Errorf("unsupported URI scheme for ContentLength: %s", string(u))
}
req, err := http.NewRequestWithContext(ctx, "HEAD", urlStr, nil)
if err != nil {
return 0, err
}
resp, err := downloadClient.Do(req)
if err != nil {
return 0, err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return 0, fmt.Errorf("HEAD %s: status %d", urlStr, resp.StatusCode)
}
if resp.ContentLength >= 0 {
return resp.ContentLength, nil
}
if resp.Header.Get("Accept-Ranges") != "bytes" {
return 0, fmt.Errorf("HEAD %s: no Content-Length and server does not support Range", urlStr)
}
req2, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
if err != nil {
return 0, err
}
req2.Header.Set("Range", "bytes=0-0")
resp2, err := downloadClient.Do(req2)
if err != nil {
return 0, err
}
defer resp2.Body.Close()
if resp2.StatusCode != http.StatusPartialContent && resp2.StatusCode != http.StatusOK {
return 0, fmt.Errorf("Range request %s: status %d", urlStr, resp2.StatusCode)
}
cr := resp2.Header.Get("Content-Range")
// Content-Range: bytes 0-0/12345
if cr == "" {
return 0, fmt.Errorf("Range request %s: no Content-Range header", urlStr)
}
parts := strings.Split(cr, "/")
if len(parts) != 2 {
return 0, fmt.Errorf("invalid Content-Range: %s", cr)
}
size, err := strconv.ParseInt(strings.TrimSpace(parts[1]), 10, 64)
if err != nil || size < 0 {
return 0, fmt.Errorf("invalid Content-Range total length: %s", parts[1])
}
return size, nil
}
func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStatus func(string, string, string, float64), opts ...DownloadOption) error {
return uri.DownloadFileWithContext(context.Background(), filePath, sha, fileN, total, downloadStatus, opts...)
}
func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string, fileN, total int, downloadStatus func(string, string, string, float64), opts ...DownloadOption) error {
dopts := applyDownloadOptions(opts)
url := uri.ResolveURL()
if uri.LooksLikeOCI() {
// Only Ollama wants to download to the file, for the rest, we want to download to the directory
// so we check if filepath has any extension, otherwise we assume it's a directory.
// Caveat: `filepath.Ext` treats any dot-suffix as an extension, so paths like
// `backends/local-store.upgrade-tmp` (the tmp dir created by gallery.UpgradeBackend)
// look like a "file" to this heuristic and get rewritten to their parent — which
// then unpacks the image at `backends/` top level and clobbers the real install
// with a flat-layout file. Guard against that by short-circuiting when the caller
// has already created the target as a directory: OCI destinations are always dirs
// in that case, regardless of what their suffix looks like.
if !strings.HasPrefix(url, OllamaPrefix) {
if fi, statErr := os.Stat(filePath); statErr == nil && fi.IsDir() {
// Existing directory — use as-is.
} else if filepath.Ext(filePath) != "" {
filePath = filepath.Dir(filePath)
}
}
progressStatus := func(desc ocispec.Descriptor) io.Writer {
return &progressWriter{
fileName: filePath,
total: desc.Size,
hash: sha256.New(),
fileNo: fileN,
totalFiles: total,
downloadStatus: downloadStatus,
}
}
if url, ok := strings.CutPrefix(url, OllamaPrefix); ok {
return oci.OllamaFetchModel(ctx, url, filePath, progressStatus)
}
if url, ok := strings.CutPrefix(url, OCIFilePrefix); ok {
// Open the tarball
img, err := tarball.ImageFromPath(url, nil)
if err != nil {
return fmt.Errorf("failed to open tarball: %s", err.Error())
}
return oci.ExtractOCIImage(ctx, img, url, filePath, downloadStatus)
}
url = strings.TrimPrefix(url, OCIPrefix)
img, err := oci.GetImage(url, "", nil, nil)
if err != nil {
return fmt.Errorf("failed to get image %q: %v", url, err)
}
// Verify before extract so tampered bytes never reach disk. We
// re-pin the ref to the manifest digest we just fetched: the
// verifier would otherwise resolve the tag again, opening a tiny
// TOCTOU window in which a registry could swap the underlying
// manifest between the two HEADs.
if dopts.verifier != nil {
digest, derr := img.Digest()
if derr != nil {
return fmt.Errorf("resolving digest for verification of %q: %v", url, derr)
}
pinned := pinnedImageRef(url, digest.String())
if verr := dopts.verifier.VerifyImage(ctx, pinned); verr != nil {
return fmt.Errorf("image verification failed for %q: %w", url, verr)
}
xlog.Info("Image signature verified", "ref", pinned)
}
return oci.ExtractOCIImage(ctx, img, url, filePath, downloadStatus)
}
// Check for cancellation before starting
select {
case <-ctx.Done():
return ctx.Err()
default:
}
// Check if the file already exists
fi, err := os.Stat(filePath)
if err == nil {
// Directories don't count as cached downloads (e.g. empty dirs left
// by failed OCI extractions). Only skip for regular files.
if fi.IsDir() {
xlog.Debug("[downloader] Path is a directory, not treating as cached download", "filePath", filePath)
} else {
xlog.Debug("[downloader] File already exists", "filePath", filePath)
// File exists, check SHA
if sha != "" {
// Verify SHA
calculatedSHA, err := CalculateSHA(filePath)
if err != nil {
return fmt.Errorf("failed to calculate SHA for file %q: %v", filePath, err)
}
if calculatedSHA == sha {
// SHA matches, skip downloading
xlog.Debug("File already exists and matches the SHA. Skipping download", "file", filePath)
return nil
}
// SHA doesn't match, delete the file and download again
err = os.Remove(filePath)
if err != nil {
return fmt.Errorf("failed to remove existing file %q: %v", filePath, err)
}
xlog.Debug("Removed file (SHA doesn't match)", "file", filePath)
} else {
// SHA is missing, skip downloading
xlog.Debug("File already exists. Skipping download", "file", filePath)
return nil
}
}
} else if !os.IsNotExist(err) || !URI(url).LooksLikeHTTPURL() {
// Error occurred while checking file existence
return fmt.Errorf("could not fetch %q: local file does not exist (%v) and %q is not a recognized downloadable URL (supported schemes: %s)", filePath, err, url, strings.Join([]string{HTTPPrefix, HTTPSPrefix, LocalPrefix, HuggingFacePrefix, HuggingFacePrefix1, OllamaPrefix, OCIPrefix, OCIFilePrefix, GithubURI2}, ", "))
}
xlog.Info("Downloading", "url", url)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return fmt.Errorf("failed to create request for %q: %v", filePath, err)
}
// save partial download to dedicated file
tmpFilePath := filePath + ".partial"
tmpFileInfo, err := os.Stat(tmpFilePath)
if err == nil && uri.LooksLikeHTTPURL() {
support, err := uri.checkSeverSupportsRangeHeader()
if err != nil {
return fmt.Errorf("failed to check if uri server supports range header: %v", err)
}
if support {
startPos := tmpFileInfo.Size()
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", startPos))
} else {
err := removePartialFile(tmpFilePath)
if err != nil {
return err
}
}
} else if !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("failed to check file %q existence: %v", filePath, err)
}
var source io.ReadCloser
var contentLength int64
if _, e := os.Stat(uri.ResolveURL()); strings.HasPrefix(string(uri), LocalPrefix) || e == nil {
file, err := os.Open(uri.ResolveURL())
if err != nil {
return fmt.Errorf("failed to open file %q: %v", uri.ResolveURL(), err)
}
l, err := file.Stat()
if err != nil {
return fmt.Errorf("failed to get file size %q: %v", uri.ResolveURL(), err)
}
source = file
contentLength = l.Size()
} else {
// Start the request
resp, err := downloadClient.Do(req)
if err != nil {
// Check if error is due to context cancellation
if errors.Is(err, context.Canceled) {
// Clean up partial file on cancellation
removePartialFile(tmpFilePath)
return err
}
return fmt.Errorf("failed to download file %q: %v", filePath, err)
}
//defer resp.Body.Close()
if resp.StatusCode >= 400 {
return fmt.Errorf("failed to download url %q, invalid status code %d", url, resp.StatusCode)
}
source = resp.Body
contentLength = resp.ContentLength
}
defer source.Close()
// Create parent directory
err = os.MkdirAll(filepath.Dir(filePath), 0750)
if err != nil {
return fmt.Errorf("failed to create parent directory for file %q: %v", filePath, err)
}
// Create and write file
outFile, err := os.OpenFile(tmpFilePath, os.O_APPEND|os.O_RDWR|os.O_CREATE, 0644)
if err != nil {
return fmt.Errorf("failed to create / open file %q: %v", tmpFilePath, err)
}
defer outFile.Close()
hash, err := calculateHashForPartialFile(outFile)
if err != nil {
return fmt.Errorf("failed to calculate hash for partial file")
}
progress := &progressWriter{
fileName: tmpFilePath,
total: contentLength,
hash: hash,
fileNo: fileN,
totalFiles: total,
downloadStatus: downloadStatus,
ctx: ctx,
}
_, err = xio.Copy(ctx, io.MultiWriter(outFile, progress), source)
if err != nil {
// Check if error is due to context cancellation
if errors.Is(err, context.Canceled) {
// Clean up partial file on cancellation
removePartialFile(tmpFilePath)
return err
}
return fmt.Errorf("failed to write file %q: %v", filePath, err)
}
// Check for cancellation before finalizing
select {
case <-ctx.Done():
removePartialFile(tmpFilePath)
return ctx.Err()
default:
}
// Invariant: verify the streamed hash before promoting the temp file to
// the final path. Renaming first would leave tampered content reachable
// to subsequent readers even though we return an error.
if sha != "" {
calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil))
if calculatedSHA != sha {
xlog.Debug("SHA mismatch for file", "file", filePath, "calculated", calculatedSHA, "metadata", sha)
_ = removePartialFile(tmpFilePath)
return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", filePath, calculatedSHA, sha)
}
} else {
// Visible at the default log level so missing-digest configs are
// noticed; silent acceptance was the historical bug.
xlog.Warn("downloading without integrity check — supplied SHA is empty",
"file", filePath,
"url", url,
)
}
err = os.Rename(tmpFilePath, filePath)
if err != nil {
return fmt.Errorf("failed to rename temporary file %s -> %s: %v", tmpFilePath, filePath, err)
}
xlog.Info("File downloaded and verified", "file", filePath)
if utils.IsArchive(filePath) {
basePath := filepath.Dir(filePath)
xlog.Info("File is an archive, uncompressing", "file", filePath, "basePath", basePath)
if err := utils.ExtractArchive(filePath, basePath); err != nil {
xlog.Debug("Failed decompressing", "file", filePath, "error", err)
return err
}
}
return nil
}
func formatBytes(bytes int64) string {
const unit = 1024
if bytes < unit {
return strconv.FormatInt(bytes, 10) + " B"
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
}
func CalculateSHA(filePath string) (string, error) {
file, err := os.Open(filePath)
if err != nil {
return "", err
}
defer file.Close()
hash := sha256.New()
if _, err := io.Copy(hash, file); err != nil {
return "", err
}
return fmt.Sprintf("%x", hash.Sum(nil)), nil
}