Compare commits

...

3 Commits

Author SHA1 Message Date
Eva H
9bf41969f0 app: fix first update check delayed by 1 hour (#14427) 2026-02-25 18:29:55 -05:00
Jesse Gross
0f23b7bff5 mlxrunner: Cancel in-flight requests when the client disconnects
Currently, a canceled request can result in computation continuing
in the background to completion. It can also trigger a deadlock
when there is nobody to read the output tokens and the pipeline
cannot continue to the next request.
2026-02-25 14:00:42 -08:00
Jesse Gross
4e57d2094e mlxrunner: Simplify pipeline memory and cache management
Particularly in error cases, it can be difficult to ensure that
all pinned memory is unpinned, MLX buffers are released and cache
state is consistent. This encapsulates those pieces and sets up
proper deferrals so that this happens automatically on exit.
2026-02-25 14:00:42 -08:00
7 changed files with 160 additions and 84 deletions

View File

@@ -35,6 +35,7 @@ import (
var (
wv = &Webview{}
uiServerPort int
appStore *store.Store
)
var debug = strings.EqualFold(os.Getenv("OLLAMA_DEBUG"), "true") || os.Getenv("OLLAMA_DEBUG") == "1"
@@ -208,6 +209,7 @@ func main() {
uiServerPort = port
st := &store.Store{}
appStore = st
// Enable CORS in development mode
if devMode {
@@ -360,8 +362,7 @@ func startHiddenTasks() {
slog.Info("deferring pending update for fast startup")
} else {
// Check if auto-update is enabled before automatically upgrading
st := &store.Store{}
settings, err := st.Settings()
settings, err := appStore.Settings()
if err != nil {
slog.Warn("failed to load settings for upgrade check", "error", err)
} else if !settings.AutoUpdateEnabled {

View File

@@ -289,6 +289,7 @@ func (u *Updater) TriggerImmediateCheck() {
func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) {
u.checkNow = make(chan struct{}, 1)
u.checkNow <- struct{}{} // Trigger first check after initial delay
go func() {
// Don't blast an update message immediately after startup
time.Sleep(UpdateCheckInitialDelay)
@@ -333,7 +334,7 @@ func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(str
continue
}
// Download successful - show tray notification (regardless of toggle state)
// Download successful - show tray notification
err = cb(resp.UpdateVersion)
if err != nil {
slog.Warn("failed to register update available with tray", "error", err)

View File

@@ -351,10 +351,13 @@ func TestTriggerImmediateCheck(t *testing.T) {
updater.StartBackgroundUpdaterChecker(ctx, cb)
// Wait for goroutine to start and pass initial delay
time.Sleep(10 * time.Millisecond)
// Wait for the initial check that fires after the initial delay
select {
case <-checkDone:
case <-time.After(2 * time.Second):
t.Fatal("initial check did not happen")
}
// With 1 hour interval, no check should have happened yet
initialCount := checkCount.Load()
// Trigger immediate check

View File

@@ -9,59 +9,99 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
)
// CacheEntry stores a single sequence
type CacheEntry struct {
Tokens []int32
Caches []cache.Cache
type kvCache struct {
// For now we only support a single entry, so this is just one sequence
tokens []int32
caches []cache.Cache
}
// FindNearestCache finds the longest common prefix between tokens and the cached sequence
func (r *Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
if r.cache == nil {
slog.Info("Cache miss", "left", len(tokens))
return nil, tokens
// cacheSession manages caches for a single pipeline run.
// Callers should append generated tokens to outputs and
// defer close to save the cache state.
type cacheSession struct {
cache *kvCache
inputs []int32
outputs []int32
caches []cache.Cache
remaining []int32
}
// begin prepares caches for a new request. It finds the nearest
// matching cache or creates new caches if none match.
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
if len(c.caches) == 0 {
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
c.caches = cacheFactory.NewCaches()
} else {
c.caches = make([]cache.Cache, m.NumLayers())
for i := range c.caches {
c.caches[i] = cache.NewKVCache()
}
}
}
// Find longest common prefix
remaining := c.findRemaining(inputs)
return &cacheSession{
cache: c,
inputs: inputs,
caches: c.caches,
remaining: remaining,
}
}
// close saves the token state if the forward pass ran.
func (s *cacheSession) close() {
if offset := s.caches[0].Offset(); offset > 0 {
// Ensure that if we have run the forward pass and set the metadata
// that we also actually have the data
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
for _, c := range s.caches {
k, v := c.State()
arrays = append(arrays, k, v)
}
mlx.AsyncEval(arrays...)
s.cache.tokens = append(s.inputs, s.outputs...)[:offset]
}
}
// findRemaining finds the longest common prefix between tokens and the cached
// sequence, trims stale cache entries, and returns the remaining tokens.
func (c *kvCache) findRemaining(tokens []int32) []int32 {
prefix := 0
for prefix < len(tokens) && prefix < len(r.cache.Tokens) && tokens[prefix] == r.cache.Tokens[prefix] {
for prefix < len(tokens) && prefix < len(c.tokens) && tokens[prefix] == c.tokens[prefix] {
prefix++
}
switch {
case prefix == 0:
for _, c := range r.cache.Caches {
c.Free()
if prefix < len(c.tokens) {
trim := len(c.tokens) - prefix
for _, kv := range c.caches {
kv.Trim(trim)
}
r.cache = nil
c.tokens = c.tokens[:prefix]
}
if prefix == 0 {
slog.Info("Cache miss", "left", len(tokens))
return nil, tokens
case prefix < len(r.cache.Tokens):
trim := len(r.cache.Tokens) - prefix
for _, c := range r.cache.Caches {
c.Trim(trim)
}
r.cache.Tokens = r.cache.Tokens[:prefix]
} else {
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
}
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
return r.cache.Caches, tokens[prefix:]
return tokens[prefix:]
}
func (r *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
r.cache = &CacheEntry{
Tokens: tokens,
Caches: caches,
func (c *kvCache) log() {
if len(c.caches) == 0 {
return
}
}
func (c *CacheEntry) LogCache() {
var totalBytes int
for _, kv := range c.Caches {
for _, kv := range c.caches {
k, v := kv.State()
totalBytes += k.NumBytes() + v.NumBytes()
}
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.Caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
}

View File

@@ -10,7 +10,6 @@ import (
"time"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
@@ -19,6 +18,23 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
return errors.New("model not loaded")
}
var (
sample, logprobs *mlx.Array
nextSample, nextLogprobs *mlx.Array
)
defer func() {
mlx.Unpin(sample, logprobs)
mlx.Unpin(nextSample, nextLogprobs)
mlx.Sweep()
mlx.ClearCache()
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
mlx.LogArrays()
r.cache.log()
}
}()
enableCompile := true
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
enableCompile = modelCompile.EnableCompile()
@@ -30,22 +46,19 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
}
inputs := r.Tokenizer.Encode(request.Prompt, true)
session := r.cache.begin(r.Model, inputs)
defer session.close()
caches, tokens := r.FindNearestCache(inputs)
if len(caches) == 0 {
if cacheFactory, ok := r.Model.(interface{ NewCaches() []cache.Cache }); ok {
caches = cacheFactory.NewCaches()
} else {
caches = make([]cache.Cache, r.Model.NumLayers())
for i := range caches {
caches[i] = cache.NewKVCache()
}
}
}
caches := session.caches
tokens := session.remaining
total, processed := len(tokens), 0
slog.Info("Prompt processing progress", "processed", processed, "total", total)
for total-processed > 1 {
if err := request.Ctx.Err(); err != nil {
return err
}
n := min(2<<10, total-processed-1)
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
mlx.Sweep()
@@ -76,15 +89,18 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
return sample, logprobs
}
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed))
sample, logprobs = step(mlx.FromValues(tokens[processed:], total-processed))
var b bytes.Buffer
now := time.Now()
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
outputs := make([]int32, 0, request.Options.MaxTokens)
for i := range request.Options.MaxTokens {
nextSample, nextLogprobs := step(sample)
if err := request.Ctx.Err(); err != nil {
return err
}
nextSample, nextLogprobs = step(sample)
if i == 0 {
slog.Info("Prompt processing progress", "processed", total, "total", total)
@@ -94,43 +110,40 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
}
output := int32(sample.Int())
outputs = append(outputs, output)
session.outputs = append(session.outputs, output)
if r.Tokenizer.IsEOS(output) {
mlx.Unpin(nextSample, nextLogprobs)
final.Token = int(output)
final.DoneReason = 0
final.CompletionTokens = i
break
}
request.Responses <- Response{
select {
case <-request.Ctx.Done():
return request.Ctx.Err()
case request.Responses <- Response{
Text: r.Decode(output, &b),
Token: int(output),
}:
}
mlx.Unpin(sample, logprobs)
sample, logprobs = nextSample, nextLogprobs
nextSample, nextLogprobs = nil, nil
if i%256 == 0 {
mlx.ClearCache()
}
sample, logprobs = nextSample, nextLogprobs
}
mlx.Unpin(sample, logprobs)
final.CompletionTokensDuration = time.Since(now)
request.Responses <- final
r.InsertCache(append(inputs, outputs...), caches)
mlx.Sweep()
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
mlx.LogArrays()
if r.cache != nil {
r.cache.LogCache()
}
select {
case <-request.Ctx.Done():
return request.Ctx.Err()
case request.Responses <- final:
return nil
}
return nil
}
func (r Runner) Decode(sample int32, b *bytes.Buffer) string {

View File

@@ -12,7 +12,6 @@ import (
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
@@ -25,8 +24,9 @@ type Request struct {
Responses chan Response
Pipeline func(Request) error
Ctx context.Context
sample.Sampler
caches []cache.Cache
}
type TextCompletionsRequest struct {
@@ -61,7 +61,7 @@ type Runner struct {
Model base.Model
Tokenizer *tokenizer.Tokenizer
Requests chan Request
cache *CacheEntry
cache kvCache
}
func (r *Runner) Load(modelName string) error {
@@ -157,7 +157,7 @@ func (r *Runner) Run(host, port string, mux http.Handler) error {
return nil
case request := <-r.Requests:
if err := request.Pipeline(request); err != nil {
break
slog.Info("Request terminated", "error", err)
}
close(request.Responses)

View File

@@ -5,6 +5,7 @@ package mlxrunner
import (
"bytes"
"cmp"
"context"
"encoding/json"
"flag"
"fmt"
@@ -98,19 +99,36 @@ func Execute(args []string) error {
request.Options.TopK,
)
runner.Requests <- request
var cancel context.CancelFunc
request.Ctx, cancel = context.WithCancel(r.Context())
defer cancel()
select {
case <-r.Context().Done():
return
case runner.Requests <- request:
}
w.Header().Set("Content-Type", "application/jsonl")
w.WriteHeader(http.StatusOK)
enc := json.NewEncoder(w)
for response := range request.Responses {
if err := enc.Encode(response); err != nil {
slog.Error("Failed to encode response", "error", err)
for {
select {
case <-r.Context().Done():
return
}
case response, ok := <-request.Responses:
if !ok {
return
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
if err := enc.Encode(response); err != nil {
slog.Error("Failed to encode response", "error", err)
return
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
}
}
})