From 952635fba636e5ef16b35d4ea0278bfd1207bad2 Mon Sep 17 00:00:00 2001 From: Richard Palethorpe Date: Tue, 31 Mar 2026 15:28:13 +0100 Subject: [PATCH] feat(distributed): Avoid resending models to backend nodes (#9193) Signed-off-by: Richard Palethorpe Co-authored-by: Ettore Di Giacinto --- core/services/nodes/file_stager_http.go | 50 +++ core/services/nodes/file_transfer_server.go | 100 ++++- .../nodes/file_transfer_server_test.go | 368 ++++++++++++++++++ pkg/downloader/uri.go | 4 +- 4 files changed, 516 insertions(+), 6 deletions(-) diff --git a/core/services/nodes/file_stager_http.go b/core/services/nodes/file_stager_http.go index f22ea0f1f..f2a74c433 100644 --- a/core/services/nodes/file_stager_http.go +++ b/core/services/nodes/file_stager_http.go @@ -17,6 +17,7 @@ import ( "time" "github.com/mudler/LocalAI/core/services/storage" + "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/xlog" ) @@ -87,6 +88,12 @@ func (h *HTTPFileStager) EnsureRemote(ctx context.Context, nodeID, localPath, ke return "", fmt.Errorf("resolving HTTP address for node %s: %w", nodeID, err) } + // Probe: check if the remote already has the file with matching content hash. + if remotePath, ok := h.probeExisting(ctx, addr, localPath, key); ok { + xlog.Info("Upload skipped (file already exists with matching hash)", "node", nodeID, "key", key, "remotePath", remotePath) + return remotePath, nil + } + fi, err := os.Stat(localPath) if err != nil { return "", fmt.Errorf("stat local file %s: %w", localPath, err) @@ -217,6 +224,49 @@ func isTransientError(err error) bool { return false } +// probeExisting sends a HEAD request to check if the remote already has the +// file with a matching SHA-256 hash. Returns the remote path and true if the +// upload can be skipped. Any errors (including 405 from older servers) silently +// fall through so the caller proceeds with a normal PUT. +func (h *HTTPFileStager) probeExisting(ctx context.Context, addr, localPath, key string) (string, bool) { + url := fmt.Sprintf("http://%s/v1/files/%s", addr, key) + + req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) + if err != nil { + return "", false + } + if h.token != "" { + req.Header.Set("Authorization", "Bearer "+h.token) + } + + resp, err := h.client.Do(req) + if err != nil { + return "", false + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", false + } + + remotePath := resp.Header.Get(HeaderLocalPath) + remoteHash := resp.Header.Get(HeaderContentSHA256) + if remotePath == "" || remoteHash == "" { + return "", false + } + + localHash, err := downloader.CalculateSHA(localPath) + if err != nil { + return "", false + } + + if localHash != remoteHash { + return "", false + } + + return remotePath, true +} + // progressReader wraps an io.Reader and logs upload progress periodically. type progressReader struct { reader io.Reader diff --git a/core/services/nodes/file_transfer_server.go b/core/services/nodes/file_transfer_server.go index b60713c16..611fa2974 100644 --- a/core/services/nodes/file_transfer_server.go +++ b/core/services/nodes/file_transfer_server.go @@ -2,6 +2,9 @@ package nodes import ( "context" + "crypto/sha256" + "crypto/subtle" + "encoding/hex" "encoding/json" "fmt" "io" @@ -10,18 +13,26 @@ import ( "net/url" "os" "path/filepath" + "strconv" "strings" "sync" "time" - "crypto/subtle" - "github.com/gorilla/websocket" "github.com/mudler/LocalAI/core/services/storage" + "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/xlog" ) +// Headers used by the HEAD probe / skip-if-exists protocol. +const ( + HeaderContentSHA256 = "X-Content-SHA256" + HeaderLocalPath = "X-Local-Path" + HeaderFileSize = "X-File-Size" + hashSidecarSuffix = ".sha256" +) + // StartFileTransferServer starts a small HTTP server for file transfer in distributed mode. // It provides PUT/GET/POST endpoints for uploading, downloading, and allocating temp files, // as well as backend log REST and WebSocket endpoints when logStore is non-nil. @@ -72,6 +83,8 @@ func StartFileTransferServerWithListener(lis net.Listener, stagingDir, modelsDir xlog.Debug("HTTP file transfer request", "method", r.Method, "key", key, "remote", r.RemoteAddr) switch r.Method { + case http.MethodHead: + handleHead(w, r, stagingDir, modelsDir, dataDir, key) case http.MethodPut: handleUpload(w, r, stagingDir, modelsDir, dataDir, key, maxUploadSize) case http.MethodGet: @@ -112,6 +125,47 @@ func StartFileTransferServerWithListener(lis net.Listener, stagingDir, modelsDir return server, nil } +func handleHead(w http.ResponseWriter, r *http.Request, stagingDir, modelsDir, dataDir, key string) { + if key == "" { + http.Error(w, "key is required", http.StatusBadRequest) + return + } + + targetDir, relName := resolveKeyToDir(key, stagingDir, modelsDir, dataDir) + filePath := filepath.Join(targetDir, relName) + + if err := validatePathInDir(filePath, targetDir); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + info, err := os.Stat(filePath) + if err != nil { + if os.IsNotExist(err) { + http.Error(w, "not found", http.StatusNotFound) + } else { + http.Error(w, fmt.Sprintf("stat error: %v", err), http.StatusInternalServerError) + } + return + } + if info.IsDir() { + http.Error(w, "is a directory", http.StatusBadRequest) + return + } + + w.Header().Set(HeaderFileSize, strconv.FormatInt(info.Size(), 10)) + w.Header().Set(HeaderLocalPath, filePath) + + hashHex, err := computeAndCacheHash(filePath) + if err != nil { + xlog.Warn("Failed to compute hash for HEAD", "path", filePath, "error", err) + } else { + w.Header().Set(HeaderContentSHA256, hashHex) + } + + w.WriteHeader(http.StatusOK) +} + func handleUpload(w http.ResponseWriter, r *http.Request, stagingDir, modelsDir, dataDir, key string, maxUploadSize int64) { if key == "" { http.Error(w, "key is required", http.StatusBadRequest) @@ -146,15 +200,22 @@ func handleUpload(w http.ResponseWriter, r *http.Request, stagingDir, modelsDir, } defer f.Close() - n, err := io.Copy(f, r.Body) + hasher := sha256.New() + n, err := io.Copy(f, io.TeeReader(r.Body, hasher)) if err != nil { os.Remove(dstPath) + os.Remove(dstPath + hashSidecarSuffix) xlog.Error("File upload failed", "key", key, "bytesReceived", n, "contentLength", r.ContentLength, "remote", r.RemoteAddr, "error", err) http.Error(w, fmt.Sprintf("writing file: %v", err), http.StatusInternalServerError) return } - xlog.Info("File upload complete", "key", key, "path", dstPath, "size", n) + hashHex := hex.EncodeToString(hasher.Sum(nil)) + if err := os.WriteFile(dstPath+hashSidecarSuffix, []byte(hashHex), 0640); err != nil { + xlog.Warn("Failed to write hash sidecar", "path", dstPath+hashSidecarSuffix, "error", err) + } + + xlog.Info("File upload complete", "key", key, "path", dstPath, "size", n, "sha256", hashHex) w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(map[string]string{"local_path": dstPath}); err != nil { @@ -162,6 +223,37 @@ func handleUpload(w http.ResponseWriter, r *http.Request, stagingDir, modelsDir, } } +// computeAndCacheHash returns the SHA-256 hex digest for filePath. +// It reads a cached sidecar when available and still fresh (sidecar mtime >= +// file mtime), otherwise computes the hash and writes/updates the sidecar. +func computeAndCacheHash(filePath string) (string, error) { + sidecar := filePath + hashSidecarSuffix + + fileStat, err := os.Stat(filePath) + if err != nil { + return "", err + } + + if sidecarStat, err := os.Stat(sidecar); err == nil && !sidecarStat.ModTime().Before(fileStat.ModTime()) { + if data, err := os.ReadFile(sidecar); err == nil { + h := strings.TrimSpace(string(data)) + if len(h) == 64 { // valid hex-encoded SHA-256 + return h, nil + } + } + } + + hashHex, err := downloader.CalculateSHA(filePath) + if err != nil { + return "", err + } + + if err := os.WriteFile(sidecar, []byte(hashHex), 0640); err != nil { + xlog.Warn("Failed to write hash sidecar", "path", sidecar, "error", err) + } + return hashHex, nil +} + func handleDownload(w http.ResponseWriter, r *http.Request, stagingDir, modelsDir, dataDir, key string) { if key == "" { http.Error(w, "key is required", http.StatusBadRequest) diff --git a/core/services/nodes/file_transfer_server_test.go b/core/services/nodes/file_transfer_server_test.go index 5c352bfa8..e5ed9ba2e 100644 --- a/core/services/nodes/file_transfer_server_test.go +++ b/core/services/nodes/file_transfer_server_test.go @@ -2,6 +2,9 @@ package nodes import ( "bytes" + "context" + "crypto/sha256" + "encoding/hex" "io" "net/http" "net/http/httptest" @@ -28,6 +31,8 @@ var _ = Describe("FileTransferServer", func() { } key := strings.TrimPrefix(r.URL.Path, "/v1/files/") switch r.Method { + case http.MethodHead: + handleHead(w, r, stagingDir, modelsDir, dataDir, key) case http.MethodPut: handleUpload(w, r, stagingDir, modelsDir, dataDir, key, maxUploadSize) case http.MethodGet: @@ -196,4 +201,367 @@ var _ = Describe("FileTransferServer", func() { Expect(string(body)).To(Equal(content)) }) }) + + // --- HEAD handler tests --- + + Describe("HEAD probe", func() { + It("returns 404 for non-existent file", func() { + ts, _, _, _ := setupTestServer("tok", 0) + + req, err := http.NewRequest(http.MethodHead, ts.URL+"/v1/files/missing.bin", nil) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Authorization", "Bearer tok") + resp, err := http.DefaultClient.Do(req) + Expect(err).ToNot(HaveOccurred()) + resp.Body.Close() + Expect(resp.StatusCode).To(Equal(http.StatusNotFound)) + }) + + It("returns 200 with hash, size, and path for an uploaded file", func() { + ts, _, _, _ := setupTestServer("tok", 0) + + content := "hello distributed world" + expectedHash := sha256Hex([]byte(content)) + + // Upload first + putReq, err := http.NewRequest(http.MethodPut, ts.URL+"/v1/files/probe.txt", strings.NewReader(content)) + Expect(err).ToNot(HaveOccurred()) + putReq.Header.Set("Authorization", "Bearer tok") + putResp, err := http.DefaultClient.Do(putReq) + Expect(err).ToNot(HaveOccurred()) + putResp.Body.Close() + Expect(putResp.StatusCode).To(Equal(http.StatusOK)) + + // HEAD should return hash + size + path + headReq, err := http.NewRequest(http.MethodHead, ts.URL+"/v1/files/probe.txt", nil) + Expect(err).ToNot(HaveOccurred()) + headReq.Header.Set("Authorization", "Bearer tok") + headResp, err := http.DefaultClient.Do(headReq) + Expect(err).ToNot(HaveOccurred()) + headResp.Body.Close() + Expect(headResp.StatusCode).To(Equal(http.StatusOK)) + Expect(headResp.Header.Get(HeaderContentSHA256)).To(Equal(expectedHash)) + Expect(headResp.Header.Get(HeaderFileSize)).To(Equal("23")) + Expect(headResp.Header.Get(HeaderLocalPath)).ToNot(BeEmpty()) + }) + + It("routes models/ prefixed keys to modelsDir", func() { + ts, _, modelsDir, _ := setupTestServer("tok", 0) + + content := "model data" + + putReq, err := http.NewRequest(http.MethodPut, ts.URL+"/v1/files/models/test/w.bin", strings.NewReader(content)) + Expect(err).ToNot(HaveOccurred()) + putReq.Header.Set("Authorization", "Bearer tok") + putResp, err := http.DefaultClient.Do(putReq) + Expect(err).ToNot(HaveOccurred()) + putResp.Body.Close() + + headReq, err := http.NewRequest(http.MethodHead, ts.URL+"/v1/files/models/test/w.bin", nil) + Expect(err).ToNot(HaveOccurred()) + headReq.Header.Set("Authorization", "Bearer tok") + headResp, err := http.DefaultClient.Do(headReq) + Expect(err).ToNot(HaveOccurred()) + headResp.Body.Close() + Expect(headResp.StatusCode).To(Equal(http.StatusOK)) + Expect(headResp.Header.Get(HeaderLocalPath)).To(HavePrefix(modelsDir)) + }) + + It("computes and caches hash for pre-existing file without sidecar", func() { + ts, stagingDir, _, _ := setupTestServer("tok", 0) + + // Write file directly (no upload, no sidecar) + content := []byte("pre-existing content") + err := os.WriteFile(filepath.Join(stagingDir, "legacy.bin"), content, 0644) + Expect(err).ToNot(HaveOccurred()) + + expectedHash := sha256Hex(content) + + headReq, err := http.NewRequest(http.MethodHead, ts.URL+"/v1/files/legacy.bin", nil) + Expect(err).ToNot(HaveOccurred()) + headReq.Header.Set("Authorization", "Bearer tok") + headResp, err := http.DefaultClient.Do(headReq) + Expect(err).ToNot(HaveOccurred()) + headResp.Body.Close() + Expect(headResp.StatusCode).To(Equal(http.StatusOK)) + Expect(headResp.Header.Get(HeaderContentSHA256)).To(Equal(expectedHash)) + + // Sidecar should now be cached + sidecar, err := os.ReadFile(filepath.Join(stagingDir, "legacy.bin.sha256")) + Expect(err).ToNot(HaveOccurred()) + Expect(string(sidecar)).To(Equal(expectedHash)) + }) + + It("returns updated hash after re-upload with different content", func() { + ts, _, _, _ := setupTestServer("tok", 0) + + // Upload v1 + putReq, err := http.NewRequest(http.MethodPut, ts.URL+"/v1/files/changing.bin", strings.NewReader("version1")) + Expect(err).ToNot(HaveOccurred()) + putReq.Header.Set("Authorization", "Bearer tok") + putResp, err := http.DefaultClient.Do(putReq) + Expect(err).ToNot(HaveOccurred()) + putResp.Body.Close() + + hash1 := sha256Hex([]byte("version1")) + + headReq, err := http.NewRequest(http.MethodHead, ts.URL+"/v1/files/changing.bin", nil) + Expect(err).ToNot(HaveOccurred()) + headReq.Header.Set("Authorization", "Bearer tok") + headResp, err := http.DefaultClient.Do(headReq) + Expect(err).ToNot(HaveOccurred()) + headResp.Body.Close() + Expect(headResp.Header.Get(HeaderContentSHA256)).To(Equal(hash1)) + + // Re-upload v2 + putReq2, err := http.NewRequest(http.MethodPut, ts.URL+"/v1/files/changing.bin", strings.NewReader("version2")) + Expect(err).ToNot(HaveOccurred()) + putReq2.Header.Set("Authorization", "Bearer tok") + putResp2, err := http.DefaultClient.Do(putReq2) + Expect(err).ToNot(HaveOccurred()) + putResp2.Body.Close() + + hash2 := sha256Hex([]byte("version2")) + Expect(hash2).ToNot(Equal(hash1)) + + headReq2, err := http.NewRequest(http.MethodHead, ts.URL+"/v1/files/changing.bin", nil) + Expect(err).ToNot(HaveOccurred()) + headReq2.Header.Set("Authorization", "Bearer tok") + headResp2, err := http.DefaultClient.Do(headReq2) + Expect(err).ToNot(HaveOccurred()) + headResp2.Body.Close() + Expect(headResp2.Header.Get(HeaderContentSHA256)).To(Equal(hash2)) + }) + + It("enforces bearer token auth on HEAD", func() { + ts, _, _, _ := setupTestServer("secret", 0) + + content := "authed" + putReq, err := http.NewRequest(http.MethodPut, ts.URL+"/v1/files/auth-head.txt", strings.NewReader(content)) + Expect(err).ToNot(HaveOccurred()) + putReq.Header.Set("Authorization", "Bearer secret") + putResp, err := http.DefaultClient.Do(putReq) + Expect(err).ToNot(HaveOccurred()) + putResp.Body.Close() + + // No token + req, err := http.NewRequest(http.MethodHead, ts.URL+"/v1/files/auth-head.txt", nil) + Expect(err).ToNot(HaveOccurred()) + resp, err := http.DefaultClient.Do(req) + Expect(err).ToNot(HaveOccurred()) + resp.Body.Close() + Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized)) + + // Wrong token + req2, err := http.NewRequest(http.MethodHead, ts.URL+"/v1/files/auth-head.txt", nil) + Expect(err).ToNot(HaveOccurred()) + req2.Header.Set("Authorization", "Bearer wrong") + resp2, err := http.DefaultClient.Do(req2) + Expect(err).ToNot(HaveOccurred()) + resp2.Body.Close() + Expect(resp2.StatusCode).To(Equal(http.StatusUnauthorized)) + }) + + It("rejects path traversal on HEAD", func() { + ts, stagingDir, _, _ := setupTestServer("tok", 0) + + // Place a file that a traversal path might resolve to + target := filepath.Join(filepath.Dir(stagingDir), "escape.txt") + err := os.WriteFile(target, []byte("secret"), 0644) + Expect(err).ToNot(HaveOccurred()) + DeferCleanup(func() { os.Remove(target) }) + + // Use a raw request to prevent Go's URL cleaning from collapsing ".." + req, err := http.NewRequest(http.MethodHead, ts.URL+"/v1/files/../escape.txt", nil) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("Authorization", "Bearer tok") + resp, err := http.DefaultClient.Do(req) + Expect(err).ToNot(HaveOccurred()) + resp.Body.Close() + // Go's HTTP mux cleans the URL, so the request becomes /v1/files/escape.txt + // which resolves inside stagingDir. The traversal is handled by validatePathInDir. + // A direct "../" in the key is caught as 400, but Go cleans it first. + // Either 400 (traversal blocked) or 404 (file not in staging) is acceptable. + Expect(resp.StatusCode).To(SatisfyAny( + Equal(http.StatusBadRequest), + Equal(http.StatusNotFound), + )) + }) + }) + + // --- Upload sidecar tests --- + + Describe("Upload hash sidecar", func() { + It("creates a .sha256 sidecar alongside uploaded file", func() { + ts, stagingDir, _, _ := setupTestServer("tok", 0) + + content := "sidecar test" + expectedHash := sha256Hex([]byte(content)) + + putReq, err := http.NewRequest(http.MethodPut, ts.URL+"/v1/files/sidecar.txt", strings.NewReader(content)) + Expect(err).ToNot(HaveOccurred()) + putReq.Header.Set("Authorization", "Bearer tok") + putResp, err := http.DefaultClient.Do(putReq) + Expect(err).ToNot(HaveOccurred()) + putResp.Body.Close() + Expect(putResp.StatusCode).To(Equal(http.StatusOK)) + + sidecar, err := os.ReadFile(filepath.Join(stagingDir, "sidecar.txt.sha256")) + Expect(err).ToNot(HaveOccurred()) + Expect(string(sidecar)).To(Equal(expectedHash)) + }) + + It("creates sidecar under modelsDir for models/ prefix", func() { + ts, _, modelsDir, _ := setupTestServer("tok", 0) + + content := "model sidecar test" + expectedHash := sha256Hex([]byte(content)) + + putReq, err := http.NewRequest(http.MethodPut, ts.URL+"/v1/files/models/m1/w.bin", strings.NewReader(content)) + Expect(err).ToNot(HaveOccurred()) + putReq.Header.Set("Authorization", "Bearer tok") + putResp, err := http.DefaultClient.Do(putReq) + Expect(err).ToNot(HaveOccurred()) + putResp.Body.Close() + + sidecar, err := os.ReadFile(filepath.Join(modelsDir, "m1", "w.bin.sha256")) + Expect(err).ToNot(HaveOccurred()) + Expect(string(sidecar)).To(Equal(expectedHash)) + }) + }) + + // --- EnsureRemote skip tests --- + + Describe("EnsureRemote skip-if-exists", func() { + It("skips upload when file exists with matching hash", func() { + ts, stagingDir, _, _ := setupTestServer("tok", 0) + + content := []byte("already on worker") + expectedHash := sha256Hex(content) + + // Pre-place file and sidecar on the "worker" + err := os.WriteFile(filepath.Join(stagingDir, "present.bin"), content, 0644) + Expect(err).ToNot(HaveOccurred()) + err = os.WriteFile(filepath.Join(stagingDir, "present.bin.sha256"), []byte(expectedHash), 0644) + Expect(err).ToNot(HaveOccurred()) + + // Create matching local file + localDir := GinkgoT().TempDir() + localPath := filepath.Join(localDir, "present.bin") + err = os.WriteFile(localPath, content, 0644) + Expect(err).ToNot(HaveOccurred()) + + addr := strings.TrimPrefix(ts.URL, "http://") + stager := NewHTTPFileStager(func(nodeID string) (string, error) { + return addr, nil + }, "tok") + + remotePath, err := stager.EnsureRemote(context.Background(), "node-1", localPath, "present.bin") + Expect(err).ToNot(HaveOccurred()) + Expect(remotePath).To(Equal(filepath.Join(stagingDir, "present.bin"))) + }) + + It("uploads when file exists but hash differs", func() { + ts, stagingDir, _, _ := setupTestServer("tok", 0) + + oldContent := []byte("old version") + oldHash := sha256Hex(oldContent) + + // Pre-place old file and sidecar + err := os.WriteFile(filepath.Join(stagingDir, "changed.bin"), oldContent, 0644) + Expect(err).ToNot(HaveOccurred()) + err = os.WriteFile(filepath.Join(stagingDir, "changed.bin.sha256"), []byte(oldHash), 0644) + Expect(err).ToNot(HaveOccurred()) + + // Create local file with different content + newContent := []byte("new version") + localDir := GinkgoT().TempDir() + localPath := filepath.Join(localDir, "changed.bin") + err = os.WriteFile(localPath, newContent, 0644) + Expect(err).ToNot(HaveOccurred()) + + addr := strings.TrimPrefix(ts.URL, "http://") + stager := NewHTTPFileStager(func(nodeID string) (string, error) { + return addr, nil + }, "tok") + + remotePath, err := stager.EnsureRemote(context.Background(), "node-1", localPath, "changed.bin") + Expect(err).ToNot(HaveOccurred()) + Expect(remotePath).ToNot(BeEmpty()) + + // Verify new content was uploaded + uploaded, err := os.ReadFile(filepath.Join(stagingDir, "changed.bin")) + Expect(err).ToNot(HaveOccurred()) + Expect(uploaded).To(Equal(newContent)) + + // Verify sidecar was updated + sidecar, err := os.ReadFile(filepath.Join(stagingDir, "changed.bin.sha256")) + Expect(err).ToNot(HaveOccurred()) + Expect(string(sidecar)).To(Equal(sha256Hex(newContent))) + }) + + It("uploads when HEAD returns 404", func() { + ts, stagingDir, _, _ := setupTestServer("tok", 0) + + content := []byte("fresh upload") + localDir := GinkgoT().TempDir() + localPath := filepath.Join(localDir, "new.bin") + err := os.WriteFile(localPath, content, 0644) + Expect(err).ToNot(HaveOccurred()) + + addr := strings.TrimPrefix(ts.URL, "http://") + stager := NewHTTPFileStager(func(nodeID string) (string, error) { + return addr, nil + }, "tok") + + remotePath, err := stager.EnsureRemote(context.Background(), "node-1", localPath, "new.bin") + Expect(err).ToNot(HaveOccurred()) + Expect(remotePath).ToNot(BeEmpty()) + + uploaded, err := os.ReadFile(filepath.Join(stagingDir, "new.bin")) + Expect(err).ToNot(HaveOccurred()) + Expect(uploaded).To(Equal(content)) + }) + + It("falls through to upload when HEAD returns 405 (old server)", func() { + // Set up a server that does NOT support HEAD (simulates old server) + stagingDir := GinkgoT().TempDir() + mux := http.NewServeMux() + mux.HandleFunc("/v1/files/", func(w http.ResponseWriter, r *http.Request) { + key := strings.TrimPrefix(r.URL.Path, "/v1/files/") + switch r.Method { + case http.MethodPut: + handleUpload(w, r, stagingDir, "", "", key, 0) + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } + }) + ts := httptest.NewServer(mux) + DeferCleanup(ts.Close) + + content := []byte("should still upload") + localDir := GinkgoT().TempDir() + localPath := filepath.Join(localDir, "compat.bin") + err := os.WriteFile(localPath, content, 0644) + Expect(err).ToNot(HaveOccurred()) + + addr := strings.TrimPrefix(ts.URL, "http://") + stager := NewHTTPFileStager(func(nodeID string) (string, error) { + return addr, nil + }, "") + + remotePath, err := stager.EnsureRemote(context.Background(), "node-1", localPath, "compat.bin") + Expect(err).ToNot(HaveOccurred()) + Expect(remotePath).ToNot(BeEmpty()) + + uploaded, err := os.ReadFile(filepath.Join(stagingDir, "compat.bin")) + Expect(err).ToNot(HaveOccurred()) + Expect(uploaded).To(Equal(content)) + }) + }) }) + +func sha256Hex(data []byte) string { + h := sha256.Sum256(data) + return hex.EncodeToString(h[:]) +} diff --git a/pkg/downloader/uri.go b/pkg/downloader/uri.go index 43edc0992..ed5d6080e 100644 --- a/pkg/downloader/uri.go +++ b/pkg/downloader/uri.go @@ -433,7 +433,7 @@ func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string // File exists, check SHA if sha != "" { // Verify SHA - calculatedSHA, err := calculateSHA(filePath) + calculatedSHA, err := CalculateSHA(filePath) if err != nil { return fmt.Errorf("failed to calculate SHA for file %q: %v", filePath, err) } @@ -609,7 +609,7 @@ func formatBytes(bytes int64) string { return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp]) } -func calculateSHA(filePath string) (string, error) { +func CalculateSHA(filePath string) (string, error) { file, err := os.Open(filePath) if err != nil { return "", err