diff --git a/core/application/startup.go b/core/application/startup.go index 6438c7df3..fa5de5ede 100644 --- a/core/application/startup.go +++ b/core/application/startup.go @@ -25,6 +25,7 @@ import ( "github.com/mudler/LocalAI/core/services/storage" coreStartup "github.com/mudler/LocalAI/core/startup" "github.com/mudler/LocalAI/internal" + "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/signals" "github.com/mudler/LocalAI/pkg/vram" @@ -71,6 +72,16 @@ func New(opts ...config.AppOption) (*Application, error) { if err != nil { return nil, fmt.Errorf("unable to create ModelPath: %q", err) } + + // Reap *.partial downloads abandoned by a previous run (killed mid-transfer + // by an OOM/restart, or stalled before cleanup could run). The 24h window + // is well beyond any legitimate in-flight download, so this never trims an + // active transfer; it just stops dead partials accumulating on the volume. + if removed, cErr := downloader.CleanupStalePartialFiles(options.SystemState.Model.ModelsPath, 24*time.Hour); cErr != nil { + xlog.Warn("Failed to reap stale partial downloads", "error", cErr) + } else if removed > 0 { + xlog.Info("Reaped stale partial downloads", "count", removed) + } if options.GeneratedContentDir != "" { err := os.MkdirAll(options.GeneratedContentDir, 0o750) if err != nil { diff --git a/core/services/galleryop/service.go b/core/services/galleryop/service.go index df0352e99..5b611d41e 100644 --- a/core/services/galleryop/service.go +++ b/core/services/galleryop/service.go @@ -11,6 +11,7 @@ import ( "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/services/distributed" "github.com/mudler/LocalAI/core/services/messaging" + "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/xlog" @@ -402,6 +403,16 @@ func (g *GalleryService) applyCancel(id string) { } } +// newUserCancellableContext returns a child context whose CancelFunc cancels +// with the downloader.ErrUserCancelled cause. This lets the download layer +// distinguish a deliberate user cancel (discard the half-downloaded .partial) +// from an incidental cancellation such as process shutdown (keep the .partial +// so the next run resumes via Range instead of restarting from zero). +func newUserCancellableContext(parent context.Context) (context.Context, context.CancelFunc) { + ctx, cancelCause := context.WithCancelCause(parent) + return ctx, func() { cancelCause(downloader.ErrUserCancelled) } +} + // storeCancellation stores a cancellation function for an operation func (g *GalleryService) storeCancellation(id string, cancelFunc context.CancelFunc) { g.Lock() @@ -444,7 +455,7 @@ func (g *GalleryService) Start(c context.Context, cl *config.ModelConfigLoader, case op := <-g.BackendGalleryChannel: // Create context if not provided if op.Context == nil { - op.Context, op.CancelFunc = context.WithCancel(c) + op.Context, op.CancelFunc = newUserCancellableContext(c) g.storeCancellation(op.ID, op.CancelFunc) } else if op.CancelFunc != nil { g.storeCancellation(op.ID, op.CancelFunc) @@ -472,7 +483,7 @@ func (g *GalleryService) Start(c context.Context, cl *config.ModelConfigLoader, case op := <-g.ModelGalleryChannel: // Create context if not provided if op.Context == nil { - op.Context, op.CancelFunc = context.WithCancel(c) + op.Context, op.CancelFunc = newUserCancellableContext(c) g.storeCancellation(op.ID, op.CancelFunc) } else if op.CancelFunc != nil { g.storeCancellation(op.ID, op.CancelFunc) diff --git a/pkg/downloader/cancel_test.go b/pkg/downloader/cancel_test.go new file mode 100644 index 000000000..76f8a2df5 --- /dev/null +++ b/pkg/downloader/cancel_test.go @@ -0,0 +1,148 @@ +package downloader_test + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "os" + "strconv" + "strings" + "time" + + . "github.com/mudler/LocalAI/pkg/downloader" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Download cancellation", func() { + var filePath string + + // streamingRangeServer serves data one small chunk at a time with a short + // pause between chunks, so a context cancellation can land mid-transfer. + // It honors a `bytes=N-` Range request so a second attempt can resume. + streamingRangeServer := func(data []byte) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "HEAD" { + w.Header().Set("Accept-Ranges", "bytes") + w.WriteHeader(http.StatusOK) + return + } + start := 0 + if rh := r.Header.Get("Range"); rh != "" { + _, _ = fmt.Sscanf(strings.TrimPrefix(rh, "bytes="), "%d-", &start) + } + w.Header().Set("Content-Length", strconv.Itoa(len(data)-start)) + if start > 0 { + w.WriteHeader(http.StatusPartialContent) + } else { + w.WriteHeader(http.StatusOK) + } + f, _ := w.(http.Flusher) + for i := start; i < len(data); i += 256 { + end := i + 256 + if end > len(data) { + end = len(data) + } + if _, err := w.Write(data[i:end]); err != nil { + return + } + if f != nil { + f.Flush() + } + time.Sleep(20 * time.Millisecond) + } + })) + } + + BeforeEach(func() { + dir, err := os.Getwd() + Expect(err).ToNot(HaveOccurred()) + filePath = dir + "/cancel_model" + }) + + AfterEach(func() { + _ = os.Remove(filePath) + _ = os.Remove(filePath + ".partial") + }) + + It("keeps the .partial file when the context is cancelled so the download can resume", func() { + data := make([]byte, 8192) + _, err := rand.Read(data) + Expect(err).ToNot(HaveOccurred()) + server := streamingRangeServer(data) + defer server.Close() + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(150 * time.Millisecond) + cancel() + }() + + err = URI(server.URL).DownloadFileWithContext(ctx, filePath, "", 1, 1, func(s1, s2, s3 string, f float64) {}) + Expect(err).To(HaveOccurred()) + Expect(errors.Is(err, context.Canceled)).To(BeTrue()) + + info, statErr := os.Stat(filePath + ".partial") + Expect(statErr).ToNot(HaveOccurred(), + "a cancelled download must leave its .partial behind so the retry resumes instead of restarting from zero") + Expect(info.Size()).To(BeNumerically(">", 0)) + Expect(info.Size()).To(BeNumerically("<", int64(len(data)))) + }) + + It("discards the .partial when the cancellation cause is ErrUserCancelled", func() { + data := make([]byte, 8192) + _, err := rand.Read(data) + Expect(err).ToNot(HaveOccurred()) + server := streamingRangeServer(data) + defer server.Close() + + // A deliberate user abort: cancel WITH the ErrUserCancelled cause. The + // half-finished download should not linger on disk. + ctx, cancel := context.WithCancelCause(context.Background()) + go func() { + time.Sleep(150 * time.Millisecond) + cancel(ErrUserCancelled) + }() + + err = URI(server.URL).DownloadFileWithContext(ctx, filePath, "", 1, 1, func(s1, s2, s3 string, f float64) {}) + Expect(err).To(HaveOccurred()) + Expect(errors.Is(err, context.Canceled)).To(BeTrue()) + + Expect(filePath + ".partial").ToNot(BeAnExistingFile(), + "a deliberate user cancel must not leave a dangling .partial behind") + }) + + It("resumes from the preserved .partial after a cancellation and completes", func() { + data := make([]byte, 8192) + _, err := rand.Read(data) + Expect(err).ToNot(HaveOccurred()) + sum := sha256.Sum256(data) + sha := fmt.Sprintf("%x", sum) + server := streamingRangeServer(data) + defer server.Close() + + // First attempt: cancel mid-stream. + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(150 * time.Millisecond) + cancel() + }() + err = URI(server.URL).DownloadFileWithContext(ctx, filePath, sha, 1, 1, func(s1, s2, s3 string, f float64) {}) + Expect(err).To(HaveOccurred()) + partialInfo, statErr := os.Stat(filePath + ".partial") + Expect(statErr).ToNot(HaveOccurred()) + resumedFrom := partialInfo.Size() + Expect(resumedFrom).To(BeNumerically(">", 0)) + + // Second attempt: fresh context, must resume and finish with a valid SHA. + err = URI(server.URL).DownloadFileWithContext(context.Background(), filePath, sha, 1, 1, func(s1, s2, s3 string, f float64) {}) + Expect(err).ToNot(HaveOccurred()) + final, rerr := os.ReadFile(filePath) + Expect(rerr).ToNot(HaveOccurred()) + Expect(final).To(Equal(data)) + }) +}) diff --git a/pkg/downloader/partial.go b/pkg/downloader/partial.go new file mode 100644 index 000000000..f816bb09f --- /dev/null +++ b/pkg/downloader/partial.go @@ -0,0 +1,69 @@ +package downloader + +import ( + "io/fs" + "os" + "path/filepath" + "strings" + "time" + + "github.com/mudler/xlog" +) + +// PartialFileSuffix marks an in-progress download. The success path renames the +// partial to its final name, so any leftover with this suffix is an unfinished +// transfer. +const PartialFileSuffix = ".partial" + +// CleanupStalePartialFiles removes *.partial files under root whose last +// modification is older than olderThan, returning the number removed. These are +// abandoned downloads left by a process killed mid-transfer (OOM, restart) or +// by a stall whose cleanup never ran; without reaping they accumulate and can +// fill the models volume. A still-in-progress download touches its .partial on +// every write, so a generous olderThan never trims an active transfer. +// +// A missing root is not an error (nothing to clean). Unreadable entries are +// skipped so one bad file does not abort the whole sweep. +func CleanupStalePartialFiles(root string, olderThan time.Duration) (int, error) { + if _, err := os.Stat(root); err != nil { + if os.IsNotExist(err) { + return 0, nil + } + return 0, err + } + + cutoff := time.Now().Add(-olderThan) + + // Collect candidates during the walk and delete them afterwards rather than + // mutating the tree from inside the WalkDir callback (avoids the symlink + // TOCTOU class flagged by gosec G122, and never removes an entry mid-walk). + var stale []string + err := filepath.WalkDir(root, func(path string, d fs.DirEntry, walkErr error) error { + if walkErr != nil { + return nil // skip unreadable subtree, keep going + } + if d.IsDir() || !strings.HasSuffix(d.Name(), PartialFileSuffix) { + return nil + } + info, err := d.Info() + if err != nil || info.ModTime().After(cutoff) { + return nil + } + stale = append(stale, path) + return nil + }) + if err != nil { + return 0, err + } + + removed := 0 + for _, path := range stale { + if err := os.Remove(path); err != nil { + xlog.Warn("failed to remove stale partial download", "file", path, "error", err) + continue + } + removed++ + xlog.Info("removed stale partial download", "file", path) + } + return removed, nil +} diff --git a/pkg/downloader/partial_test.go b/pkg/downloader/partial_test.go new file mode 100644 index 000000000..ceec8417f --- /dev/null +++ b/pkg/downloader/partial_test.go @@ -0,0 +1,53 @@ +package downloader_test + +import ( + "os" + "path/filepath" + "time" + + . "github.com/mudler/LocalAI/pkg/downloader" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("CleanupStalePartialFiles", func() { + var root string + + BeforeEach(func() { + var err error + root, err = os.MkdirTemp("", "partials") + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + _ = os.RemoveAll(root) + }) + + It("removes stale .partial files (recursively) while keeping fresh ones and completed files", func() { + nested := filepath.Join(root, "llama-cpp", "models", "foo") + Expect(os.MkdirAll(nested, 0755)).To(Succeed()) + + stale := filepath.Join(nested, "model.gguf.partial") + fresh := filepath.Join(root, "fresh.gguf.partial") + completed := filepath.Join(root, "done.gguf") + for _, f := range []string{stale, fresh, completed} { + Expect(os.WriteFile(f, []byte("data"), 0644)).To(Succeed()) + } + old := time.Now().Add(-2 * time.Hour) + Expect(os.Chtimes(stale, old, old)).To(Succeed()) + + removed, err := CleanupStalePartialFiles(root, time.Hour) + Expect(err).ToNot(HaveOccurred()) + Expect(removed).To(Equal(1)) + + Expect(stale).ToNot(BeAnExistingFile()) + Expect(fresh).To(BeAnExistingFile()) + Expect(completed).To(BeAnExistingFile()) + }) + + It("returns no error when the root directory does not exist", func() { + removed, err := CleanupStalePartialFiles(filepath.Join(root, "does-not-exist"), time.Hour) + Expect(err).ToNot(HaveOccurred()) + Expect(removed).To(Equal(0)) + }) +}) diff --git a/pkg/downloader/stall.go b/pkg/downloader/stall.go new file mode 100644 index 000000000..697ad25d9 --- /dev/null +++ b/pkg/downloader/stall.go @@ -0,0 +1,77 @@ +package downloader + +import ( + "fmt" + "io" + "sync" + "time" +) + +// DownloadStallTimeout bounds how long an in-flight download may receive no +// data before it is aborted. A silently-dropped TCP connection (no FIN/RST) +// would otherwise block the body read forever, freezing an install at N bytes +// until an external reaper kills it. Overridable (tests set it small); a value +// <= 0 disables the guard. +var DownloadStallTimeout = 60 * time.Second + +// idleTimeoutReader wraps a streaming ReadCloser and aborts reads that make no +// progress within timeout. A standard io.Copy blocks indefinitely on a Read +// against a dead-but-unclosed socket; nothing in the copy loop can interrupt a +// blocked syscall. The watchdog timer closes the underlying reader on expiry, +// which unblocks the in-flight Read with an error. Each read that returns data +// resets the idle clock, so a slow-but-steady transfer never trips the guard. +type idleTimeoutReader struct { + rc io.ReadCloser + timeout time.Duration + + mu sync.Mutex + timer *time.Timer + fired bool + done bool +} + +func newIdleTimeoutReader(rc io.ReadCloser, timeout time.Duration) *idleTimeoutReader { + r := &idleTimeoutReader{rc: rc, timeout: timeout} + r.timer = time.AfterFunc(timeout, r.onStall) + return r +} + +// onStall fires when no data has arrived within the timeout. Closing the +// underlying reader is what unblocks a Read parked in the kernel. +func (r *idleTimeoutReader) onStall() { + r.mu.Lock() + if r.done { + r.mu.Unlock() + return + } + r.fired = true + r.mu.Unlock() + _ = r.rc.Close() +} + +func (r *idleTimeoutReader) Read(p []byte) (int, error) { + n, err := r.rc.Read(p) + if n > 0 { + r.timer.Reset(r.timeout) + } + if err != nil { + r.mu.Lock() + fired := r.fired + r.mu.Unlock() + if fired { + // Translate the "use of closed connection" the watchdog induced + // into an actionable stall error. This is not context.Canceled, + // so the caller keeps the .partial file for a later resume. + return n, fmt.Errorf("download stalled: no data received for %s", r.timeout) + } + } + return n, err +} + +func (r *idleTimeoutReader) Close() error { + r.mu.Lock() + r.done = true + r.mu.Unlock() + r.timer.Stop() + return r.rc.Close() +} diff --git a/pkg/downloader/stall_test.go b/pkg/downloader/stall_test.go new file mode 100644 index 000000000..8e6a003c6 --- /dev/null +++ b/pkg/downloader/stall_test.go @@ -0,0 +1,131 @@ +package downloader_test + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "time" + + . "github.com/mudler/LocalAI/pkg/downloader" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Download stall timeout", func() { + var filePath string + var savedTimeout time.Duration + + BeforeEach(func() { + dir, err := os.Getwd() + Expect(err).ToNot(HaveOccurred()) + filePath = dir + "/stall_model" + savedTimeout = DownloadStallTimeout + }) + + AfterEach(func() { + DownloadStallTimeout = savedTimeout + _ = os.Remove(filePath) + _ = os.Remove(filePath + ".partial") + }) + + It("aborts a download that stalls mid-stream instead of hanging forever", func() { + // Server sends a chunk, flushes, then blocks forever without closing + // the connection — a silently-dropped TCP stream. Without a stall + // guard the body Read blocks indefinitely and DownloadFile never + // returns. + release := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "HEAD" { + w.Header().Set("Accept-Ranges", "bytes") + w.WriteHeader(http.StatusOK) + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write(make([]byte, 4096)) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + <-release // hang: no more data, never close + })) + defer server.Close() + defer close(release) + + DownloadStallTimeout = 300 * time.Millisecond + + done := make(chan error, 1) + go func() { + done <- URI(server.URL).DownloadFileWithContext( + context.Background(), filePath, "", 1, 1, + func(s1, s2, s3 string, f float64) {}) + }() + + var err error + Eventually(done, "5s").Should(Receive(&err)) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("stall")) + }) + + It("preserves the .partial file when a download stalls so it can resume", func() { + release := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "HEAD" { + w.Header().Set("Accept-Ranges", "bytes") + w.WriteHeader(http.StatusOK) + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write(make([]byte, 4096)) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + <-release + })) + defer server.Close() + defer close(release) + + DownloadStallTimeout = 300 * time.Millisecond + + done := make(chan error, 1) + go func() { + done <- URI(server.URL).DownloadFileWithContext( + context.Background(), filePath, "", 1, 1, + func(s1, s2, s3 string, f float64) {}) + }() + Eventually(done, "5s").Should(Receive(HaveOccurred())) + + info, statErr := os.Stat(filePath + ".partial") + Expect(statErr).ToNot(HaveOccurred(), "the .partial must survive a stall so the next attempt can resume") + Expect(info.Size()).To(BeNumerically(">", 0)) + }) + + It("does not abort a slow-but-steady download", func() { + // One byte every 100ms keeps the idle clock from ever expiring even + // though the total transfer outlasts the stall timeout. + payload := make([]byte, 12) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "HEAD" { + w.Header().Set("Accept-Ranges", "bytes") + w.WriteHeader(http.StatusOK) + return + } + w.WriteHeader(http.StatusOK) + f, _ := w.(http.Flusher) + for i := range payload { + _, _ = w.Write(payload[i : i+1]) + if f != nil { + f.Flush() + } + time.Sleep(100 * time.Millisecond) + } + })) + defer server.Close() + + DownloadStallTimeout = 300 * time.Millisecond + + err := URI(server.URL).DownloadFileWithContext( + context.Background(), filePath, "", 1, 1, + func(s1, s2, s3 string, f float64) {}) + Expect(err).ToNot(HaveOccurred()) + }) +}) diff --git a/pkg/downloader/uri.go b/pkg/downloader/uri.go index 4be1b9081..41bdbe672 100644 --- a/pkg/downloader/uri.go +++ b/pkg/downloader/uri.go @@ -330,6 +330,18 @@ func (s URI) ResolveURL() string { return string(s) } +// ErrUserCancelled distinguishes a deliberate user abort from an incidental +// context cancellation (process shutdown, pod restart). Pass it as the cause +// when cancelling the download context: +// +// ctx, cancel := context.WithCancelCause(parent) +// cancel(downloader.ErrUserCancelled) // discards the .partial +// +// On a deliberate cancel the downloader removes the .partial (the user does not +// want a half-download lingering). On a plain cancellation it keeps the .partial +// so the next run resumes via Range instead of restarting from zero. +var ErrUserCancelled = errors.New("download cancelled by user") + func removePartialFile(tmpFilePath string) error { xlog.Debug("Removing temporary file", "file", tmpFilePath) if err := os.Remove(tmpFilePath); err != nil && !errors.Is(err, os.ErrNotExist) { @@ -594,11 +606,17 @@ func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string // Start the request resp, err := downloadClient.Do(req) if err != nil { - // Check if error is due to context cancellation - if errors.Is(err, context.Canceled) { - // Clean up partial file on cancellation - removePartialFile(tmpFilePath) - return err + // Detect cancellation via the context, not the returned error: a + // request cancelled *with a cause* surfaces the cause error (not + // context.Canceled) from the HTTP client. Keep the .partial for + // resume on an incidental cancel (shutdown, restart) — large GGUFs + // take long enough that deleting progress means they never finish — + // but discard it on a deliberate user abort (ErrUserCancelled). + if ctx.Err() != nil { + if errors.Is(context.Cause(ctx), ErrUserCancelled) { + _ = removePartialFile(tmpFilePath) + } + return ctx.Err() } return fmt.Errorf("failed to download file %q: %v", filePath, err) } @@ -608,6 +626,13 @@ func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string return fmt.Errorf("failed to download url %q, invalid status code %d", url, resp.StatusCode) } source = resp.Body + // Guard against a silently-stalled stream: a dropped TCP connection + // that never sends FIN/RST would otherwise block the body Read (and + // thus the whole install) forever. The watchdog aborts after a window + // of zero progress; the .partial is kept for a later resume. + if DownloadStallTimeout > 0 { + source = newIdleTimeoutReader(resp.Body, DownloadStallTimeout) + } contentLength = resp.ContentLength } defer source.Close() @@ -640,19 +665,27 @@ func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string _, err = xio.Copy(ctx, io.MultiWriter(outFile, progress), source) if err != nil { - // Check if error is due to context cancellation - if errors.Is(err, context.Canceled) { - // Clean up partial file on cancellation - removePartialFile(tmpFilePath) - return err + // Detect cancellation via the context (a cause-cancelled read surfaces + // the cause, not context.Canceled). Keep the .partial for resume, + // except on a deliberate user abort (ErrUserCancelled), which discards + // it. A stall-guard abort leaves ctx uncancelled, so it falls through + // to the error path below and likewise preserves the partial. + if ctx.Err() != nil { + if errors.Is(context.Cause(ctx), ErrUserCancelled) { + _ = removePartialFile(tmpFilePath) + } + return ctx.Err() } return fmt.Errorf("failed to write file %q: %v", filePath, err) } - // Check for cancellation before finalizing + // Check for cancellation before finalizing. Keep the .partial for resume + // unless the user deliberately aborted. select { case <-ctx.Done(): - removePartialFile(tmpFilePath) + if errors.Is(context.Cause(ctx), ErrUserCancelled) { + _ = removePartialFile(tmpFilePath) + } return ctx.Err() default: }