Compare commits

...

2 Commits

Author SHA1 Message Date
Patrick Devine
8c261d3c26 fix scheduling hang 2026-02-16 18:36:25 -08:00
Patrick Devine
32e605af8f bugfix: better mlx model scheduling
This change properly loads in different mlx (safetensors) based models so that subsequent model
loads don't improperly do inference on the already loaded model.
2026-02-16 17:23:42 -08:00
2 changed files with 265 additions and 76 deletions

View File

@@ -22,6 +22,7 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen"
imagegenmanifest "github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/mlxrunner"
)
@@ -83,6 +84,28 @@ func InitScheduler(ctx context.Context) *Scheduler {
return sched
}
// schedulerModelKey returns the scheduler map key for a model.
// GGUF-backed models use ModelPath; safetensors/image models without a
// ModelPath use manifest digest so distinct models don't collide.
func schedulerModelKey(m *Model) string {
if m == nil {
return ""
}
if m.ModelPath != "" {
return m.ModelPath
}
if m.Digest != "" {
return "digest:" + m.Digest
}
if m.Name != "" {
return "name:" + m.Name
}
if m.ShortName != "" {
return "short:" + m.ShortName
}
return ""
}
// context must be canceled to decrement ref count and release the runner
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration, useImagegen bool) (chan *runnerRef, chan error) {
if opts.NumCtx < 4 {
@@ -104,8 +127,9 @@ func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, ses
useImagegen: useImagegen,
}
key := schedulerModelKey(req.model)
s.loadedMu.Lock()
runner := s.loaded[req.model.ModelPath]
runner := s.loaded[key]
s.loadedMu.Unlock()
if runner != nil && !runner.needsReload(c, req) {
req.useLoadedRunner(runner, s.finishedReqCh)
@@ -151,8 +175,9 @@ func (s *Scheduler) processPending(ctx context.Context) {
for {
var runnerToExpire *runnerRef
pendingKey := schedulerModelKey(pending.model)
s.loadedMu.Lock()
runner := s.loaded[pending.model.ModelPath]
runner := s.loaded[pendingKey]
loadedCount := len(s.loaded)
runnersSnapshot := make([]ml.FilteredRunnerDiscovery, 0, len(s.loaded))
for _, r := range s.loaded {
@@ -166,7 +191,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
runnerToExpire = runner
} else {
// Runner is usable, return it
logutil.Trace("using existing loaded runner", "model", pending.model.ModelPath)
logutil.Trace("using existing loaded runner", "model", pendingKey)
pending.useLoadedRunner(runner, s.finishedReqCh)
break
}
@@ -198,54 +223,54 @@ func (s *Scheduler) processPending(ctx context.Context) {
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
}
// Check for image generation models - all use MLX runner
if slices.Contains(pending.model.Config.Capabilities, "image") {
if s.loadMLX(pending) {
break
}
continue
}
// Check for experimental safetensors LLM models
if pending.model.Config.ModelFormat == "safetensors" {
if slices.Contains(pending.model.Config.Capabilities, "completion") {
// LLM model with safetensors format - use MLX runner
// Check for image generation models and experimental safetensors
// LLM models - both use MLX runner.
isMLXModel := slices.Contains(pending.model.Config.Capabilities, "image") ||
(pending.model.Config.ModelFormat == "safetensors" &&
slices.Contains(pending.model.Config.Capabilities, "completion"))
if isMLXModel {
// Account for currently loaded runners before fit estimate.
s.updateFreeSpace(gpus)
if loadedCount > 0 && s.shouldEvictForMLXLoad(pending.model.ShortName, gpus) {
slog.Debug("mlx model does not fit with loaded runners, evicting one first", "model", pending.model.ShortName)
runnerToExpire = s.findRunnerToUnload()
} else {
if s.loadMLX(pending) {
break
}
continue
}
} else {
// Load model for fitting
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
if err != nil {
pending.errCh <- err
break
}
// Update free memory from currently loaded models
logutil.Trace("updating free space", "gpu_count", len(gpus), "model", pending.model.ModelPath)
s.updateFreeSpace(gpus)
if loadedCount == 0 {
// No models loaded. Load the model but prefer the best fit.
slog.Debug("loading first model", "model", pending.model.ModelPath)
s.loadFn(pending, ggml, systemInfo, gpus, false)
break
}
// More than one loaded model, so we have to see if the
// new one fits
logutil.Trace("loading additional model", "model", pending.model.ModelPath)
needEvict := s.loadFn(pending, ggml, systemInfo, gpus, true)
if !needEvict {
slog.Debug("new model fits with existing models, loading")
break
}
runnerToExpire = s.findRunnerToUnload()
}
// Load model for fitting
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
if err != nil {
pending.errCh <- err
break
}
// Update free memory from currently loaded models
logutil.Trace("updating free space", "gpu_count", len(gpus), "model", pending.model.ModelPath)
s.updateFreeSpace(gpus)
if loadedCount == 0 {
// No models loaded. Load the model but prefer the best fit.
slog.Debug("loading first model", "model", pending.model.ModelPath)
s.loadFn(pending, ggml, systemInfo, gpus, false)
break
}
// More than one loaded model, so we have to see if the
// new one fits
logutil.Trace("loading additional model", "model", pending.model.ModelPath)
needEvict := s.loadFn(pending, ggml, systemInfo, gpus, true)
if !needEvict {
slog.Debug("new model fits with existing models, loading")
break
}
runnerToExpire = s.findRunnerToUnload()
}
if runnerToExpire == nil {
@@ -292,22 +317,24 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
slog.Debug("shutting down scheduler completed loop")
return
case finished := <-s.finishedReqCh:
finishedKey := schedulerModelKey(finished.model)
s.loadedMu.Lock()
runner := s.loaded[finished.model.ModelPath]
runner := s.loaded[finishedKey]
s.loadedMu.Unlock()
if runner == nil {
slog.Error("finished request signal received after model unloaded", "modelPath", finished.model.ModelPath)
slog.Error("finished request signal received after model unloaded", "modelPath", finishedKey)
continue
}
runner.refMu.Lock()
runner.refCount--
if runner.refCount <= 0 {
if runner.sessionDuration <= 0 {
slog.Debug("runner with zero duration has gone idle, expiring to unload", "runner", runner)
if runner.sessionDuration <= 0 || runner.expirePending {
slog.Debug("runner has gone idle with pending expiration, expiring to unload", "runner", runner)
if runner.expireTimer != nil {
runner.expireTimer.Stop()
runner.expireTimer = nil
}
runner.expirePending = false
s.expiredCh <- runner
} else if runner.expireTimer == nil {
slog.Debug("runner with non-zero duration has gone idle, adding timer", "runner", runner, "duration", runner.sessionDuration)
@@ -315,16 +342,24 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
slog.Debug("timer expired, expiring to unload", "runner", runner)
runner.refMu.Lock()
defer runner.refMu.Unlock()
if runner.expireTimer != nil {
runner.expireTimer.Stop()
runner.expireTimer = nil
// Timer was canceled after callback dispatch.
if runner.expireTimer == nil {
return
}
runner.expireTimer.Stop()
runner.expireTimer = nil
if runner.refCount > 0 {
runner.expirePending = true
return
}
s.expiredCh <- runner
})
runner.expirePending = false
runner.expiresAt = time.Now().Add(runner.sessionDuration)
} else {
slog.Debug("runner with non-zero duration has gone idle, resetting timer", "runner", runner, "duration", runner.sessionDuration)
runner.expireTimer.Reset(runner.sessionDuration)
runner.expirePending = false
runner.expiresAt = time.Now().Add(runner.sessionDuration)
}
}
@@ -334,20 +369,16 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
slog.Debug("runner expired event received", "runner", runner)
runner.refMu.Lock()
if runner.refCount > 0 {
slog.Debug("expired event with positive ref count, retrying", "runner", runner, "refCount", runner.refCount)
go func(runner *runnerRef) {
// We can't unload yet, but want to as soon as the current request completes
// So queue up another expired event
time.Sleep(10 * time.Millisecond)
s.expiredCh <- runner
}(runner)
// Mark expiration to happen as soon as this runner goes idle.
slog.Debug("expired event with positive ref count, deferring until idle", "runner", runner, "refCount", runner.refCount)
runner.expirePending = true
runner.refMu.Unlock()
continue
}
s.loadedMu.Lock()
slog.Debug("got lock to unload expired event", "runner", runner)
runnerToUnload := s.loaded[runner.modelPath]
runnerToUnload := s.loaded[runner.modelKey]
if runnerToUnload == nil {
// If runnerToUnload is nil, we already processed an event and
// unloaded it. This double unload can happen if the initial
@@ -376,7 +407,7 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
}
finished := s.waitForVRAMRecovery(runner, runnersSnapshot)
runner.unload()
delete(s.loaded, runner.modelPath)
delete(s.loaded, runner.modelKey)
s.loadedMu.Unlock()
slog.Debug("runner terminated and removed from list, blocking for VRAM recovery", "runner", runner)
<-finished
@@ -514,6 +545,7 @@ iGPUScan:
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
modelKey: schedulerModelKey(req.model),
llama: llama,
Options: &req.opts,
sessionDuration: sessionDuration,
@@ -528,7 +560,7 @@ iGPUScan:
runner.refMu.Lock() // hold lock until running or aborted
s.loadedMu.Lock()
if oldRunner, ok := s.loaded[req.model.ModelPath]; ok {
if oldRunner, ok := s.loaded[runner.modelKey]; ok {
// Shouldn't happen, but safeguard against leaking a runner
slog.Warn("model was still loaded", "old_runner", oldRunner, "new_runner", runner)
oldRunner.refMu.Lock()
@@ -536,7 +568,7 @@ iGPUScan:
oldRunner.refMu.Unlock()
}
s.activeLoading = nil
s.loaded[req.model.ModelPath] = runner
s.loaded[runner.modelKey] = runner
slog.Info("loaded runners", "count", len(s.loaded))
s.loadedMu.Unlock()
@@ -566,6 +598,44 @@ iGPUScan:
return false
}
func mlxModelFitsInMemory(required uint64, gpus []ml.DeviceInfo) bool {
if required == 0 || len(gpus) == 0 {
return true
}
var available uint64
for _, gpu := range gpus {
reserved := envconfig.GpuOverhead() + gpu.MinimumMemory()
if gpu.FreeMemory > reserved {
available += gpu.FreeMemory - reserved
}
}
// If we can't compute usable free memory, fall back to attempting a load.
if available == 0 {
return true
}
return required <= available
}
func (s *Scheduler) shouldEvictForMLXLoad(modelName string, gpus []ml.DeviceInfo) bool {
m, err := imagegenmanifest.LoadManifest(modelName)
if err != nil {
return false
}
required := uint64(m.TotalTensorSize())
if fits := mlxModelFitsInMemory(required, gpus); !fits {
slog.Debug("mlx memory fit check failed",
"model", modelName,
"required", format.HumanBytes2(required))
return true
}
return false
}
// loadMLX loads an experimental safetensors model using the unified MLX runner.
// This supports both LLM (completion) and image generation models.
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
@@ -596,6 +666,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
modelKey: schedulerModelKey(req.model),
llama: server,
Options: &req.opts,
loading: false,
@@ -606,13 +677,25 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
}
s.loadedMu.Lock()
s.loaded[req.model.ModelPath] = runner
s.loaded[runner.modelKey] = runner
s.loadedMu.Unlock()
// Set up expiration timer
runner.refMu.Lock()
if sessionDuration > 0 {
runner.expireTimer = time.AfterFunc(sessionDuration, func() {
runner.refMu.Lock()
defer runner.refMu.Unlock()
// Timer was canceled after callback dispatch.
if runner.expireTimer == nil {
return
}
runner.expireTimer.Stop()
runner.expireTimer = nil
if runner.refCount > 0 {
runner.expirePending = true
return
}
s.expiredCh <- runner
})
}
@@ -637,7 +720,9 @@ func (s *Scheduler) updateFreeSpace(allGpus []ml.DeviceInfo) {
r.refMu.Lock()
if r.llama != nil {
for _, gpu := range allGpus {
predMap[gpu.DeviceID] += r.llama.VRAMByGPU(gpu.DeviceID)
if vram, ok := safeVRAMByGPU(r.llama, gpu.DeviceID); ok {
predMap[gpu.DeviceID] += vram
}
}
} else {
slog.Warn("unexpected nil runner reference, memory prediction may be incorrect")
@@ -664,6 +749,15 @@ func (s *Scheduler) updateFreeSpace(allGpus []ml.DeviceInfo) {
}
}
func safeVRAMByGPU(server llm.LlamaServer, id ml.DeviceID) (_ uint64, ok bool) {
defer func() {
if recover() != nil {
ok = false
}
}()
return server.VRAMByGPU(id), true
}
// TODO consolidate sched_types.go
type runnerRef struct {
refMu sync.Mutex
@@ -680,10 +774,12 @@ type runnerRef struct {
sessionDuration time.Duration
expireTimer *time.Timer
expirePending bool
expiresAt time.Time
model *Model
modelPath string
modelKey string
numParallel int
*api.Options
}
@@ -703,7 +799,7 @@ func (runner *runnerRef) unload() {
}
func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
slog.Debug("evaluating already loaded", "model", req.model.ModelPath)
slog.Debug("evaluating already loaded", "model", schedulerModelKey(req.model))
runner.refMu.Lock()
defer runner.refMu.Unlock()
@@ -814,6 +910,10 @@ func (runner *runnerRef) LogValue() slog.Value {
if runner == nil {
return slog.StringValue("nil")
}
modelID := runner.modelPath
if modelID == "" {
modelID = runner.modelKey
}
attrs := []slog.Attr{}
if runner.model != nil {
attrs = append(attrs, slog.String("name", runner.model.Name))
@@ -828,7 +928,7 @@ func (runner *runnerRef) LogValue() slog.Value {
slog.String("vram", format.HumanBytes2(runner.vramSize)),
slog.Int("parallel", runner.numParallel),
slog.Int("pid", runner.pid),
slog.String("model", runner.modelPath),
slog.String("model", modelID),
)
if runner.Options != nil {
attrs = append(attrs, slog.Int("num_ctx", runner.Options.NumCtx))
@@ -873,8 +973,16 @@ func (a ByDurationAndName) Less(i, j int) bool {
if d1 != d2 {
return d1 < d2
}
// Secondary sort by model path lex order
return a[i].modelPath < a[j].modelPath
// Secondary sort by model key/path lex order
n1 := a[i].modelPath
if n1 == "" {
n1 = a[i].modelKey
}
n2 := a[j].modelPath
if n2 == "" {
n2 = a[j].modelKey
}
return n1 < n2
}
// TODO - future consideration to pick runners based on size
@@ -934,8 +1042,9 @@ func (s *Scheduler) unloadAllRunners() {
}
func (s *Scheduler) expireRunner(model *Model) {
modelKey := schedulerModelKey(model)
s.loadedMu.Lock()
runner, ok := s.loaded[model.ModelPath]
runner, ok := s.loaded[modelKey]
s.loadedMu.Unlock()
if ok {
runner.refMu.Lock()

View File

@@ -448,6 +448,71 @@ func TestSchedGetRunner(t *testing.T) {
b.ctxDone()
}
func TestSchedGetRunnerUsesDigestKeyWhenModelPathEmpty(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
defer done()
s := InitScheduler(ctx)
opts := api.DefaultOptions()
opts.NumCtx = 4
loadedModel := &Model{Name: "safetensors-a", Digest: "sha-a"}
loadedRunner := &runnerRef{
model: loadedModel,
modelKey: schedulerModelKey(loadedModel),
llama: &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}},
Options: &opts,
numParallel: 1,
}
s.loadedMu.Lock()
s.loaded[loadedRunner.modelKey] = loadedRunner
s.loadedMu.Unlock()
reqModel := &Model{Name: "safetensors-b", Digest: "sha-b"}
successCh, errCh := s.GetRunner(ctx, reqModel, opts, nil, false)
require.Empty(t, successCh)
require.Empty(t, errCh)
require.Len(t, s.pendingReqCh, 1)
}
func TestSchedGetRunnerReusesSameDigestWhenModelPathEmpty(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
defer done()
s := InitScheduler(ctx)
opts := api.DefaultOptions()
opts.NumCtx = 4
loadedModel := &Model{Name: "safetensors-a", Digest: "sha-a"}
loadedRunner := &runnerRef{
model: loadedModel,
modelKey: schedulerModelKey(loadedModel),
llama: &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}},
Options: &opts,
numParallel: 1,
}
s.loadedMu.Lock()
s.loaded[loadedRunner.modelKey] = loadedRunner
s.loadedMu.Unlock()
reqCtx, cancelReq := context.WithCancel(ctx)
successCh, errCh := s.GetRunner(reqCtx, &Model{Name: "safetensors-a-copy", Digest: "sha-a"}, opts, nil, false)
cancelReq()
select {
case runner := <-successCh:
require.Equal(t, loadedRunner, runner)
default:
t.Fatal("expected existing runner to be reused")
}
require.Empty(t, errCh)
require.Empty(t, s.pendingReqCh)
}
func TestSchedExpireRunner(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 20*time.Millisecond)
defer done()
@@ -530,12 +595,11 @@ func TestSchedPrematureExpired(t *testing.T) {
time.Sleep(scenario1a.req.sessionDuration.Duration)
scenario1a.ctxDone()
time.Sleep(20 * time.Millisecond)
require.LessOrEqual(t, len(s.finishedReqCh), 1)
time.Sleep(10 * time.Millisecond)
require.Empty(t, s.finishedReqCh)
s.loadedMu.Lock()
require.Empty(t, s.loaded)
s.loadedMu.Unlock()
require.Eventually(t, func() bool {
s.loadedMu.Lock()
defer s.loadedMu.Unlock()
return len(s.loaded) == 0
}, 300*time.Millisecond, 10*time.Millisecond)
// also shouldn't happen in real life
s.finishedReqCh <- scenario1a.req
@@ -613,6 +677,22 @@ func TestSchedUpdateFreeSpace(t *testing.T) {
require.Equal(t, uint64(2000-50-75), gpus[1].FreeMemory)
}
func TestMLXModelFitsInMemory(t *testing.T) {
t.Setenv("OLLAMA_GPU_OVERHEAD", "0")
gpus := []ml.DeviceInfo{
{
DeviceID: ml.DeviceID{Library: "Metal"},
FreeMemory: 10 * format.GigaByte,
},
}
require.True(t, mlxModelFitsInMemory(9*format.GigaByte, gpus))
require.False(t, mlxModelFitsInMemory(10*format.GigaByte, gpus))
require.True(t, mlxModelFitsInMemory(0, gpus))
require.True(t, mlxModelFitsInMemory(10*format.GigaByte, nil))
}
func TestSchedFindRunnerToUnload(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
defer done()