Compare commits

..

4 Commits

Author SHA1 Message Date
Jesse Gross
6586e61525 mlxrunner: Fix prompt eval timing and count metrics
Only the last token's processing time is included in prompt processing,
giving an artificially high rate. In addition, the number of tokens
only included the tokens that miss the cache, instead of our historic
total tokens.
2026-02-26 15:14:47 -08:00
Jesse Gross
f6229d2464 mlxrunner: Enforce model context limit
Currently, context length is unbounded - the cache will keep
growing forever independent of the model's trained context
length. This caps it and enforces semantics similar to most
cloud services:
 - Long prompts will result in an error, not truncation.
 - Generation that exceeds the context will be stopped
2026-02-26 13:55:24 -08:00
Jesse Gross
fda19c1282 mlxrunner: Propagate pipeline errors to client via api.StatusError
Errors that occur during pipeline processing are currently only
logged but not sent back to the client. Rather than using HTTP
status codes as we have historically done, this serializes errors
as messages to allow sending them at any time during the stream.
2026-02-26 13:30:37 -08:00
Jesse Gross
93ca624a7c mlxrunner: Report actual memory usage from runner
The MLX runner previously reported a static VRAM estimate that was
computed at load time and consisted only of the weights. This is
strictly less than the actual memory usage, as it does not include
the KV cache or compute graph.
2026-02-26 10:41:06 -08:00
22 changed files with 215 additions and 288 deletions

View File

@@ -296,15 +296,8 @@ func main() {
// Check for pending updates on startup (show tray notification if update is ready)
if updater.IsUpdatePending() {
// On Windows, the tray is initialized in osRun(). Calling UpdateAvailable
// before that would dereference a nil tray callback.
// TODO: refactor so the update check runs after platform init on all platforms.
if runtime.GOOS == "windows" {
slog.Debug("update pending on startup, deferring tray notification until tray initialization")
} else {
slog.Debug("update pending on startup, showing tray notification")
UpdateAvailable("")
}
slog.Debug("update pending on startup, showing tray notification")
UpdateAvailable("")
}
hasCompletedFirstRun, err := st.HasCompletedFirstRun()

View File

@@ -154,10 +154,6 @@ func handleURLSchemeRequest(urlScheme string) {
}
func UpdateAvailable(ver string) error {
if app.t == nil {
slog.Debug("tray not yet initialized, skipping update notification")
return nil
}
return app.t.UpdateAvailable(ver)
}
@@ -169,14 +165,6 @@ func osRun(shutdown func(), hasCompletedFirstRun, startHidden bool) {
log.Fatalf("Failed to start: %s", err)
}
// Check for pending updates now that the tray is initialized.
// The platform-independent check in app.go fires before osRun,
// when app.t is still nil, so we must re-check here.
if updater.IsUpdatePending() {
slog.Debug("update pending on startup, showing tray notification")
UpdateAvailable("")
}
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)

View File

@@ -74,8 +74,7 @@ type LlamaServer interface {
Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error
VRAMSize() uint64 // Total VRAM across all GPUs
TotalSize() uint64
MemorySize() (total, vram uint64)
VRAMByGPU(id ml.DeviceID) uint64
Pid() int
GetPort() int
@@ -685,8 +684,9 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
// Windows CUDA should not use mmap for best performance
// Linux with a model larger than free space, mmap leads to thrashing
// For CPU loads we want the memory to be allocated, not FS cache
totalSize, _ := s.MemorySize()
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < s.TotalSize() && s.options.UseMMap == nil) ||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < totalSize && s.options.UseMMap == nil) ||
(len(gpus) == 0 && s.options.UseMMap == nil) ||
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
(s.options.UseMMap != nil && !*s.options.UseMMap) {
@@ -1848,17 +1848,17 @@ func (s *llamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
return nil
}
func (s *llmServer) VRAMSize() uint64 {
func (s *llmServer) MemorySize() (total, vram uint64) {
if s.mem == nil {
return 0
return 0, 0
}
var mem uint64
for _, g := range s.mem.GPUs {
mem += g.Size()
vram += g.Size()
}
total = s.mem.InputWeights + s.mem.CPU.Size() + vram
// Some elements are always on CPU. However, if we have allocated all layers
// on the GPU then include the CPU components as well, to represent complete offloading.
noCPULayers := true
@@ -1869,25 +1869,11 @@ func (s *llmServer) VRAMSize() uint64 {
}
}
if noCPULayers {
mem += s.mem.InputWeights
mem += s.mem.CPU.Graph
vram += s.mem.InputWeights
vram += s.mem.CPU.Graph
}
return mem
}
func (s *llmServer) TotalSize() uint64 {
if s.mem == nil {
return 0
}
mem := s.mem.InputWeights
mem += s.mem.CPU.Size()
for _, g := range s.mem.GPUs {
mem += g.Size()
}
return mem
return total, vram
}
func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {

View File

@@ -204,24 +204,6 @@ func (p *Qwen3Parser) eat() ([]qwen3Event, bool) {
p.maybeThinkingOpenAtBOL = false
}
thinkingCloseIdx := strings.Index(acc, qwen3ThinkingCloseTag)
toolOpenIdx := strings.Index(acc, qwen3ToolOpenTag)
// If a tool call starts before </think>, treat that as the end of thinking
// for parsing purposes and continue in tool-call mode.
if toolOpenIdx != -1 && (thinkingCloseIdx == -1 || toolOpenIdx < thinkingCloseIdx) {
before, after := p.splitAtTag(qwen3ToolOpenTag, true)
if len(before) > 0 {
events = append(events, qwen3EventThinkingContent{content: before})
}
if after == "" {
p.state = qwen3ParserStateToolStartedEatingWhitespace
} else {
p.state = qwen3ParserStateCollectingToolContent
}
return events, true
}
if strings.Contains(acc, qwen3ThinkingCloseTag) {
thinking, remaining := p.splitAtTag(qwen3ThinkingCloseTag, true)
if len(thinking) > 0 {
@@ -233,7 +215,7 @@ func (p *Qwen3Parser) eat() ([]qwen3Event, bool) {
p.state = qwen3ParserStateCollectingContent
}
return events, true
} else if overlapLen := max(overlap(acc, qwen3ThinkingCloseTag), overlap(acc, qwen3ToolOpenTag)); overlapLen > 0 {
} else if overlapLen := overlap(acc, qwen3ThinkingCloseTag); overlapLen > 0 {
beforePartialTag := acc[:len(acc)-overlapLen]
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWsLen

View File

@@ -146,68 +146,6 @@ func TestQwen3ParserToolCall(t *testing.T) {
}
}
func TestQwen3ParserThinkingWithToolCallBeforeThinkingClose(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
parser.Init(nil, nil, &api.ThinkValue{Value: true})
input := "Let me think<tool_call>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"San Francisco\",\"unit\":\"celsius\"}}</tool_call>"
content, thinking, calls, err := parser.Add(input, true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if content != "" {
t.Fatalf("expected empty content, got %q", content)
}
if thinking != "Let me think" {
t.Fatalf("expected thinking %q, got %q", "Let me think", thinking)
}
if len(calls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(calls))
}
if calls[0].Function.Name != "get_weather" {
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
}
}
func TestQwen3ParserThinkingWithSplitToolOpenTag(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
parser.Init(nil, nil, &api.ThinkValue{Value: true})
content, thinking, calls, err := parser.Add("Let me think<tool_ca", false)
if err != nil {
t.Fatalf("parse failed on first chunk: %v", err)
}
if content != "" || thinking != "Let me think" || len(calls) != 0 {
t.Fatalf(
"expected content=%q thinking=%q calls=%d, got content=%q thinking=%q calls=%d",
"",
"Let me think",
0,
content,
thinking,
len(calls),
)
}
content, thinking, calls, err = parser.Add("ll>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"SF\"}}</tool_call>", true)
if err != nil {
t.Fatalf("parse failed on second chunk: %v", err)
}
if content != "" {
t.Fatalf("expected empty content, got %q", content)
}
if thinking != "" {
t.Fatalf("expected no additional thinking on second chunk, got %q", thinking)
}
if len(calls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(calls))
}
if calls[0].Function.Name != "get_weather" {
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
}
}
func TestQwen35ParserRespectsNoThink(t *testing.T) {
parser := ParserForName("qwen3.5")
if parser == nil {

View File

@@ -180,22 +180,7 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
return events, false
}
case CollectingThinkingContent:
acc := p.buffer.String()
thinkingCloseIdx := strings.Index(acc, thinkingCloseTag)
toolOpenIdx := strings.Index(acc, toolOpenTag)
// If a tool call starts before </think>, treat that as the end of thinking
// for parsing purposes and continue in tool-call mode.
if toolOpenIdx != -1 && (thinkingCloseIdx == -1 || toolOpenIdx < thinkingCloseIdx) {
before, _ := splitAtTag(&p.buffer, toolOpenTag, false)
if len(before) > 0 {
events = append(events, qwenEventThinkingContent{content: before})
}
p.state = CollectingToolContent
return events, true
}
if strings.Contains(acc, thinkingCloseTag) {
if strings.Contains(p.buffer.String(), thinkingCloseTag) {
thinking, remaining := splitAtTag(&p.buffer, thinkingCloseTag, true)
if len(thinking) > 0 {
events = append(events, qwenEventThinkingContent{content: thinking})
@@ -206,13 +191,13 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
p.state = CollectingContent
}
return events, true
} else if overlapLen := max(overlap(acc, thinkingCloseTag), overlap(acc, toolOpenTag)); overlapLen > 0 {
beforePartialTag := acc[:len(acc)-overlapLen]
} else if overlapLen := overlap(p.buffer.String(), thinkingCloseTag); overlapLen > 0 {
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
@@ -220,11 +205,11 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
}
return events, false
} else {
whitespaceLen := trailingWhitespaceLen(acc)
ambiguousStart := len(acc) - whitespaceLen
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
ambiguousStart := len(p.buffer.String()) - whitespaceLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {

View File

@@ -98,12 +98,8 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
desc: "nested thinking and tool call (outside thinking, inside tool call)",
steps: []step{
{
input: "I'm thinking<tool_call>I'm nested tool call</tool_call></think>",
wantEvents: []qwenEvent{
qwenEventThinkingContent{content: "I'm thinking"},
qwenEventRawToolCall{raw: "I'm nested tool call"},
qwenEventContent{content: "</think>"},
},
input: "I'm thinking<tool_call>I'm nested tool call</tool_call></think>",
wantEvents: []qwenEvent{qwenEventThinkingContent{content: "I'm thinking<tool_call>I'm nested tool call</tool_call>"}},
},
},
},
@@ -113,7 +109,8 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
{
input: "<tool_call>I'm nested tool call<think>I'm thinking</think></tool_call>",
wantEvents: []qwenEvent{
qwenEventRawToolCall{raw: "I'm nested tool call<think>I'm thinking</think>"},
qwenEventThinkingContent{content: "<tool_call>I'm nested tool call<think>I'm thinking"},
qwenEventContent{content: "</tool_call>"},
},
},
},
@@ -124,8 +121,8 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
{
input: "I'm thinking<tool_call>I'm NOT a nested tool call</think></tool_call><tool_call>I'm nested tool call 2<think></tool_call></think>",
wantEvents: []qwenEvent{
qwenEventThinkingContent{content: "I'm thinking"},
qwenEventRawToolCall{raw: "I'm NOT a nested tool call</think>"},
qwenEventThinkingContent{content: "I'm thinking<tool_call>I'm NOT a nested tool call"},
qwenEventContent{content: "</tool_call>"},
qwenEventRawToolCall{raw: "I'm nested tool call 2<think>"},
qwenEventContent{content: "</think>"},
},

View File

@@ -71,6 +71,10 @@ type Model struct {
Template *template.Template
}
func (m *Model) IsMLX() bool {
return m.Config.ModelFormat == "safetensors"
}
// Capabilities returns the capabilities that the model supports
func (m *Model) Capabilities() []model.Capability {
capabilities := []model.Capability{}

View File

@@ -30,42 +30,44 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
lastMsgIdx := len(msgs) - 1
currMsgIdx := 0
// Start with all messages and remove from the front until it fits in context
for i := 0; i <= lastMsgIdx; i++ {
// Collect system messages from the portion we're about to skip
system = make([]api.Message, 0)
for j := range i {
if msgs[j].Role == "system" {
system = append(system, msgs[j])
if truncate {
// Start with all messages and remove from the front until it fits in context
for i := 0; i <= lastMsgIdx; i++ {
// Collect system messages from the portion we're about to skip
system = make([]api.Message, 0)
for j := range i {
if msgs[j].Role == "system" {
system = append(system, msgs[j])
}
}
}
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
if err != nil {
return "", nil, err
}
s, err := tokenize(ctx, p)
if err != nil {
return "", nil, err
}
ctxLen := len(s)
if m.ProjectorPaths != nil {
for _, msg := range msgs[i:] {
ctxLen += imageNumTokens * len(msg.Images)
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
if err != nil {
return "", nil, err
}
}
if !truncate || ctxLen <= opts.NumCtx {
currMsgIdx = i
break
}
s, err := tokenize(ctx, p)
if err != nil {
return "", nil, err
}
// Must always include at least the last message
if i == lastMsgIdx {
currMsgIdx = lastMsgIdx
break
ctxLen := len(s)
if m.ProjectorPaths != nil {
for _, msg := range msgs[i:] {
ctxLen += imageNumTokens * len(msg.Images)
}
}
if ctxLen <= opts.NumCtx {
currMsgIdx = i
break
}
// Must always include at least the last message
if i == lastMsgIdx {
currMsgIdx = lastMsgIdx
break
}
}
}

View File

@@ -484,7 +484,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
// the real chat handler, but doing this as a stopgap to get renderer
// support for generate
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate)
genTruncate := (req.Truncate == nil || *req.Truncate) && !m.IsMLX()
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, genTruncate)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -1951,6 +1952,9 @@ func (s *Server) PsHandler(c *gin.Context) {
}
if v.llama != nil {
mr.ContextLength = v.llama.ContextLength()
total, vram := v.llama.MemorySize()
mr.Size = int64(total)
mr.SizeVRAM = int64(vram)
}
// The scheduler waits to set expiresAt, so if a model is loading it's
// possible that it will be set to the unix epoch. For those cases, just
@@ -2213,6 +2217,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
truncate := req.Truncate == nil || *req.Truncate
if m.IsMLX() {
truncate = false
}
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
if err != nil {
slog.Error("chat prompt error", "error", err)

View File

@@ -231,7 +231,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
}
// Check for experimental safetensors LLM models
if pending.model.Config.ModelFormat == "safetensors" {
if pending.model.IsMLX() {
if slices.Contains(pending.model.Config.Capabilities, "completion") {
// LLM model with safetensors format - use MLX runner
if s.loadMLX(pending) {
@@ -536,6 +536,7 @@ iGPUScan:
}
}
totalSize, vramSize := llama.MemorySize()
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
@@ -545,8 +546,8 @@ iGPUScan:
sessionDuration: sessionDuration,
gpus: gpuIDs,
discreteGPUs: discreteGPUs,
vramSize: llama.VRAMSize(),
totalSize: llama.TotalSize(),
totalSize: totalSize,
vramSize: vramSize,
loading: true,
pid: llama.Pid(),
}
@@ -619,6 +620,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
sessionDuration = req.sessionDuration.Duration
}
totalSize, vramSize := server.MemorySize()
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
@@ -628,8 +630,8 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
loading: false,
isImagegen: isImagegen,
sessionDuration: sessionDuration,
totalSize: server.TotalSize(),
vramSize: server.VRAMSize(),
totalSize: totalSize,
vramSize: vramSize,
}
s.loadedMu.Lock()
@@ -762,7 +764,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
defer cancel()
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
(!runner.model.IsMLX() && !reflect.DeepEqual(optsExisting, optsNew)) || // have the runner options changed?
runner.llama.Ping(ctx) != nil {
return true
}

View File

@@ -861,8 +861,7 @@ func (s *mockLlm) Close() error {
s.closeCalled = true
return s.closeResp
}
func (s *mockLlm) VRAMSize() uint64 { return s.vramSize }
func (s *mockLlm) TotalSize() uint64 { return s.totalSize }
func (s *mockLlm) MemorySize() (uint64, uint64) { return s.totalSize, s.vramSize }
func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] }
func (s *mockLlm) Pid() int { return -1 }
func (s *mockLlm) GetPort() int { return -1 }

View File

@@ -374,14 +374,9 @@ func (s *Server) Close() error {
return nil
}
// VRAMSize returns the estimated VRAM usage.
func (s *Server) VRAMSize() uint64 {
return s.vramSize
}
// TotalSize returns the total memory usage.
func (s *Server) TotalSize() uint64 {
return s.vramSize
// MemorySize returns the total and VRAM memory usage.
func (s *Server) MemorySize() (total, vram uint64) {
return s.vramSize, s.vramSize
}
// VRAMByGPU returns VRAM usage for a specific GPU.

View File

@@ -8,7 +8,6 @@ import (
"fmt"
"io"
"log/slog"
"math"
"math/rand"
"net"
"net/http"
@@ -21,23 +20,24 @@ import (
"sync"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
)
// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models.
type Client struct {
port int
modelName string
vramSize uint64
done chan error
client *http.Client
lastErr string
lastErrLock sync.Mutex
mu sync.Mutex
cmd *exec.Cmd
port int
modelName string
contextLength int
memory uint
done chan error
client *http.Client
lastErr string
lastErrLock sync.Mutex
mu sync.Mutex
cmd *exec.Cmd
}
// NewClient spawns a new MLX runner subprocess for LLM models and waits until it's ready.
@@ -98,18 +98,9 @@ func NewClient(modelName string) (*Client, error) {
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
}
// Estimate VRAM based on tensor size from manifest
var vramSize uint64
if modelManifest, err := manifest.LoadManifest(modelName); err == nil {
vramSize = uint64(modelManifest.TotalTensorSize())
} else {
vramSize = 8 * 1024 * 1024 * 1024
}
c := &Client{
port: port,
modelName: modelName,
vramSize: vramSize,
done: make(chan error, 1),
client: &http.Client{Timeout: 10 * time.Minute},
cmd: cmd,
@@ -201,6 +192,19 @@ type completionOpts struct {
NumPredict int `json:"num_predict,omitempty"`
}
type CompletionResponse struct {
Content string
Done bool
DoneReason int
PromptEvalCount int
PromptEvalDuration time.Duration
EvalCount int
EvalDuration time.Duration
Error *api.StatusError
}
// Close terminates the subprocess.
func (c *Client) Close() error {
c.mu.Lock()
@@ -260,28 +264,24 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
var raw struct {
Content string `json:"content,omitempty"`
Done bool `json:"done"`
DoneReason int `json:"done_reason,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration int `json:"eval_duration,omitempty"`
}
var raw CompletionResponse
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
continue
}
if raw.Error != nil {
return *raw.Error
}
cresp := llm.CompletionResponse{
Content: raw.Content,
Done: raw.Done,
DoneReason: llm.DoneReason(raw.DoneReason),
PromptEvalCount: raw.PromptEvalCount,
PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
PromptEvalDuration: raw.PromptEvalDuration,
EvalCount: raw.EvalCount,
EvalDuration: time.Duration(raw.EvalDuration),
EvalDuration: raw.EvalDuration,
}
fn(cresp)
@@ -294,7 +294,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
}
func (c *Client) ContextLength() int {
return math.MaxInt
return c.contextLength
}
// Detokenize implements llm.LlamaServer.
@@ -347,9 +347,16 @@ func (c *Client) Pid() int {
return -1
}
type statusResponse struct {
Status int
Progress int
ContextLength int
Memory uint
}
// Ping implements llm.LlamaServer.
func (c *Client) Ping(ctx context.Context) error {
reqURL := fmt.Sprintf("http://127.0.0.1:%d/health", c.port)
reqURL := fmt.Sprintf("http://127.0.0.1:%d/v1/status", c.port)
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
if err != nil {
return err
@@ -362,6 +369,15 @@ func (c *Client) Ping(ctx context.Context) error {
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("health check failed: %d", resp.StatusCode)
}
var status statusResponse
if err := json.NewDecoder(resp.Body).Decode(&status); err != nil {
return err
}
c.contextLength = status.ContextLength
c.memory = status.Memory
return nil
}
@@ -388,19 +404,24 @@ func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
return tokens, nil
}
// TotalSize implements llm.LlamaServer.
func (c *Client) TotalSize() uint64 {
return c.vramSize
func (c *Client) currentMemory() uint64 {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := c.Ping(ctx); err != nil {
slog.Warn("failed to get current memory", "error", err)
}
return uint64(c.memory)
}
// MemorySize implements llm.LlamaServer.
func (c *Client) MemorySize() (total, vram uint64) {
mem := c.currentMemory()
return mem, mem
}
// VRAMByGPU implements llm.LlamaServer.
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
return c.vramSize
}
// VRAMSize implements llm.LlamaServer.
func (c *Client) VRAMSize() uint64 {
return c.vramSize
return c.currentMemory()
}
// WaitUntilRunning implements llm.LlamaServer.

View File

@@ -20,6 +20,7 @@ type Model interface {
Unembed(x *mlx.Array) *mlx.Array
NumLayers() int
Tokenizer() *tokenizer.Tokenizer
MaxContextLength() int
// LoadWeights receives all tensors loaded from the manifest and assigns
// them to model fields. Model-specific logic (MLA absorption, expert

View File

@@ -6,9 +6,12 @@ import (
"bytes"
"context"
"errors"
"fmt"
"log/slog"
"net/http"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
@@ -46,12 +49,28 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
}
inputs := r.Tokenizer.Encode(request.Prompt, true)
if len(inputs) >= r.contextLength {
return api.StatusError{
StatusCode: http.StatusBadRequest,
ErrorMessage: fmt.Sprintf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(inputs), r.contextLength),
}
}
// Cap generation to stay within the model's context length
maxGenerate := r.contextLength - len(inputs)
if request.Options.MaxTokens <= 0 {
request.Options.MaxTokens = maxGenerate
} else {
request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate)
}
session := r.cache.begin(r.Model, inputs)
defer session.close()
caches := session.caches
tokens := session.remaining
now := time.Now()
total, processed := len(tokens), 0
slog.Info("Prompt processing progress", "processed", processed, "total", total)
for total-processed > 1 {
@@ -93,8 +112,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
var b bytes.Buffer
now := time.Now()
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1}
for i := range request.Options.MaxTokens {
if err := request.Ctx.Err(); err != nil {
return err
@@ -105,7 +123,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
if i == 0 {
slog.Info("Prompt processing progress", "processed", total, "total", total)
mlx.Eval(sample)
final.PromptTokensDuration = time.Since(now)
final.PromptEvalDuration = time.Since(now)
now = time.Now()
}
@@ -113,18 +131,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
session.outputs = append(session.outputs, output)
if r.Tokenizer.IsEOS(output) {
final.Token = int(output)
final.DoneReason = 0
final.CompletionTokens = i
final.EvalCount = i
break
}
select {
case <-request.Ctx.Done():
return request.Ctx.Err()
case request.Responses <- Response{
Text: r.Decode(output, &b),
Token: int(output),
case request.Responses <- CompletionResponse{
Content: r.Decode(output, &b),
}:
}
@@ -137,7 +153,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
}
}
final.CompletionTokensDuration = time.Since(now)
final.EvalDuration = time.Since(now)
select {
case <-request.Ctx.Done():
return request.Ctx.Err()

View File

@@ -4,14 +4,15 @@ package mlxrunner
import (
"context"
"errors"
"log/slog"
"net"
"net/http"
"strings"
"time"
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
@@ -21,7 +22,7 @@ import (
type Request struct {
TextCompletionsRequest
Responses chan Response
Responses chan CompletionResponse
Pipeline func(Request) error
Ctx context.Context
@@ -43,25 +44,12 @@ type TextCompletionsRequest struct {
} `json:"options"`
}
type Response struct {
Text string `json:"content,omitempty"`
Token int `json:"token,omitempty"`
Logprobs []float32 `json:"logprobs,omitempty"`
Done bool `json:"done,omitempty"`
DoneReason int `json:"done_reason,omitempty"`
PromptTokens int `json:"prompt_eval_count,omitempty"`
PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"`
CompletionTokens int `json:"eval_count,omitempty"`
CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
}
type Runner struct {
Model base.Model
Tokenizer *tokenizer.Tokenizer
Requests chan Request
cache kvCache
Model base.Model
Tokenizer *tokenizer.Tokenizer
Requests chan Request
cache kvCache
contextLength int
}
func (r *Runner) Load(modelName string) error {
@@ -90,6 +78,7 @@ func (r *Runner) Load(modelName string) error {
r.Model = m
r.Tokenizer = m.Tokenizer()
r.contextLength = m.MaxContextLength()
return nil
}
@@ -158,6 +147,17 @@ func (r *Runner) Run(host, port string, mux http.Handler) error {
case request := <-r.Requests:
if err := request.Pipeline(request); err != nil {
slog.Info("Request terminated", "error", err)
var statusErr api.StatusError
if !errors.As(err, &statusErr) {
statusErr = api.StatusError{
StatusCode: http.StatusInternalServerError,
ErrorMessage: err.Error(),
}
}
select {
case request.Responses <- CompletionResponse{Error: &statusErr}:
case <-request.Ctx.Done():
}
}
close(request.Responses)

View File

@@ -50,9 +50,11 @@ func Execute(args []string) error {
mux := http.NewServeMux()
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
if err := json.NewEncoder(w).Encode(map[string]any{
"status": 0,
"progress": 100,
if err := json.NewEncoder(w).Encode(statusResponse{
Status: 0,
Progress: 100,
ContextLength: runner.contextLength,
Memory: uint(mlx.ActiveMemory() + mlx.CacheMemory()),
}); err != nil {
slog.Error("Failed to encode response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
@@ -78,7 +80,7 @@ func Execute(args []string) error {
})
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
request := Request{Responses: make(chan Response)}
request := Request{Responses: make(chan CompletionResponse)}
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
slog.Error("Failed to decode request", "error", err)
@@ -87,9 +89,6 @@ func Execute(args []string) error {
}
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
if request.Options.MaxTokens < 1 {
request.Options.MaxTokens = 16 << 10
}
request.Pipeline = runner.TextGenerationPipeline
request.Sampler = sample.New(

View File

@@ -430,6 +430,10 @@ func (m *Model) NumLayers() int {
return len(m.Layers)
}
func (m *Model) MaxContextLength() int {
return int(m.MaxPositionEmbeddings)
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
return m.tok
}

View File

@@ -733,7 +733,7 @@ func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
func (m *Model) NumLayers() int { return len(m.Layers) }
// MaxContextLength returns the maximum context length
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
func (m *Model) MaxContextLength() int { return int(m.MaxPositionEmbeddings) }
// VocabSize returns the vocabulary size
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }

View File

@@ -262,6 +262,10 @@ func (m *Model) NumLayers() int {
return len(m.Layers)
}
func (m *Model) MaxContextLength() int {
return int(m.MaxPositionEmbeddings)
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
return m.tok
}

View File

@@ -279,6 +279,10 @@ func (m *Model) NumLayers() int {
return len(m.Layers)
}
func (m *Model) MaxContextLength() int {
return int(m.MaxPositionEmbeddings)
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
return m.tok
}