diff --git a/core/services/galleryop/service.go b/core/services/galleryop/service.go index df0352e99..5b611d41e 100644 --- a/core/services/galleryop/service.go +++ b/core/services/galleryop/service.go @@ -11,6 +11,7 @@ import ( "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/services/distributed" "github.com/mudler/LocalAI/core/services/messaging" + "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/xlog" @@ -402,6 +403,16 @@ func (g *GalleryService) applyCancel(id string) { } } +// newUserCancellableContext returns a child context whose CancelFunc cancels +// with the downloader.ErrUserCancelled cause. This lets the download layer +// distinguish a deliberate user cancel (discard the half-downloaded .partial) +// from an incidental cancellation such as process shutdown (keep the .partial +// so the next run resumes via Range instead of restarting from zero). +func newUserCancellableContext(parent context.Context) (context.Context, context.CancelFunc) { + ctx, cancelCause := context.WithCancelCause(parent) + return ctx, func() { cancelCause(downloader.ErrUserCancelled) } +} + // storeCancellation stores a cancellation function for an operation func (g *GalleryService) storeCancellation(id string, cancelFunc context.CancelFunc) { g.Lock() @@ -444,7 +455,7 @@ func (g *GalleryService) Start(c context.Context, cl *config.ModelConfigLoader, case op := <-g.BackendGalleryChannel: // Create context if not provided if op.Context == nil { - op.Context, op.CancelFunc = context.WithCancel(c) + op.Context, op.CancelFunc = newUserCancellableContext(c) g.storeCancellation(op.ID, op.CancelFunc) } else if op.CancelFunc != nil { g.storeCancellation(op.ID, op.CancelFunc) @@ -472,7 +483,7 @@ func (g *GalleryService) Start(c context.Context, cl *config.ModelConfigLoader, case op := <-g.ModelGalleryChannel: // Create context if not provided if op.Context == nil { - op.Context, op.CancelFunc = context.WithCancel(c) + op.Context, op.CancelFunc = newUserCancellableContext(c) g.storeCancellation(op.ID, op.CancelFunc) } else if op.CancelFunc != nil { g.storeCancellation(op.ID, op.CancelFunc) diff --git a/pkg/downloader/cancel_test.go b/pkg/downloader/cancel_test.go index c9025ffa3..76f8a2df5 100644 --- a/pkg/downloader/cancel_test.go +++ b/pkg/downloader/cancel_test.go @@ -93,6 +93,29 @@ var _ = Describe("Download cancellation", func() { Expect(info.Size()).To(BeNumerically("<", int64(len(data)))) }) + It("discards the .partial when the cancellation cause is ErrUserCancelled", func() { + data := make([]byte, 8192) + _, err := rand.Read(data) + Expect(err).ToNot(HaveOccurred()) + server := streamingRangeServer(data) + defer server.Close() + + // A deliberate user abort: cancel WITH the ErrUserCancelled cause. The + // half-finished download should not linger on disk. + ctx, cancel := context.WithCancelCause(context.Background()) + go func() { + time.Sleep(150 * time.Millisecond) + cancel(ErrUserCancelled) + }() + + err = URI(server.URL).DownloadFileWithContext(ctx, filePath, "", 1, 1, func(s1, s2, s3 string, f float64) {}) + Expect(err).To(HaveOccurred()) + Expect(errors.Is(err, context.Canceled)).To(BeTrue()) + + Expect(filePath + ".partial").ToNot(BeAnExistingFile(), + "a deliberate user cancel must not leave a dangling .partial behind") + }) + It("resumes from the preserved .partial after a cancellation and completes", func() { data := make([]byte, 8192) _, err := rand.Read(data) diff --git a/pkg/downloader/uri.go b/pkg/downloader/uri.go index e5c06c61d..41bdbe672 100644 --- a/pkg/downloader/uri.go +++ b/pkg/downloader/uri.go @@ -330,6 +330,18 @@ func (s URI) ResolveURL() string { return string(s) } +// ErrUserCancelled distinguishes a deliberate user abort from an incidental +// context cancellation (process shutdown, pod restart). Pass it as the cause +// when cancelling the download context: +// +// ctx, cancel := context.WithCancelCause(parent) +// cancel(downloader.ErrUserCancelled) // discards the .partial +// +// On a deliberate cancel the downloader removes the .partial (the user does not +// want a half-download lingering). On a plain cancellation it keeps the .partial +// so the next run resumes via Range instead of restarting from zero. +var ErrUserCancelled = errors.New("download cancelled by user") + func removePartialFile(tmpFilePath string) error { xlog.Debug("Removing temporary file", "file", tmpFilePath) if err := os.Remove(tmpFilePath); err != nil && !errors.Is(err, os.ErrNotExist) { @@ -594,13 +606,17 @@ func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string // Start the request resp, err := downloadClient.Do(req) if err != nil { - // On cancellation keep the .partial file: the next attempt resumes - // via a Range request instead of restarting from zero. Frontend - // restarts (deploys, OOM) cancel in-flight downloads, and large - // GGUFs take long enough that deleting progress means they never - // finish. - if errors.Is(err, context.Canceled) { - return err + // Detect cancellation via the context, not the returned error: a + // request cancelled *with a cause* surfaces the cause error (not + // context.Canceled) from the HTTP client. Keep the .partial for + // resume on an incidental cancel (shutdown, restart) — large GGUFs + // take long enough that deleting progress means they never finish — + // but discard it on a deliberate user abort (ErrUserCancelled). + if ctx.Err() != nil { + if errors.Is(context.Cause(ctx), ErrUserCancelled) { + _ = removePartialFile(tmpFilePath) + } + return ctx.Err() } return fmt.Errorf("failed to download file %q: %v", filePath, err) } @@ -649,18 +665,27 @@ func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string _, err = xio.Copy(ctx, io.MultiWriter(outFile, progress), source) if err != nil { - // Keep the .partial on cancellation so the next attempt resumes. A - // stall-guard abort is a plain error (not context.Canceled) and also - // falls through here, likewise preserving the partial for resume. - if errors.Is(err, context.Canceled) { - return err + // Detect cancellation via the context (a cause-cancelled read surfaces + // the cause, not context.Canceled). Keep the .partial for resume, + // except on a deliberate user abort (ErrUserCancelled), which discards + // it. A stall-guard abort leaves ctx uncancelled, so it falls through + // to the error path below and likewise preserves the partial. + if ctx.Err() != nil { + if errors.Is(context.Cause(ctx), ErrUserCancelled) { + _ = removePartialFile(tmpFilePath) + } + return ctx.Err() } return fmt.Errorf("failed to write file %q: %v", filePath, err) } - // Check for cancellation before finalizing. Keep the .partial for resume. + // Check for cancellation before finalizing. Keep the .partial for resume + // unless the user deliberately aborted. select { case <-ctx.Done(): + if errors.Is(context.Cause(ctx), ErrUserCancelled) { + _ = removePartialFile(tmpFilePath) + } return ctx.Err() default: }