mirror of
https://github.com/mudler/LocalAI.git
synced 2026-03-31 21:25:59 -04:00
feat(distributed): Avoid resending models to backend nodes (#9193)
Signed-off-by: Richard Palethorpe <io@richiejp.com> Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
3cc05af2e5
commit
952635fba6
@@ -17,6 +17,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/storage"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
@@ -87,6 +88,12 @@ func (h *HTTPFileStager) EnsureRemote(ctx context.Context, nodeID, localPath, ke
|
||||
return "", fmt.Errorf("resolving HTTP address for node %s: %w", nodeID, err)
|
||||
}
|
||||
|
||||
// Probe: check if the remote already has the file with matching content hash.
|
||||
if remotePath, ok := h.probeExisting(ctx, addr, localPath, key); ok {
|
||||
xlog.Info("Upload skipped (file already exists with matching hash)", "node", nodeID, "key", key, "remotePath", remotePath)
|
||||
return remotePath, nil
|
||||
}
|
||||
|
||||
fi, err := os.Stat(localPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("stat local file %s: %w", localPath, err)
|
||||
@@ -217,6 +224,49 @@ func isTransientError(err error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// probeExisting sends a HEAD request to check if the remote already has the
|
||||
// file with a matching SHA-256 hash. Returns the remote path and true if the
|
||||
// upload can be skipped. Any errors (including 405 from older servers) silently
|
||||
// fall through so the caller proceeds with a normal PUT.
|
||||
func (h *HTTPFileStager) probeExisting(ctx context.Context, addr, localPath, key string) (string, bool) {
|
||||
url := fmt.Sprintf("http://%s/v1/files/%s", addr, key)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
if h.token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+h.token)
|
||||
}
|
||||
|
||||
resp, err := h.client.Do(req)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", false
|
||||
}
|
||||
|
||||
remotePath := resp.Header.Get(HeaderLocalPath)
|
||||
remoteHash := resp.Header.Get(HeaderContentSHA256)
|
||||
if remotePath == "" || remoteHash == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
localHash, err := downloader.CalculateSHA(localPath)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
if localHash != remoteHash {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return remotePath, true
|
||||
}
|
||||
|
||||
// progressReader wraps an io.Reader and logs upload progress periodically.
|
||||
type progressReader struct {
|
||||
reader io.Reader
|
||||
|
||||
@@ -2,6 +2,9 @@ package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -10,18 +13,26 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"crypto/subtle"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/mudler/LocalAI/core/services/storage"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// Headers used by the HEAD probe / skip-if-exists protocol.
|
||||
const (
|
||||
HeaderContentSHA256 = "X-Content-SHA256"
|
||||
HeaderLocalPath = "X-Local-Path"
|
||||
HeaderFileSize = "X-File-Size"
|
||||
hashSidecarSuffix = ".sha256"
|
||||
)
|
||||
|
||||
// StartFileTransferServer starts a small HTTP server for file transfer in distributed mode.
|
||||
// It provides PUT/GET/POST endpoints for uploading, downloading, and allocating temp files,
|
||||
// as well as backend log REST and WebSocket endpoints when logStore is non-nil.
|
||||
@@ -72,6 +83,8 @@ func StartFileTransferServerWithListener(lis net.Listener, stagingDir, modelsDir
|
||||
xlog.Debug("HTTP file transfer request", "method", r.Method, "key", key, "remote", r.RemoteAddr)
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodHead:
|
||||
handleHead(w, r, stagingDir, modelsDir, dataDir, key)
|
||||
case http.MethodPut:
|
||||
handleUpload(w, r, stagingDir, modelsDir, dataDir, key, maxUploadSize)
|
||||
case http.MethodGet:
|
||||
@@ -112,6 +125,47 @@ func StartFileTransferServerWithListener(lis net.Listener, stagingDir, modelsDir
|
||||
return server, nil
|
||||
}
|
||||
|
||||
func handleHead(w http.ResponseWriter, r *http.Request, stagingDir, modelsDir, dataDir, key string) {
|
||||
if key == "" {
|
||||
http.Error(w, "key is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
targetDir, relName := resolveKeyToDir(key, stagingDir, modelsDir, dataDir)
|
||||
filePath := filepath.Join(targetDir, relName)
|
||||
|
||||
if err := validatePathInDir(filePath, targetDir); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
info, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
} else {
|
||||
http.Error(w, fmt.Sprintf("stat error: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
if info.IsDir() {
|
||||
http.Error(w, "is a directory", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set(HeaderFileSize, strconv.FormatInt(info.Size(), 10))
|
||||
w.Header().Set(HeaderLocalPath, filePath)
|
||||
|
||||
hashHex, err := computeAndCacheHash(filePath)
|
||||
if err != nil {
|
||||
xlog.Warn("Failed to compute hash for HEAD", "path", filePath, "error", err)
|
||||
} else {
|
||||
w.Header().Set(HeaderContentSHA256, hashHex)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func handleUpload(w http.ResponseWriter, r *http.Request, stagingDir, modelsDir, dataDir, key string, maxUploadSize int64) {
|
||||
if key == "" {
|
||||
http.Error(w, "key is required", http.StatusBadRequest)
|
||||
@@ -146,15 +200,22 @@ func handleUpload(w http.ResponseWriter, r *http.Request, stagingDir, modelsDir,
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
n, err := io.Copy(f, r.Body)
|
||||
hasher := sha256.New()
|
||||
n, err := io.Copy(f, io.TeeReader(r.Body, hasher))
|
||||
if err != nil {
|
||||
os.Remove(dstPath)
|
||||
os.Remove(dstPath + hashSidecarSuffix)
|
||||
xlog.Error("File upload failed", "key", key, "bytesReceived", n, "contentLength", r.ContentLength, "remote", r.RemoteAddr, "error", err)
|
||||
http.Error(w, fmt.Sprintf("writing file: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
xlog.Info("File upload complete", "key", key, "path", dstPath, "size", n)
|
||||
hashHex := hex.EncodeToString(hasher.Sum(nil))
|
||||
if err := os.WriteFile(dstPath+hashSidecarSuffix, []byte(hashHex), 0640); err != nil {
|
||||
xlog.Warn("Failed to write hash sidecar", "path", dstPath+hashSidecarSuffix, "error", err)
|
||||
}
|
||||
|
||||
xlog.Info("File upload complete", "key", key, "path", dstPath, "size", n, "sha256", hashHex)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(map[string]string{"local_path": dstPath}); err != nil {
|
||||
@@ -162,6 +223,37 @@ func handleUpload(w http.ResponseWriter, r *http.Request, stagingDir, modelsDir,
|
||||
}
|
||||
}
|
||||
|
||||
// computeAndCacheHash returns the SHA-256 hex digest for filePath.
|
||||
// It reads a cached sidecar when available and still fresh (sidecar mtime >=
|
||||
// file mtime), otherwise computes the hash and writes/updates the sidecar.
|
||||
func computeAndCacheHash(filePath string) (string, error) {
|
||||
sidecar := filePath + hashSidecarSuffix
|
||||
|
||||
fileStat, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if sidecarStat, err := os.Stat(sidecar); err == nil && !sidecarStat.ModTime().Before(fileStat.ModTime()) {
|
||||
if data, err := os.ReadFile(sidecar); err == nil {
|
||||
h := strings.TrimSpace(string(data))
|
||||
if len(h) == 64 { // valid hex-encoded SHA-256
|
||||
return h, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
hashHex, err := downloader.CalculateSHA(filePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := os.WriteFile(sidecar, []byte(hashHex), 0640); err != nil {
|
||||
xlog.Warn("Failed to write hash sidecar", "path", sidecar, "error", err)
|
||||
}
|
||||
return hashHex, nil
|
||||
}
|
||||
|
||||
func handleDownload(w http.ResponseWriter, r *http.Request, stagingDir, modelsDir, dataDir, key string) {
|
||||
if key == "" {
|
||||
http.Error(w, "key is required", http.StatusBadRequest)
|
||||
|
||||
@@ -2,6 +2,9 @@ package nodes
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -28,6 +31,8 @@ var _ = Describe("FileTransferServer", func() {
|
||||
}
|
||||
key := strings.TrimPrefix(r.URL.Path, "/v1/files/")
|
||||
switch r.Method {
|
||||
case http.MethodHead:
|
||||
handleHead(w, r, stagingDir, modelsDir, dataDir, key)
|
||||
case http.MethodPut:
|
||||
handleUpload(w, r, stagingDir, modelsDir, dataDir, key, maxUploadSize)
|
||||
case http.MethodGet:
|
||||
@@ -196,4 +201,367 @@ var _ = Describe("FileTransferServer", func() {
|
||||
Expect(string(body)).To(Equal(content))
|
||||
})
|
||||
})
|
||||
|
||||
// --- HEAD handler tests ---
|
||||
|
||||
Describe("HEAD probe", func() {
|
||||
It("returns 404 for non-existent file", func() {
|
||||
ts, _, _, _ := setupTestServer("tok", 0)
|
||||
|
||||
req, err := http.NewRequest(http.MethodHead, ts.URL+"/v1/files/missing.bin", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Authorization", "Bearer tok")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
resp.Body.Close()
|
||||
Expect(resp.StatusCode).To(Equal(http.StatusNotFound))
|
||||
})
|
||||
|
||||
It("returns 200 with hash, size, and path for an uploaded file", func() {
|
||||
ts, _, _, _ := setupTestServer("tok", 0)
|
||||
|
||||
content := "hello distributed world"
|
||||
expectedHash := sha256Hex([]byte(content))
|
||||
|
||||
// Upload first
|
||||
putReq, err := http.NewRequest(http.MethodPut, ts.URL+"/v1/files/probe.txt", strings.NewReader(content))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
putReq.Header.Set("Authorization", "Bearer tok")
|
||||
putResp, err := http.DefaultClient.Do(putReq)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
putResp.Body.Close()
|
||||
Expect(putResp.StatusCode).To(Equal(http.StatusOK))
|
||||
|
||||
// HEAD should return hash + size + path
|
||||
headReq, err := http.NewRequest(http.MethodHead, ts.URL+"/v1/files/probe.txt", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
headReq.Header.Set("Authorization", "Bearer tok")
|
||||
headResp, err := http.DefaultClient.Do(headReq)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
headResp.Body.Close()
|
||||
Expect(headResp.StatusCode).To(Equal(http.StatusOK))
|
||||
Expect(headResp.Header.Get(HeaderContentSHA256)).To(Equal(expectedHash))
|
||||
Expect(headResp.Header.Get(HeaderFileSize)).To(Equal("23"))
|
||||
Expect(headResp.Header.Get(HeaderLocalPath)).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("routes models/ prefixed keys to modelsDir", func() {
|
||||
ts, _, modelsDir, _ := setupTestServer("tok", 0)
|
||||
|
||||
content := "model data"
|
||||
|
||||
putReq, err := http.NewRequest(http.MethodPut, ts.URL+"/v1/files/models/test/w.bin", strings.NewReader(content))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
putReq.Header.Set("Authorization", "Bearer tok")
|
||||
putResp, err := http.DefaultClient.Do(putReq)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
putResp.Body.Close()
|
||||
|
||||
headReq, err := http.NewRequest(http.MethodHead, ts.URL+"/v1/files/models/test/w.bin", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
headReq.Header.Set("Authorization", "Bearer tok")
|
||||
headResp, err := http.DefaultClient.Do(headReq)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
headResp.Body.Close()
|
||||
Expect(headResp.StatusCode).To(Equal(http.StatusOK))
|
||||
Expect(headResp.Header.Get(HeaderLocalPath)).To(HavePrefix(modelsDir))
|
||||
})
|
||||
|
||||
It("computes and caches hash for pre-existing file without sidecar", func() {
|
||||
ts, stagingDir, _, _ := setupTestServer("tok", 0)
|
||||
|
||||
// Write file directly (no upload, no sidecar)
|
||||
content := []byte("pre-existing content")
|
||||
err := os.WriteFile(filepath.Join(stagingDir, "legacy.bin"), content, 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
expectedHash := sha256Hex(content)
|
||||
|
||||
headReq, err := http.NewRequest(http.MethodHead, ts.URL+"/v1/files/legacy.bin", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
headReq.Header.Set("Authorization", "Bearer tok")
|
||||
headResp, err := http.DefaultClient.Do(headReq)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
headResp.Body.Close()
|
||||
Expect(headResp.StatusCode).To(Equal(http.StatusOK))
|
||||
Expect(headResp.Header.Get(HeaderContentSHA256)).To(Equal(expectedHash))
|
||||
|
||||
// Sidecar should now be cached
|
||||
sidecar, err := os.ReadFile(filepath.Join(stagingDir, "legacy.bin.sha256"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(sidecar)).To(Equal(expectedHash))
|
||||
})
|
||||
|
||||
It("returns updated hash after re-upload with different content", func() {
|
||||
ts, _, _, _ := setupTestServer("tok", 0)
|
||||
|
||||
// Upload v1
|
||||
putReq, err := http.NewRequest(http.MethodPut, ts.URL+"/v1/files/changing.bin", strings.NewReader("version1"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
putReq.Header.Set("Authorization", "Bearer tok")
|
||||
putResp, err := http.DefaultClient.Do(putReq)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
putResp.Body.Close()
|
||||
|
||||
hash1 := sha256Hex([]byte("version1"))
|
||||
|
||||
headReq, err := http.NewRequest(http.MethodHead, ts.URL+"/v1/files/changing.bin", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
headReq.Header.Set("Authorization", "Bearer tok")
|
||||
headResp, err := http.DefaultClient.Do(headReq)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
headResp.Body.Close()
|
||||
Expect(headResp.Header.Get(HeaderContentSHA256)).To(Equal(hash1))
|
||||
|
||||
// Re-upload v2
|
||||
putReq2, err := http.NewRequest(http.MethodPut, ts.URL+"/v1/files/changing.bin", strings.NewReader("version2"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
putReq2.Header.Set("Authorization", "Bearer tok")
|
||||
putResp2, err := http.DefaultClient.Do(putReq2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
putResp2.Body.Close()
|
||||
|
||||
hash2 := sha256Hex([]byte("version2"))
|
||||
Expect(hash2).ToNot(Equal(hash1))
|
||||
|
||||
headReq2, err := http.NewRequest(http.MethodHead, ts.URL+"/v1/files/changing.bin", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
headReq2.Header.Set("Authorization", "Bearer tok")
|
||||
headResp2, err := http.DefaultClient.Do(headReq2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
headResp2.Body.Close()
|
||||
Expect(headResp2.Header.Get(HeaderContentSHA256)).To(Equal(hash2))
|
||||
})
|
||||
|
||||
It("enforces bearer token auth on HEAD", func() {
|
||||
ts, _, _, _ := setupTestServer("secret", 0)
|
||||
|
||||
content := "authed"
|
||||
putReq, err := http.NewRequest(http.MethodPut, ts.URL+"/v1/files/auth-head.txt", strings.NewReader(content))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
putReq.Header.Set("Authorization", "Bearer secret")
|
||||
putResp, err := http.DefaultClient.Do(putReq)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
putResp.Body.Close()
|
||||
|
||||
// No token
|
||||
req, err := http.NewRequest(http.MethodHead, ts.URL+"/v1/files/auth-head.txt", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
resp.Body.Close()
|
||||
Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized))
|
||||
|
||||
// Wrong token
|
||||
req2, err := http.NewRequest(http.MethodHead, ts.URL+"/v1/files/auth-head.txt", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req2.Header.Set("Authorization", "Bearer wrong")
|
||||
resp2, err := http.DefaultClient.Do(req2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
resp2.Body.Close()
|
||||
Expect(resp2.StatusCode).To(Equal(http.StatusUnauthorized))
|
||||
})
|
||||
|
||||
It("rejects path traversal on HEAD", func() {
|
||||
ts, stagingDir, _, _ := setupTestServer("tok", 0)
|
||||
|
||||
// Place a file that a traversal path might resolve to
|
||||
target := filepath.Join(filepath.Dir(stagingDir), "escape.txt")
|
||||
err := os.WriteFile(target, []byte("secret"), 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
DeferCleanup(func() { os.Remove(target) })
|
||||
|
||||
// Use a raw request to prevent Go's URL cleaning from collapsing ".."
|
||||
req, err := http.NewRequest(http.MethodHead, ts.URL+"/v1/files/../escape.txt", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Authorization", "Bearer tok")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
resp.Body.Close()
|
||||
// Go's HTTP mux cleans the URL, so the request becomes /v1/files/escape.txt
|
||||
// which resolves inside stagingDir. The traversal is handled by validatePathInDir.
|
||||
// A direct "../" in the key is caught as 400, but Go cleans it first.
|
||||
// Either 400 (traversal blocked) or 404 (file not in staging) is acceptable.
|
||||
Expect(resp.StatusCode).To(SatisfyAny(
|
||||
Equal(http.StatusBadRequest),
|
||||
Equal(http.StatusNotFound),
|
||||
))
|
||||
})
|
||||
})
|
||||
|
||||
// --- Upload sidecar tests ---
|
||||
|
||||
Describe("Upload hash sidecar", func() {
|
||||
It("creates a .sha256 sidecar alongside uploaded file", func() {
|
||||
ts, stagingDir, _, _ := setupTestServer("tok", 0)
|
||||
|
||||
content := "sidecar test"
|
||||
expectedHash := sha256Hex([]byte(content))
|
||||
|
||||
putReq, err := http.NewRequest(http.MethodPut, ts.URL+"/v1/files/sidecar.txt", strings.NewReader(content))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
putReq.Header.Set("Authorization", "Bearer tok")
|
||||
putResp, err := http.DefaultClient.Do(putReq)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
putResp.Body.Close()
|
||||
Expect(putResp.StatusCode).To(Equal(http.StatusOK))
|
||||
|
||||
sidecar, err := os.ReadFile(filepath.Join(stagingDir, "sidecar.txt.sha256"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(sidecar)).To(Equal(expectedHash))
|
||||
})
|
||||
|
||||
It("creates sidecar under modelsDir for models/ prefix", func() {
|
||||
ts, _, modelsDir, _ := setupTestServer("tok", 0)
|
||||
|
||||
content := "model sidecar test"
|
||||
expectedHash := sha256Hex([]byte(content))
|
||||
|
||||
putReq, err := http.NewRequest(http.MethodPut, ts.URL+"/v1/files/models/m1/w.bin", strings.NewReader(content))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
putReq.Header.Set("Authorization", "Bearer tok")
|
||||
putResp, err := http.DefaultClient.Do(putReq)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
putResp.Body.Close()
|
||||
|
||||
sidecar, err := os.ReadFile(filepath.Join(modelsDir, "m1", "w.bin.sha256"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(sidecar)).To(Equal(expectedHash))
|
||||
})
|
||||
})
|
||||
|
||||
// --- EnsureRemote skip tests ---
|
||||
|
||||
Describe("EnsureRemote skip-if-exists", func() {
|
||||
It("skips upload when file exists with matching hash", func() {
|
||||
ts, stagingDir, _, _ := setupTestServer("tok", 0)
|
||||
|
||||
content := []byte("already on worker")
|
||||
expectedHash := sha256Hex(content)
|
||||
|
||||
// Pre-place file and sidecar on the "worker"
|
||||
err := os.WriteFile(filepath.Join(stagingDir, "present.bin"), content, 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = os.WriteFile(filepath.Join(stagingDir, "present.bin.sha256"), []byte(expectedHash), 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Create matching local file
|
||||
localDir := GinkgoT().TempDir()
|
||||
localPath := filepath.Join(localDir, "present.bin")
|
||||
err = os.WriteFile(localPath, content, 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
addr := strings.TrimPrefix(ts.URL, "http://")
|
||||
stager := NewHTTPFileStager(func(nodeID string) (string, error) {
|
||||
return addr, nil
|
||||
}, "tok")
|
||||
|
||||
remotePath, err := stager.EnsureRemote(context.Background(), "node-1", localPath, "present.bin")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(remotePath).To(Equal(filepath.Join(stagingDir, "present.bin")))
|
||||
})
|
||||
|
||||
It("uploads when file exists but hash differs", func() {
|
||||
ts, stagingDir, _, _ := setupTestServer("tok", 0)
|
||||
|
||||
oldContent := []byte("old version")
|
||||
oldHash := sha256Hex(oldContent)
|
||||
|
||||
// Pre-place old file and sidecar
|
||||
err := os.WriteFile(filepath.Join(stagingDir, "changed.bin"), oldContent, 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = os.WriteFile(filepath.Join(stagingDir, "changed.bin.sha256"), []byte(oldHash), 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Create local file with different content
|
||||
newContent := []byte("new version")
|
||||
localDir := GinkgoT().TempDir()
|
||||
localPath := filepath.Join(localDir, "changed.bin")
|
||||
err = os.WriteFile(localPath, newContent, 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
addr := strings.TrimPrefix(ts.URL, "http://")
|
||||
stager := NewHTTPFileStager(func(nodeID string) (string, error) {
|
||||
return addr, nil
|
||||
}, "tok")
|
||||
|
||||
remotePath, err := stager.EnsureRemote(context.Background(), "node-1", localPath, "changed.bin")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(remotePath).ToNot(BeEmpty())
|
||||
|
||||
// Verify new content was uploaded
|
||||
uploaded, err := os.ReadFile(filepath.Join(stagingDir, "changed.bin"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(uploaded).To(Equal(newContent))
|
||||
|
||||
// Verify sidecar was updated
|
||||
sidecar, err := os.ReadFile(filepath.Join(stagingDir, "changed.bin.sha256"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(sidecar)).To(Equal(sha256Hex(newContent)))
|
||||
})
|
||||
|
||||
It("uploads when HEAD returns 404", func() {
|
||||
ts, stagingDir, _, _ := setupTestServer("tok", 0)
|
||||
|
||||
content := []byte("fresh upload")
|
||||
localDir := GinkgoT().TempDir()
|
||||
localPath := filepath.Join(localDir, "new.bin")
|
||||
err := os.WriteFile(localPath, content, 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
addr := strings.TrimPrefix(ts.URL, "http://")
|
||||
stager := NewHTTPFileStager(func(nodeID string) (string, error) {
|
||||
return addr, nil
|
||||
}, "tok")
|
||||
|
||||
remotePath, err := stager.EnsureRemote(context.Background(), "node-1", localPath, "new.bin")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(remotePath).ToNot(BeEmpty())
|
||||
|
||||
uploaded, err := os.ReadFile(filepath.Join(stagingDir, "new.bin"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(uploaded).To(Equal(content))
|
||||
})
|
||||
|
||||
It("falls through to upload when HEAD returns 405 (old server)", func() {
|
||||
// Set up a server that does NOT support HEAD (simulates old server)
|
||||
stagingDir := GinkgoT().TempDir()
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/v1/files/", func(w http.ResponseWriter, r *http.Request) {
|
||||
key := strings.TrimPrefix(r.URL.Path, "/v1/files/")
|
||||
switch r.Method {
|
||||
case http.MethodPut:
|
||||
handleUpload(w, r, stagingDir, "", "", key, 0)
|
||||
default:
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
})
|
||||
ts := httptest.NewServer(mux)
|
||||
DeferCleanup(ts.Close)
|
||||
|
||||
content := []byte("should still upload")
|
||||
localDir := GinkgoT().TempDir()
|
||||
localPath := filepath.Join(localDir, "compat.bin")
|
||||
err := os.WriteFile(localPath, content, 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
addr := strings.TrimPrefix(ts.URL, "http://")
|
||||
stager := NewHTTPFileStager(func(nodeID string) (string, error) {
|
||||
return addr, nil
|
||||
}, "")
|
||||
|
||||
remotePath, err := stager.EnsureRemote(context.Background(), "node-1", localPath, "compat.bin")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(remotePath).ToNot(BeEmpty())
|
||||
|
||||
uploaded, err := os.ReadFile(filepath.Join(stagingDir, "compat.bin"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(uploaded).To(Equal(content))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
func sha256Hex(data []byte) string {
|
||||
h := sha256.Sum256(data)
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
@@ -433,7 +433,7 @@ func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string
|
||||
// File exists, check SHA
|
||||
if sha != "" {
|
||||
// Verify SHA
|
||||
calculatedSHA, err := calculateSHA(filePath)
|
||||
calculatedSHA, err := CalculateSHA(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to calculate SHA for file %q: %v", filePath, err)
|
||||
}
|
||||
@@ -609,7 +609,7 @@ func formatBytes(bytes int64) string {
|
||||
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
func calculateSHA(filePath string) (string, error) {
|
||||
func CalculateSHA(filePath string) (string, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
||||
Reference in New Issue
Block a user