diff --git a/core/application/startup.go b/core/application/startup.go index 6438c7df3..fa5de5ede 100644 --- a/core/application/startup.go +++ b/core/application/startup.go @@ -25,6 +25,7 @@ import ( "github.com/mudler/LocalAI/core/services/storage" coreStartup "github.com/mudler/LocalAI/core/startup" "github.com/mudler/LocalAI/internal" + "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/signals" "github.com/mudler/LocalAI/pkg/vram" @@ -71,6 +72,16 @@ func New(opts ...config.AppOption) (*Application, error) { if err != nil { return nil, fmt.Errorf("unable to create ModelPath: %q", err) } + + // Reap *.partial downloads abandoned by a previous run (killed mid-transfer + // by an OOM/restart, or stalled before cleanup could run). The 24h window + // is well beyond any legitimate in-flight download, so this never trims an + // active transfer; it just stops dead partials accumulating on the volume. + if removed, cErr := downloader.CleanupStalePartialFiles(options.SystemState.Model.ModelsPath, 24*time.Hour); cErr != nil { + xlog.Warn("Failed to reap stale partial downloads", "error", cErr) + } else if removed > 0 { + xlog.Info("Reaped stale partial downloads", "count", removed) + } if options.GeneratedContentDir != "" { err := os.MkdirAll(options.GeneratedContentDir, 0o750) if err != nil { diff --git a/pkg/downloader/cancel_test.go b/pkg/downloader/cancel_test.go new file mode 100644 index 000000000..c9025ffa3 --- /dev/null +++ b/pkg/downloader/cancel_test.go @@ -0,0 +1,125 @@ +package downloader_test + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "os" + "strconv" + "strings" + "time" + + . "github.com/mudler/LocalAI/pkg/downloader" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Download cancellation", func() { + var filePath string + + // streamingRangeServer serves data one small chunk at a time with a short + // pause between chunks, so a context cancellation can land mid-transfer. + // It honors a `bytes=N-` Range request so a second attempt can resume. + streamingRangeServer := func(data []byte) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "HEAD" { + w.Header().Set("Accept-Ranges", "bytes") + w.WriteHeader(http.StatusOK) + return + } + start := 0 + if rh := r.Header.Get("Range"); rh != "" { + _, _ = fmt.Sscanf(strings.TrimPrefix(rh, "bytes="), "%d-", &start) + } + w.Header().Set("Content-Length", strconv.Itoa(len(data)-start)) + if start > 0 { + w.WriteHeader(http.StatusPartialContent) + } else { + w.WriteHeader(http.StatusOK) + } + f, _ := w.(http.Flusher) + for i := start; i < len(data); i += 256 { + end := i + 256 + if end > len(data) { + end = len(data) + } + if _, err := w.Write(data[i:end]); err != nil { + return + } + if f != nil { + f.Flush() + } + time.Sleep(20 * time.Millisecond) + } + })) + } + + BeforeEach(func() { + dir, err := os.Getwd() + Expect(err).ToNot(HaveOccurred()) + filePath = dir + "/cancel_model" + }) + + AfterEach(func() { + _ = os.Remove(filePath) + _ = os.Remove(filePath + ".partial") + }) + + It("keeps the .partial file when the context is cancelled so the download can resume", func() { + data := make([]byte, 8192) + _, err := rand.Read(data) + Expect(err).ToNot(HaveOccurred()) + server := streamingRangeServer(data) + defer server.Close() + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(150 * time.Millisecond) + cancel() + }() + + err = URI(server.URL).DownloadFileWithContext(ctx, filePath, "", 1, 1, func(s1, s2, s3 string, f float64) {}) + Expect(err).To(HaveOccurred()) + Expect(errors.Is(err, context.Canceled)).To(BeTrue()) + + info, statErr := os.Stat(filePath + ".partial") + Expect(statErr).ToNot(HaveOccurred(), + "a cancelled download must leave its .partial behind so the retry resumes instead of restarting from zero") + Expect(info.Size()).To(BeNumerically(">", 0)) + Expect(info.Size()).To(BeNumerically("<", int64(len(data)))) + }) + + It("resumes from the preserved .partial after a cancellation and completes", func() { + data := make([]byte, 8192) + _, err := rand.Read(data) + Expect(err).ToNot(HaveOccurred()) + sum := sha256.Sum256(data) + sha := fmt.Sprintf("%x", sum) + server := streamingRangeServer(data) + defer server.Close() + + // First attempt: cancel mid-stream. + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(150 * time.Millisecond) + cancel() + }() + err = URI(server.URL).DownloadFileWithContext(ctx, filePath, sha, 1, 1, func(s1, s2, s3 string, f float64) {}) + Expect(err).To(HaveOccurred()) + partialInfo, statErr := os.Stat(filePath + ".partial") + Expect(statErr).ToNot(HaveOccurred()) + resumedFrom := partialInfo.Size() + Expect(resumedFrom).To(BeNumerically(">", 0)) + + // Second attempt: fresh context, must resume and finish with a valid SHA. + err = URI(server.URL).DownloadFileWithContext(context.Background(), filePath, sha, 1, 1, func(s1, s2, s3 string, f float64) {}) + Expect(err).ToNot(HaveOccurred()) + final, rerr := os.ReadFile(filePath) + Expect(rerr).ToNot(HaveOccurred()) + Expect(final).To(Equal(data)) + }) +}) diff --git a/pkg/downloader/partial.go b/pkg/downloader/partial.go new file mode 100644 index 000000000..7ad52f171 --- /dev/null +++ b/pkg/downloader/partial.go @@ -0,0 +1,57 @@ +package downloader + +import ( + "io/fs" + "os" + "path/filepath" + "strings" + "time" + + "github.com/mudler/xlog" +) + +// PartialFileSuffix marks an in-progress download. The success path renames the +// partial to its final name, so any leftover with this suffix is an unfinished +// transfer. +const PartialFileSuffix = ".partial" + +// CleanupStalePartialFiles removes *.partial files under root whose last +// modification is older than olderThan, returning the number removed. These are +// abandoned downloads left by a process killed mid-transfer (OOM, restart) or +// by a stall whose cleanup never ran; without reaping they accumulate and can +// fill the models volume. A still-in-progress download touches its .partial on +// every write, so a generous olderThan never trims an active transfer. +// +// A missing root is not an error (nothing to clean). Unreadable entries are +// skipped so one bad file does not abort the whole sweep. +func CleanupStalePartialFiles(root string, olderThan time.Duration) (int, error) { + if _, err := os.Stat(root); err != nil { + if os.IsNotExist(err) { + return 0, nil + } + return 0, err + } + + cutoff := time.Now().Add(-olderThan) + removed := 0 + err := filepath.WalkDir(root, func(path string, d fs.DirEntry, walkErr error) error { + if walkErr != nil { + return nil // skip unreadable subtree, keep going + } + if d.IsDir() || !strings.HasSuffix(d.Name(), PartialFileSuffix) { + return nil + } + info, err := d.Info() + if err != nil || info.ModTime().After(cutoff) { + return nil + } + if err := os.Remove(path); err != nil { + xlog.Warn("failed to remove stale partial download", "file", path, "error", err) + return nil + } + removed++ + xlog.Info("removed stale partial download", "file", path) + return nil + }) + return removed, err +} diff --git a/pkg/downloader/partial_test.go b/pkg/downloader/partial_test.go new file mode 100644 index 000000000..ceec8417f --- /dev/null +++ b/pkg/downloader/partial_test.go @@ -0,0 +1,53 @@ +package downloader_test + +import ( + "os" + "path/filepath" + "time" + + . "github.com/mudler/LocalAI/pkg/downloader" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("CleanupStalePartialFiles", func() { + var root string + + BeforeEach(func() { + var err error + root, err = os.MkdirTemp("", "partials") + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + _ = os.RemoveAll(root) + }) + + It("removes stale .partial files (recursively) while keeping fresh ones and completed files", func() { + nested := filepath.Join(root, "llama-cpp", "models", "foo") + Expect(os.MkdirAll(nested, 0755)).To(Succeed()) + + stale := filepath.Join(nested, "model.gguf.partial") + fresh := filepath.Join(root, "fresh.gguf.partial") + completed := filepath.Join(root, "done.gguf") + for _, f := range []string{stale, fresh, completed} { + Expect(os.WriteFile(f, []byte("data"), 0644)).To(Succeed()) + } + old := time.Now().Add(-2 * time.Hour) + Expect(os.Chtimes(stale, old, old)).To(Succeed()) + + removed, err := CleanupStalePartialFiles(root, time.Hour) + Expect(err).ToNot(HaveOccurred()) + Expect(removed).To(Equal(1)) + + Expect(stale).ToNot(BeAnExistingFile()) + Expect(fresh).To(BeAnExistingFile()) + Expect(completed).To(BeAnExistingFile()) + }) + + It("returns no error when the root directory does not exist", func() { + removed, err := CleanupStalePartialFiles(filepath.Join(root, "does-not-exist"), time.Hour) + Expect(err).ToNot(HaveOccurred()) + Expect(removed).To(Equal(0)) + }) +}) diff --git a/pkg/downloader/stall.go b/pkg/downloader/stall.go new file mode 100644 index 000000000..697ad25d9 --- /dev/null +++ b/pkg/downloader/stall.go @@ -0,0 +1,77 @@ +package downloader + +import ( + "fmt" + "io" + "sync" + "time" +) + +// DownloadStallTimeout bounds how long an in-flight download may receive no +// data before it is aborted. A silently-dropped TCP connection (no FIN/RST) +// would otherwise block the body read forever, freezing an install at N bytes +// until an external reaper kills it. Overridable (tests set it small); a value +// <= 0 disables the guard. +var DownloadStallTimeout = 60 * time.Second + +// idleTimeoutReader wraps a streaming ReadCloser and aborts reads that make no +// progress within timeout. A standard io.Copy blocks indefinitely on a Read +// against a dead-but-unclosed socket; nothing in the copy loop can interrupt a +// blocked syscall. The watchdog timer closes the underlying reader on expiry, +// which unblocks the in-flight Read with an error. Each read that returns data +// resets the idle clock, so a slow-but-steady transfer never trips the guard. +type idleTimeoutReader struct { + rc io.ReadCloser + timeout time.Duration + + mu sync.Mutex + timer *time.Timer + fired bool + done bool +} + +func newIdleTimeoutReader(rc io.ReadCloser, timeout time.Duration) *idleTimeoutReader { + r := &idleTimeoutReader{rc: rc, timeout: timeout} + r.timer = time.AfterFunc(timeout, r.onStall) + return r +} + +// onStall fires when no data has arrived within the timeout. Closing the +// underlying reader is what unblocks a Read parked in the kernel. +func (r *idleTimeoutReader) onStall() { + r.mu.Lock() + if r.done { + r.mu.Unlock() + return + } + r.fired = true + r.mu.Unlock() + _ = r.rc.Close() +} + +func (r *idleTimeoutReader) Read(p []byte) (int, error) { + n, err := r.rc.Read(p) + if n > 0 { + r.timer.Reset(r.timeout) + } + if err != nil { + r.mu.Lock() + fired := r.fired + r.mu.Unlock() + if fired { + // Translate the "use of closed connection" the watchdog induced + // into an actionable stall error. This is not context.Canceled, + // so the caller keeps the .partial file for a later resume. + return n, fmt.Errorf("download stalled: no data received for %s", r.timeout) + } + } + return n, err +} + +func (r *idleTimeoutReader) Close() error { + r.mu.Lock() + r.done = true + r.mu.Unlock() + r.timer.Stop() + return r.rc.Close() +} diff --git a/pkg/downloader/stall_test.go b/pkg/downloader/stall_test.go new file mode 100644 index 000000000..8e6a003c6 --- /dev/null +++ b/pkg/downloader/stall_test.go @@ -0,0 +1,131 @@ +package downloader_test + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "time" + + . "github.com/mudler/LocalAI/pkg/downloader" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Download stall timeout", func() { + var filePath string + var savedTimeout time.Duration + + BeforeEach(func() { + dir, err := os.Getwd() + Expect(err).ToNot(HaveOccurred()) + filePath = dir + "/stall_model" + savedTimeout = DownloadStallTimeout + }) + + AfterEach(func() { + DownloadStallTimeout = savedTimeout + _ = os.Remove(filePath) + _ = os.Remove(filePath + ".partial") + }) + + It("aborts a download that stalls mid-stream instead of hanging forever", func() { + // Server sends a chunk, flushes, then blocks forever without closing + // the connection — a silently-dropped TCP stream. Without a stall + // guard the body Read blocks indefinitely and DownloadFile never + // returns. + release := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "HEAD" { + w.Header().Set("Accept-Ranges", "bytes") + w.WriteHeader(http.StatusOK) + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write(make([]byte, 4096)) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + <-release // hang: no more data, never close + })) + defer server.Close() + defer close(release) + + DownloadStallTimeout = 300 * time.Millisecond + + done := make(chan error, 1) + go func() { + done <- URI(server.URL).DownloadFileWithContext( + context.Background(), filePath, "", 1, 1, + func(s1, s2, s3 string, f float64) {}) + }() + + var err error + Eventually(done, "5s").Should(Receive(&err)) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("stall")) + }) + + It("preserves the .partial file when a download stalls so it can resume", func() { + release := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "HEAD" { + w.Header().Set("Accept-Ranges", "bytes") + w.WriteHeader(http.StatusOK) + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write(make([]byte, 4096)) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + <-release + })) + defer server.Close() + defer close(release) + + DownloadStallTimeout = 300 * time.Millisecond + + done := make(chan error, 1) + go func() { + done <- URI(server.URL).DownloadFileWithContext( + context.Background(), filePath, "", 1, 1, + func(s1, s2, s3 string, f float64) {}) + }() + Eventually(done, "5s").Should(Receive(HaveOccurred())) + + info, statErr := os.Stat(filePath + ".partial") + Expect(statErr).ToNot(HaveOccurred(), "the .partial must survive a stall so the next attempt can resume") + Expect(info.Size()).To(BeNumerically(">", 0)) + }) + + It("does not abort a slow-but-steady download", func() { + // One byte every 100ms keeps the idle clock from ever expiring even + // though the total transfer outlasts the stall timeout. + payload := make([]byte, 12) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "HEAD" { + w.Header().Set("Accept-Ranges", "bytes") + w.WriteHeader(http.StatusOK) + return + } + w.WriteHeader(http.StatusOK) + f, _ := w.(http.Flusher) + for i := range payload { + _, _ = w.Write(payload[i : i+1]) + if f != nil { + f.Flush() + } + time.Sleep(100 * time.Millisecond) + } + })) + defer server.Close() + + DownloadStallTimeout = 300 * time.Millisecond + + err := URI(server.URL).DownloadFileWithContext( + context.Background(), filePath, "", 1, 1, + func(s1, s2, s3 string, f float64) {}) + Expect(err).ToNot(HaveOccurred()) + }) +}) diff --git a/pkg/downloader/uri.go b/pkg/downloader/uri.go index 4be1b9081..e5c06c61d 100644 --- a/pkg/downloader/uri.go +++ b/pkg/downloader/uri.go @@ -594,10 +594,12 @@ func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string // Start the request resp, err := downloadClient.Do(req) if err != nil { - // Check if error is due to context cancellation + // On cancellation keep the .partial file: the next attempt resumes + // via a Range request instead of restarting from zero. Frontend + // restarts (deploys, OOM) cancel in-flight downloads, and large + // GGUFs take long enough that deleting progress means they never + // finish. if errors.Is(err, context.Canceled) { - // Clean up partial file on cancellation - removePartialFile(tmpFilePath) return err } return fmt.Errorf("failed to download file %q: %v", filePath, err) @@ -608,6 +610,13 @@ func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string return fmt.Errorf("failed to download url %q, invalid status code %d", url, resp.StatusCode) } source = resp.Body + // Guard against a silently-stalled stream: a dropped TCP connection + // that never sends FIN/RST would otherwise block the body Read (and + // thus the whole install) forever. The watchdog aborts after a window + // of zero progress; the .partial is kept for a later resume. + if DownloadStallTimeout > 0 { + source = newIdleTimeoutReader(resp.Body, DownloadStallTimeout) + } contentLength = resp.ContentLength } defer source.Close() @@ -640,19 +649,18 @@ func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string _, err = xio.Copy(ctx, io.MultiWriter(outFile, progress), source) if err != nil { - // Check if error is due to context cancellation + // Keep the .partial on cancellation so the next attempt resumes. A + // stall-guard abort is a plain error (not context.Canceled) and also + // falls through here, likewise preserving the partial for resume. if errors.Is(err, context.Canceled) { - // Clean up partial file on cancellation - removePartialFile(tmpFilePath) return err } return fmt.Errorf("failed to write file %q: %v", filePath, err) } - // Check for cancellation before finalizing + // Check for cancellation before finalizing. Keep the .partial for resume. select { case <-ctx.Done(): - removePartialFile(tmpFilePath) return ctx.Err() default: }