Compare commits

..

1 Commits

Author SHA1 Message Date
Jesse Gross
4d5ff25724 mlxrunner: Report actual memory usage from runner
The MLX runner previously reported a static VRAM estimate that was
computed at load time and consisted only of the weights. This is
strictly less than the actual memory usage, as it does not include
the KV cache or compute graph.
2026-02-25 15:06:37 -08:00
35 changed files with 165 additions and 895 deletions

View File

@@ -15,7 +15,6 @@ import (
"github.com/google/uuid"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/internal/orderedmap"
"github.com/ollama/ollama/types/model"
)
@@ -570,7 +569,6 @@ type DebugInfo struct {
type Metrics struct {
TotalDuration time.Duration `json:"total_duration,omitempty"`
PeakMemory uint64 `json:"peak_memory,omitempty"`
LoadDuration time.Duration `json:"load_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
@@ -936,10 +934,6 @@ func (m *Metrics) Summary() {
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
}
if m.PeakMemory > 0 {
fmt.Fprintf(os.Stderr, "peak memory: %s\n", formatPeakMemory(m.PeakMemory))
}
if m.LoadDuration > 0 {
fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration)
}
@@ -963,14 +957,6 @@ func (m *Metrics) Summary() {
}
}
func formatPeakMemory(b uint64) string {
if b >= format.GibiByte {
return fmt.Sprintf("%.3f GiB", float64(b)/float64(format.GibiByte))
}
return format.HumanBytes2(b)
}
func (opts *Options) FromMap(m map[string]any) error {
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct

View File

@@ -35,7 +35,6 @@ import (
var (
wv = &Webview{}
uiServerPort int
appStore *store.Store
)
var debug = strings.EqualFold(os.Getenv("OLLAMA_DEBUG"), "true") || os.Getenv("OLLAMA_DEBUG") == "1"
@@ -209,7 +208,6 @@ func main() {
uiServerPort = port
st := &store.Store{}
appStore = st
// Enable CORS in development mode
if devMode {
@@ -296,15 +294,8 @@ func main() {
// Check for pending updates on startup (show tray notification if update is ready)
if updater.IsUpdatePending() {
// On Windows, the tray is initialized in osRun(). Calling UpdateAvailable
// before that would dereference a nil tray callback.
// TODO: refactor so the update check runs after platform init on all platforms.
if runtime.GOOS == "windows" {
slog.Debug("update pending on startup, deferring tray notification until tray initialization")
} else {
slog.Debug("update pending on startup, showing tray notification")
UpdateAvailable("")
}
slog.Debug("update pending on startup, showing tray notification")
UpdateAvailable("")
}
hasCompletedFirstRun, err := st.HasCompletedFirstRun()
@@ -369,7 +360,8 @@ func startHiddenTasks() {
slog.Info("deferring pending update for fast startup")
} else {
// Check if auto-update is enabled before automatically upgrading
settings, err := appStore.Settings()
st := &store.Store{}
settings, err := st.Settings()
if err != nil {
slog.Warn("failed to load settings for upgrade check", "error", err)
} else if !settings.AutoUpdateEnabled {

View File

@@ -154,10 +154,6 @@ func handleURLSchemeRequest(urlScheme string) {
}
func UpdateAvailable(ver string) error {
if app.t == nil {
slog.Debug("tray not yet initialized, skipping update notification")
return nil
}
return app.t.UpdateAvailable(ver)
}
@@ -169,14 +165,6 @@ func osRun(shutdown func(), hasCompletedFirstRun, startHidden bool) {
log.Fatalf("Failed to start: %s", err)
}
// Check for pending updates now that the tray is initialized.
// The platform-independent check in app.go fires before osRun,
// when app.t is still nil, so we must re-check here.
if updater.IsUpdatePending() {
slog.Debug("update pending on startup, showing tray notification")
UpdateAvailable("")
}
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)

View File

@@ -289,7 +289,6 @@ func (u *Updater) TriggerImmediateCheck() {
func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) {
u.checkNow = make(chan struct{}, 1)
u.checkNow <- struct{}{} // Trigger first check after initial delay
go func() {
// Don't blast an update message immediately after startup
time.Sleep(UpdateCheckInitialDelay)
@@ -334,7 +333,7 @@ func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(str
continue
}
// Download successful - show tray notification
// Download successful - show tray notification (regardless of toggle state)
err = cb(resp.UpdateVersion)
if err != nil {
slog.Warn("failed to register update available with tray", "error", err)

View File

@@ -351,13 +351,10 @@ func TestTriggerImmediateCheck(t *testing.T) {
updater.StartBackgroundUpdaterChecker(ctx, cb)
// Wait for the initial check that fires after the initial delay
select {
case <-checkDone:
case <-time.After(2 * time.Second):
t.Fatal("initial check did not happen")
}
// Wait for goroutine to start and pass initial delay
time.Sleep(10 * time.Millisecond)
// With 1 hour interval, no check should have happened yet
initialCount := checkCount.Load()
// Trigger immediate check

View File

@@ -1518,7 +1518,6 @@ type CompletionResponse struct {
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
EvalCount int `json:"eval_count"`
EvalDuration time.Duration `json:"eval_duration"`
PeakMemory uint64 `json:"peak_memory,omitempty"`
// Logprobs contains log probability information if requested
Logprobs []Logprob `json:"logprobs,omitempty"`

View File

@@ -41,8 +41,8 @@ type GatedDeltaNet struct {
SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35)
SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35)
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
SSMDT ml.Tensor `gguf:"ssm_dt,alt:ssm_dt.bias"` // alpha bias
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
SSMOut *nn.Linear `gguf:"ssm_out"`
@@ -135,18 +135,6 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
default:
return nil, errors.New("qwen3next: missing linear attention beta/alpha projections")
}
if gdn.SSMDT == nil {
return nil, errors.New("qwen3next: missing linear attention ssm_dt tensor")
}
if gdn.SSMA == nil {
return nil, errors.New("qwen3next: missing linear attention ssm_a tensor")
}
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
return nil, errors.New("qwen3next: missing linear attention ssm_conv1d tensor")
}
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
return nil, errors.New("qwen3next: missing linear attention ssm_norm/ssm_out projections")
}
// Compute gate: softplus(alpha + dt_bias) * -A
alphaBiased := alpha.Add(ctx, gdn.SSMDT)

View File

@@ -437,46 +437,6 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
return m.Output.Forward(ctx, hiddenStates), nil
}
func (m *Model) Validate() error {
if m.Options == nil {
return fmt.Errorf("qwen3next: missing model options")
}
if len(m.Layers) != len(m.Options.isRecurrent) {
return fmt.Errorf("qwen3next: layer config mismatch: have %d layers, %d recurrent flags", len(m.Layers), len(m.Options.isRecurrent))
}
for i, layer := range m.Layers {
if !m.Options.isRecurrent[i] {
continue
}
gdn, ok := layer.Operator.(*GatedDeltaNet)
if !ok || gdn == nil {
return fmt.Errorf("qwen3next: layer %d expected recurrent operator", i)
}
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
return fmt.Errorf("qwen3next: layer %d missing attn_qkv/attn_gate projections", i)
}
if gdn.SSMBetaAlpha == nil && (gdn.SSMBeta == nil || gdn.SSMAlpha == nil) {
return fmt.Errorf("qwen3next: layer %d missing linear attention beta/alpha projections", i)
}
if gdn.SSMDT == nil {
return fmt.Errorf("qwen3next: layer %d missing ssm_dt tensor", i)
}
if gdn.SSMA == nil {
return fmt.Errorf("qwen3next: layer %d missing ssm_a tensor", i)
}
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
return fmt.Errorf("qwen3next: layer %d missing ssm_conv1d tensor", i)
}
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
return fmt.Errorf("qwen3next: layer %d missing ssm_norm/ssm_out projections", i)
}
}
return nil
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
m.positionCache = nil
if len(m.mropeSections) > 0 {
@@ -490,64 +450,6 @@ var (
_ model.MultimodalProcessor = (*Model)(nil)
)
func defaultVHeadReordered(arch string) bool {
return arch == "qwen35" || arch == "qwen35moe"
}
func inferRecurrentLayers(headCountKV []uint64, numLayers int, fullAttentionInterval uint32) ([]bool, error) {
isRecurrent := make([]bool, numLayers)
hasZero := false
hasFull := false
for i := range numLayers {
if i >= len(headCountKV) {
continue
}
if headCountKV[i] == 0 {
isRecurrent[i] = true
hasZero = true
} else {
hasFull = true
}
}
if hasZero && hasFull {
return isRecurrent, nil
}
if !hasFull {
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
}
// Compatibility path: older imports store a scalar KV head count and omit
// per-layer recurrent flags. Derive the hybrid layout from the interval.
interval := int(fullAttentionInterval)
if interval == 0 {
interval = min(4, numLayers)
}
if interval <= 0 {
return nil, fmt.Errorf("qwen3next: invalid block_count (%d)", numLayers)
}
if interval > numLayers {
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds block_count (%d)", interval, numLayers)
}
hasZero = false
hasFull = false
for i := range numLayers {
isRecurrent[i] = (i+1)%interval != 0
if isRecurrent[i] {
hasZero = true
} else {
hasFull = true
}
}
if !hasZero || !hasFull {
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) does not produce a mixed recurrent/full layout", interval)
}
return isRecurrent, nil
}
func New(c fs.Config) (model.Model, error) {
numLayers := int(c.Uint("block_count"))
layers := make([]Layer, numLayers)
@@ -558,14 +460,26 @@ func New(c fs.Config) (model.Model, error) {
HeadCountKV() []uint64
}
var isRecurrent []bool
var headCountKV []uint64
if hc, ok := c.(headCounts); ok {
headCountKV = hc.HeadCountKV()
}
isRecurrent, err := inferRecurrentLayers(headCountKV, numLayers, c.Uint("full_attention_interval"))
if err != nil {
return nil, err
isRecurrent = make([]bool, numLayers)
hasZero := false
hasFull := false
for i := range numLayers {
// If KV head count is 0, it's a recurrent layer
if i < len(headCountKV) && headCountKV[i] == 0 {
isRecurrent[i] = true
hasZero = true
} else if i < len(headCountKV) && headCountKV[i] > 0 {
hasFull = true
}
}
if !hasZero || !hasFull {
return nil, fmt.Errorf("qwen3next: invalid attention.head_count_kv array; expected mix of zero and non-zero values")
}
// Determine if MoE
@@ -629,7 +543,7 @@ func New(c fs.Config) (model.Model, error) {
ssmNGroup: int(c.Uint("ssm.group_count")),
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
convKernelSize: int(c.Uint("ssm.conv_kernel")),
vHeadReordered: c.Bool("ssm.v_head_reordered", defaultVHeadReordered(c.Architecture())),
vHeadReordered: c.Bool("ssm.v_head_reordered", false),
isRecurrent: isRecurrent,
mropeSections: slices.Collect(func(yield func(int) bool) {
for _, section := range mropeSections {
@@ -641,7 +555,7 @@ func New(c fs.Config) (model.Model, error) {
mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)),
}
if opts.numKVHeads == 0 {
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value")
}
// Calculate cache dimensions

View File

@@ -1,65 +0,0 @@
package qwen3next
import (
"slices"
"strings"
"testing"
)
func TestInferRecurrentLayersMixedKVArray(t *testing.T) {
got, err := inferRecurrentLayers([]uint64{0, 2, 0, 2}, 4, 0)
if err != nil {
t.Fatalf("inferRecurrentLayers() error = %v", err)
}
want := []bool{true, false, true, false}
if !slices.Equal(got, want) {
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
}
}
func TestInferRecurrentLayersScalarKVDefaultInterval(t *testing.T) {
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2, 2, 2}, 8, 0)
if err != nil {
t.Fatalf("inferRecurrentLayers() error = %v", err)
}
want := []bool{true, true, true, false, true, true, true, false}
if !slices.Equal(got, want) {
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
}
}
func TestInferRecurrentLayersScalarKVConfiguredInterval(t *testing.T) {
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2}, 6, 3)
if err != nil {
t.Fatalf("inferRecurrentLayers() error = %v", err)
}
want := []bool{true, true, false, true, true, false}
if !slices.Equal(got, want) {
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
}
}
func TestInferRecurrentLayersAllZeroRejects(t *testing.T) {
_, err := inferRecurrentLayers([]uint64{0, 0, 0, 0}, 4, 0)
if err == nil {
t.Fatal("inferRecurrentLayers() expected error, got nil")
}
if !strings.Contains(err.Error(), "must include at least one non-zero value") {
t.Fatalf("unexpected error = %v", err)
}
}
func TestDefaultVHeadReordered(t *testing.T) {
if !defaultVHeadReordered("qwen35") {
t.Fatal("defaultVHeadReordered(qwen35) = false, want true")
}
if !defaultVHeadReordered("qwen35moe") {
t.Fatal("defaultVHeadReordered(qwen35moe) = false, want true")
}
if defaultVHeadReordered("qwen3next") {
t.Fatal("defaultVHeadReordered(qwen3next) = true, want false")
}
}

View File

@@ -1,45 +0,0 @@
package qwen3next
import (
"strings"
"testing"
"github.com/ollama/ollama/ml/nn"
)
func TestValidateRecurrentLayerRequiresSSMDT(t *testing.T) {
m := &Model{
Layers: []Layer{{
Operator: &GatedDeltaNet{
SSMQKV: &nn.Linear{},
SSMQKVGate: &nn.Linear{},
SSMBeta: &nn.Linear{},
SSMAlpha: &nn.Linear{},
},
}},
Options: &Options{
isRecurrent: []bool{true},
},
}
err := m.Validate()
if err == nil {
t.Fatal("Validate() expected error, got nil")
}
if !strings.Contains(err.Error(), "missing ssm_dt") {
t.Fatalf("unexpected error = %v", err)
}
}
func TestValidateNonRecurrentSkipsLinearChecks(t *testing.T) {
m := &Model{
Layers: []Layer{{Operator: &FullAttention{}}},
Options: &Options{
isRecurrent: []bool{false},
},
}
if err := m.Validate(); err != nil {
t.Fatalf("Validate() error = %v", err)
}
}

View File

@@ -32,10 +32,9 @@ const (
)
type GLM46Parser struct {
state glm46ParserState
buffer strings.Builder
tools []api.Tool
callIndex int
state glm46ParserState
buffer strings.Builder
tools []api.Tool
}
func (p *GLM46Parser) HasToolSupport() bool {
@@ -49,7 +48,6 @@ func (p *GLM46Parser) HasThinkingSupport() bool {
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
p.callIndex = 0
return tools
}
@@ -91,8 +89,6 @@ func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string,
slog.Warn("glm-4.6 tool call parsing failed", "error", err)
return "", "", nil, err
}
toolCall.Function.Index = p.callIndex
p.callIndex++
toolCalls = append(toolCalls, toolCall)
case glm46EventThinkingContent:
thinkingSb.WriteString(event.content)

View File

@@ -11,7 +11,6 @@ type GLM47Parser struct {
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
p.callIndex = 0
// When thinking is enabled (nil or true), the prompt ends with <think>,
// so model output starts directly with thinking content (no opening tag).
if thinkValue == nil || thinkValue.Bool() {

View File

@@ -97,91 +97,3 @@ func TestGLM47ParserToolCallEscaping(t *testing.T) {
t.Fatalf("expected %#v, got %#v", expected, toolCall)
}
}
func TestGLM47ParserToolCallIndexing(t *testing.T) {
parser := GLM47Parser{}
parser.Init(nil, nil, nil)
input := `plan</think>
<tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>
<tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>
<tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>`
_, _, calls, err := parser.Add(input, true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
}
if len(calls) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
}
for i := range want {
if !toolCallEqual(calls[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
}
}
}
func TestGLM47ParserToolCallIndexingStreaming(t *testing.T) {
parser := GLM47Parser{}
parser.Init(nil, nil, nil)
var all []api.ToolCall
_, _, calls, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call><tool_call>second<arg_key>b</arg_key>", false)
if err != nil {
t.Fatalf("step 1 parse failed: %v", err)
}
all = append(all, calls...)
_, _, calls, err = parser.Add("<arg_value>2</arg_value></tool_call><tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>", true)
if err != nil {
t.Fatalf("step 2 parse failed: %v", err)
}
all = append(all, calls...)
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
}
if len(all) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(all))
}
for i := range want {
if !toolCallEqual(all[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
}
}
}
func TestGLM47ParserToolCallIndexResetOnInit(t *testing.T) {
parser := GLM47Parser{}
parser.Init(nil, nil, nil)
_, _, _, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>", true)
if err != nil {
t.Fatalf("first parse failed: %v", err)
}
parser.Init(nil, nil, nil)
_, _, calls, err := parser.Add("plan</think><tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>", true)
if err != nil {
t.Fatalf("second parse failed: %v", err)
}
want := api.ToolCall{
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
}
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if !toolCallEqual(calls[0], want) {
t.Fatalf("got %#v, want %#v", calls[0], want)
}
}

View File

@@ -38,7 +38,6 @@ type Qwen3Parser struct {
state qwen3ParserState
buffer strings.Builder
tools []api.Tool
callIndex int
hasThinkingSupport bool
defaultThinking bool
maybeThinkingOpenAtBOL bool
@@ -55,7 +54,6 @@ func (p *Qwen3Parser) HasThinkingSupport() bool {
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
p.buffer.Reset()
p.callIndex = 0
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
if thinkValue == nil {
@@ -108,8 +106,6 @@ func (p *Qwen3Parser) Add(s string, done bool) (content string, thinking string,
slog.Warn("qwen3 tool call parsing failed", "error", err)
return "", "", nil, err
}
toolCall.Function.Index = p.callIndex
p.callIndex++
calls = append(calls, toolCall)
case qwen3EventThinkingContent:
thinkingSb.WriteString(event.content)
@@ -208,24 +204,6 @@ func (p *Qwen3Parser) eat() ([]qwen3Event, bool) {
p.maybeThinkingOpenAtBOL = false
}
thinkingCloseIdx := strings.Index(acc, qwen3ThinkingCloseTag)
toolOpenIdx := strings.Index(acc, qwen3ToolOpenTag)
// If a tool call starts before </think>, treat that as the end of thinking
// for parsing purposes and continue in tool-call mode.
if toolOpenIdx != -1 && (thinkingCloseIdx == -1 || toolOpenIdx < thinkingCloseIdx) {
before, after := p.splitAtTag(qwen3ToolOpenTag, true)
if len(before) > 0 {
events = append(events, qwen3EventThinkingContent{content: before})
}
if after == "" {
p.state = qwen3ParserStateToolStartedEatingWhitespace
} else {
p.state = qwen3ParserStateCollectingToolContent
}
return events, true
}
if strings.Contains(acc, qwen3ThinkingCloseTag) {
thinking, remaining := p.splitAtTag(qwen3ThinkingCloseTag, true)
if len(thinking) > 0 {
@@ -237,7 +215,7 @@ func (p *Qwen3Parser) eat() ([]qwen3Event, bool) {
p.state = qwen3ParserStateCollectingContent
}
return events, true
} else if overlapLen := max(overlap(acc, qwen3ThinkingCloseTag), overlap(acc, qwen3ToolOpenTag)); overlapLen > 0 {
} else if overlapLen := overlap(acc, qwen3ThinkingCloseTag); overlapLen > 0 {
beforePartialTag := acc[:len(acc)-overlapLen]
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWsLen

View File

@@ -146,68 +146,6 @@ func TestQwen3ParserToolCall(t *testing.T) {
}
}
func TestQwen3ParserThinkingWithToolCallBeforeThinkingClose(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
parser.Init(nil, nil, &api.ThinkValue{Value: true})
input := "Let me think<tool_call>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"San Francisco\",\"unit\":\"celsius\"}}</tool_call>"
content, thinking, calls, err := parser.Add(input, true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if content != "" {
t.Fatalf("expected empty content, got %q", content)
}
if thinking != "Let me think" {
t.Fatalf("expected thinking %q, got %q", "Let me think", thinking)
}
if len(calls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(calls))
}
if calls[0].Function.Name != "get_weather" {
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
}
}
func TestQwen3ParserThinkingWithSplitToolOpenTag(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
parser.Init(nil, nil, &api.ThinkValue{Value: true})
content, thinking, calls, err := parser.Add("Let me think<tool_ca", false)
if err != nil {
t.Fatalf("parse failed on first chunk: %v", err)
}
if content != "" || thinking != "Let me think" || len(calls) != 0 {
t.Fatalf(
"expected content=%q thinking=%q calls=%d, got content=%q thinking=%q calls=%d",
"",
"Let me think",
0,
content,
thinking,
len(calls),
)
}
content, thinking, calls, err = parser.Add("ll>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"SF\"}}</tool_call>", true)
if err != nil {
t.Fatalf("parse failed on second chunk: %v", err)
}
if content != "" {
t.Fatalf("expected empty content, got %q", content)
}
if thinking != "" {
t.Fatalf("expected no additional thinking on second chunk, got %q", thinking)
}
if len(calls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(calls))
}
if calls[0].Function.Name != "get_weather" {
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
}
}
func TestQwen35ParserRespectsNoThink(t *testing.T) {
parser := ParserForName("qwen3.5")
if parser == nil {
@@ -230,89 +168,3 @@ func TestQwen35ParserRespectsNoThink(t *testing.T) {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen3ParserToolCallIndexing(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
input := `<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>
<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>
<tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`
_, _, calls, err := parser.Add(input, true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
}
if len(calls) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
}
for i := range want {
if !toolCallEqual(calls[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
}
}
}
func TestQwen3ParserToolCallIndexingStreaming(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
var all []api.ToolCall
_, _, calls, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call><tool_call>{"name":"second","arguments":{"b":"2"}`, false)
if err != nil {
t.Fatalf("step 1 parse failed: %v", err)
}
all = append(all, calls...)
_, _, calls, err = parser.Add(`}</tool_call><tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`, true)
if err != nil {
t.Fatalf("step 2 parse failed: %v", err)
}
all = append(all, calls...)
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
}
if len(all) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(all))
}
for i := range want {
if !toolCallEqual(all[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
}
}
}
func TestQwen3ParserToolCallIndexResetOnInit(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
_, _, _, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>`, true)
if err != nil {
t.Fatalf("first parse failed: %v", err)
}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
_, _, calls, err := parser.Add(`<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>`, true)
if err != nil {
t.Fatalf("second parse failed: %v", err)
}
want := api.ToolCall{
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
}
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if !toolCallEqual(calls[0], want) {
t.Fatalf("got %#v, want %#v", calls[0], want)
}
}

View File

@@ -29,10 +29,9 @@ const (
)
type Qwen3CoderParser struct {
state qwenParserState
acc strings.Builder
tools []api.Tool
callIndex int
state qwenParserState
acc strings.Builder
tools []api.Tool
}
func (p *Qwen3CoderParser) HasToolSupport() bool {
@@ -45,7 +44,6 @@ func (p *Qwen3CoderParser) HasThinkingSupport() bool {
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
p.callIndex = 0
return tools // Qwen doesn't modify tools
}
@@ -64,8 +62,6 @@ func (p *Qwen3CoderParser) Add(s string, done bool) (content string, thinking st
slog.Warn("qwen tool call parsing failed", "error", err)
return "", "", nil, err
}
toolCall.Function.Index = p.callIndex
p.callIndex++
toolCalls = append(toolCalls, toolCall)
case qwenEventContent:
// TODO(drifkin): if the same turn contains multiple interleaved content

View File

@@ -1035,92 +1035,6 @@ func TestQwenToolCallValueParsing(t *testing.T) {
}
}
func TestQwen3CoderParserToolCallIndexing(t *testing.T) {
parser := Qwen3CoderParser{}
parser.Init(nil, nil, nil)
input := `<tool_call><function=first><parameter=a>1</parameter></function></tool_call>
<tool_call><function=second><parameter=b>2</parameter></function></tool_call>
<tool_call><function=third><parameter=c>3</parameter></function></tool_call>`
_, _, calls, err := parser.Add(input, true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
}
if len(calls) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
}
for i := range want {
if !toolCallEqual(calls[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
}
}
}
func TestQwen3CoderParserToolCallIndexingStreaming(t *testing.T) {
parser := Qwen3CoderParser{}
parser.Init(nil, nil, nil)
var all []api.ToolCall
_, _, calls, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call><tool_call><function=second>", false)
if err != nil {
t.Fatalf("step 1 parse failed: %v", err)
}
all = append(all, calls...)
_, _, calls, err = parser.Add("<parameter=b>2</parameter></function></tool_call><tool_call><function=third><parameter=c>3</parameter></function></tool_call>", true)
if err != nil {
t.Fatalf("step 2 parse failed: %v", err)
}
all = append(all, calls...)
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
}
if len(all) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(all))
}
for i := range want {
if !toolCallEqual(all[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
}
}
}
func TestQwen3CoderParserToolCallIndexResetOnInit(t *testing.T) {
parser := Qwen3CoderParser{}
parser.Init(nil, nil, nil)
_, _, _, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call>", true)
if err != nil {
t.Fatalf("first parse failed: %v", err)
}
parser.Init(nil, nil, nil)
_, _, calls, err := parser.Add("<tool_call><function=second><parameter=b>2</parameter></function></tool_call>", true)
if err != nil {
t.Fatalf("second parse failed: %v", err)
}
want := api.ToolCall{
Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 0},
}
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if !toolCallEqual(calls[0], want) {
t.Fatalf("got %#v, want %#v", calls[0], want)
}
}
func TestQwenXMLTransform(t *testing.T) {
cases := []struct {
desc string

View File

@@ -180,22 +180,7 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
return events, false
}
case CollectingThinkingContent:
acc := p.buffer.String()
thinkingCloseIdx := strings.Index(acc, thinkingCloseTag)
toolOpenIdx := strings.Index(acc, toolOpenTag)
// If a tool call starts before </think>, treat that as the end of thinking
// for parsing purposes and continue in tool-call mode.
if toolOpenIdx != -1 && (thinkingCloseIdx == -1 || toolOpenIdx < thinkingCloseIdx) {
before, _ := splitAtTag(&p.buffer, toolOpenTag, false)
if len(before) > 0 {
events = append(events, qwenEventThinkingContent{content: before})
}
p.state = CollectingToolContent
return events, true
}
if strings.Contains(acc, thinkingCloseTag) {
if strings.Contains(p.buffer.String(), thinkingCloseTag) {
thinking, remaining := splitAtTag(&p.buffer, thinkingCloseTag, true)
if len(thinking) > 0 {
events = append(events, qwenEventThinkingContent{content: thinking})
@@ -206,13 +191,13 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
p.state = CollectingContent
}
return events, true
} else if overlapLen := max(overlap(acc, thinkingCloseTag), overlap(acc, toolOpenTag)); overlapLen > 0 {
beforePartialTag := acc[:len(acc)-overlapLen]
} else if overlapLen := overlap(p.buffer.String(), thinkingCloseTag); overlapLen > 0 {
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
@@ -220,11 +205,11 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
}
return events, false
} else {
whitespaceLen := trailingWhitespaceLen(acc)
ambiguousStart := len(acc) - whitespaceLen
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
ambiguousStart := len(p.buffer.String()) - whitespaceLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {

View File

@@ -98,12 +98,8 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
desc: "nested thinking and tool call (outside thinking, inside tool call)",
steps: []step{
{
input: "I'm thinking<tool_call>I'm nested tool call</tool_call></think>",
wantEvents: []qwenEvent{
qwenEventThinkingContent{content: "I'm thinking"},
qwenEventRawToolCall{raw: "I'm nested tool call"},
qwenEventContent{content: "</think>"},
},
input: "I'm thinking<tool_call>I'm nested tool call</tool_call></think>",
wantEvents: []qwenEvent{qwenEventThinkingContent{content: "I'm thinking<tool_call>I'm nested tool call</tool_call>"}},
},
},
},
@@ -113,7 +109,8 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
{
input: "<tool_call>I'm nested tool call<think>I'm thinking</think></tool_call>",
wantEvents: []qwenEvent{
qwenEventRawToolCall{raw: "I'm nested tool call<think>I'm thinking</think>"},
qwenEventThinkingContent{content: "<tool_call>I'm nested tool call<think>I'm thinking"},
qwenEventContent{content: "</tool_call>"},
},
},
},
@@ -124,8 +121,8 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
{
input: "I'm thinking<tool_call>I'm NOT a nested tool call</think></tool_call><tool_call>I'm nested tool call 2<think></tool_call></think>",
wantEvents: []qwenEvent{
qwenEventThinkingContent{content: "I'm thinking"},
qwenEventRawToolCall{raw: "I'm NOT a nested tool call</think>"},
qwenEventThinkingContent{content: "I'm thinking<tool_call>I'm NOT a nested tool call"},
qwenEventContent{content: "</tool_call>"},
qwenEventRawToolCall{raw: "I'm nested tool call 2<think>"},
qwenEventContent{content: "</think>"},
},

View File

@@ -71,10 +71,6 @@ type Model struct {
Template *template.Template
}
func (m *Model) IsMLX() bool {
return m.Config.ModelFormat == "safetensors"
}
// Capabilities returns the capabilities that the model supports
func (m *Model) Capabilities() []model.Capability {
capabilities := []model.Capability{}

View File

@@ -30,44 +30,42 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
lastMsgIdx := len(msgs) - 1
currMsgIdx := 0
if truncate {
// Start with all messages and remove from the front until it fits in context
for i := 0; i <= lastMsgIdx; i++ {
// Collect system messages from the portion we're about to skip
system = make([]api.Message, 0)
for j := range i {
if msgs[j].Role == "system" {
system = append(system, msgs[j])
}
// Start with all messages and remove from the front until it fits in context
for i := 0; i <= lastMsgIdx; i++ {
// Collect system messages from the portion we're about to skip
system = make([]api.Message, 0)
for j := range i {
if msgs[j].Role == "system" {
system = append(system, msgs[j])
}
}
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
if err != nil {
return "", nil, err
}
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
if err != nil {
return "", nil, err
}
s, err := tokenize(ctx, p)
if err != nil {
return "", nil, err
}
s, err := tokenize(ctx, p)
if err != nil {
return "", nil, err
}
ctxLen := len(s)
if m.ProjectorPaths != nil {
for _, msg := range msgs[i:] {
ctxLen += imageNumTokens * len(msg.Images)
}
ctxLen := len(s)
if m.ProjectorPaths != nil {
for _, msg := range msgs[i:] {
ctxLen += imageNumTokens * len(msg.Images)
}
}
if ctxLen <= opts.NumCtx {
currMsgIdx = i
break
}
if !truncate || ctxLen <= opts.NumCtx {
currMsgIdx = i
break
}
// Must always include at least the last message
if i == lastMsgIdx {
currMsgIdx = lastMsgIdx
break
}
// Must always include at least the last message
if i == lastMsgIdx {
currMsgIdx = lastMsgIdx
break
}
}

View File

@@ -21,76 +21,33 @@ type quantizer struct {
progressFn func(n uint64)
}
const quantizationChunkElements uint64 = 4 * 1024 * 1024
func (q quantizer) WriteTo(w io.Writer) (int64, error) {
quantize := q.from.Kind != q.to.Kind
sr := io.NewSectionReader(q, int64(q.offset), int64(q.from.Size()))
if !quantize {
n, err := io.Copy(w, sr)
if q.progressFn != nil {
q.progressFn(q.from.Size())
}
q.progressFn(q.from.Size())
return n, err
}
if len(q.from.Shape) == 0 || q.from.Shape[0] == 0 {
return 0, fmt.Errorf("tensor %s has invalid shape %v", q.from.Name, q.from.Shape)
data, err := io.ReadAll(sr)
if err != nil {
slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err)
return 0, fmt.Errorf("unable to read tensor %s from %s: %s", q.from.Name, q.Name(), err)
}
fromType := fsggml.TensorType(q.from.Kind)
toType := fsggml.TensorType(q.to.Kind)
nPerRow := q.from.Shape[0]
totalElements := q.from.Elements()
if totalElements%nPerRow != 0 {
return 0, fmt.Errorf("tensor %s has non-row-aligned shape %v", q.from.Name, q.from.Shape)
if uint64(len(data)) < q.from.Size() {
return 0, fmt.Errorf("tensor %s data size %d is less than expected %d from shape %v", q.from.Name, len(data), q.from.Size(), q.from.Shape)
}
inRowSize := fromType.RowSize(nPerRow)
if inRowSize == 0 {
return 0, fmt.Errorf("tensor %s has unsupported source type %v", q.from.Name, fromType)
var f32s []float32
newType := fsggml.TensorType(q.to.Kind)
if fsggml.TensorType(q.from.Kind) == fsggml.TensorTypeF32 {
f32s = unsafe.Slice((*float32)(unsafe.Pointer(&data[0])), q.from.Elements())
} else {
f32s = ggml.ConvertToF32(data, q.from.Kind, q.from.Elements())
}
totalRows := totalElements / nPerRow
rowsPerChunk := max(quantizationChunkElements/nPerRow, uint64(1))
chunkBuf := make([]byte, inRowSize*rowsPerChunk)
var written int64
for row := uint64(0); row < totalRows; {
chunkRows := min(rowsPerChunk, totalRows-row)
chunkBytes := inRowSize * chunkRows
data := chunkBuf[:chunkBytes]
if _, err := io.ReadFull(sr, data); err != nil {
slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err)
return written, fmt.Errorf("unable to read tensor %s from %s: %w", q.from.Name, q.Name(), err)
}
var f32s []float32
chunkElements := chunkRows * nPerRow
if fromType == fsggml.TensorTypeF32 {
f32s = unsafe.Slice((*float32)(unsafe.Pointer(&data[0])), chunkElements)
} else {
f32s = ggml.ConvertToF32(data, q.from.Kind, chunkElements)
}
quantized := ggml.Quantize(toType, f32s, []uint64{nPerRow, chunkRows})
n, err := w.Write(quantized)
written += int64(n)
if err != nil {
return written, err
}
if n != len(quantized) {
return written, io.ErrShortWrite
}
if q.progressFn != nil {
q.progressFn(chunkBytes)
}
row += chunkRows
}
return written, nil
data = ggml.Quantize(newType, f32s, q.from.Shape)
n, err := w.Write(data)
q.progressFn(q.from.Size())
return int64(n), err
}
type quantizeState struct {

View File

@@ -484,8 +484,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
// the real chat handler, but doing this as a stopgap to get renderer
// support for generate
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
genTruncate := (req.Truncate == nil || *req.Truncate) && !m.IsMLX()
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, genTruncate)
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -558,7 +557,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
PromptEvalDuration: cr.PromptEvalDuration,
EvalCount: cr.EvalCount,
EvalDuration: cr.EvalDuration,
PeakMemory: cr.PeakMemory,
},
Logprobs: toAPILogprobs(cr.Logprobs),
}
@@ -2218,9 +2216,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
truncate := req.Truncate == nil || *req.Truncate
if m.IsMLX() {
truncate = false
}
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
if err != nil {
slog.Error("chat prompt error", "error", err)
@@ -2317,7 +2312,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
PeakMemory: r.PeakMemory,
},
Logprobs: toAPILogprobs(r.Logprobs),
}

View File

@@ -231,7 +231,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
}
// Check for experimental safetensors LLM models
if pending.model.IsMLX() {
if pending.model.Config.ModelFormat == "safetensors" {
if slices.Contains(pending.model.Config.Capabilities, "completion") {
// LLM model with safetensors format - use MLX runner
if s.loadMLX(pending) {
@@ -764,7 +764,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
defer cancel()
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
(!runner.model.IsMLX() && !reflect.DeepEqual(optsExisting, optsNew)) || // have the runner options changed?
!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
runner.llama.Ping(ctx) != nil {
return true
}

View File

@@ -78,12 +78,6 @@ func (c *kvCache) findRemaining(tokens []int32) []int32 {
prefix++
}
// Always keep at least one token to re-evaluate so the
// pipeline can seed token generation from it.
if prefix == len(tokens) && prefix > 0 {
prefix--
}
if prefix < len(c.tokens) {
trim := len(c.tokens) - prefix
for _, kv := range c.caches {

View File

@@ -8,6 +8,7 @@ import (
"fmt"
"io"
"log/slog"
"math"
"math/rand"
"net"
"net/http"
@@ -18,10 +19,8 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/x/imagegen"
@@ -29,16 +28,15 @@ import (
// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models.
type Client struct {
port int
modelName string
contextLength atomic.Int64
memory atomic.Uint64
done chan error
client *http.Client
lastErr string
lastErrLock sync.Mutex
mu sync.Mutex
cmd *exec.Cmd
port int
modelName string
memory uint
done chan error
client *http.Client
lastErr string
lastErrLock sync.Mutex
mu sync.Mutex
cmd *exec.Cmd
}
// NewClient spawns a new MLX runner subprocess for LLM models and waits until it's ready.
@@ -193,20 +191,6 @@ type completionOpts struct {
NumPredict int `json:"num_predict,omitempty"`
}
type CompletionResponse struct {
Content string
Done bool
DoneReason int
PromptEvalCount int
PromptEvalDuration time.Duration
EvalCount int
EvalDuration time.Duration
PeakMemory uint64
Error *api.StatusError
}
// Close terminates the subprocess.
func (c *Client) Close() error {
c.mu.Lock()
@@ -266,25 +250,28 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
var raw CompletionResponse
var raw struct {
Content string `json:"content,omitempty"`
Done bool `json:"done"`
DoneReason int `json:"done_reason,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration int `json:"eval_duration,omitempty"`
}
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
continue
}
if raw.Error != nil {
return *raw.Error
}
cresp := llm.CompletionResponse{
Content: raw.Content,
Done: raw.Done,
DoneReason: llm.DoneReason(raw.DoneReason),
PromptEvalCount: raw.PromptEvalCount,
PromptEvalDuration: raw.PromptEvalDuration,
PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
EvalCount: raw.EvalCount,
EvalDuration: raw.EvalDuration,
PeakMemory: raw.PeakMemory,
EvalDuration: time.Duration(raw.EvalDuration),
}
fn(cresp)
@@ -297,7 +284,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
}
func (c *Client) ContextLength() int {
return int(c.contextLength.Load())
return math.MaxInt
}
// Detokenize implements llm.LlamaServer.
@@ -351,10 +338,9 @@ func (c *Client) Pid() int {
}
type statusResponse struct {
Status int
Progress int
ContextLength int
Memory uint64
Status int
Progress int
Memory uint
}
// Ping implements llm.LlamaServer.
@@ -377,10 +363,7 @@ func (c *Client) Ping(ctx context.Context) error {
if err := json.NewDecoder(resp.Body).Decode(&status); err != nil {
return err
}
c.contextLength.Store(int64(status.ContextLength))
c.memory.Store(status.Memory)
c.memory = status.Memory
return nil
}
@@ -413,7 +396,7 @@ func (c *Client) currentMemory() uint64 {
if err := c.Ping(ctx); err != nil {
slog.Warn("failed to get current memory", "error", err)
}
return c.memory.Load()
return uint64(c.memory)
}
// MemorySize implements llm.LlamaServer.

View File

@@ -64,10 +64,6 @@ func PeakMemory() int {
return int(peak)
}
func ResetPeakMemory() {
C.mlx_reset_peak_memory()
}
type Memory struct{}
func (Memory) LogValue() slog.Value {

View File

@@ -20,7 +20,6 @@ type Model interface {
Unembed(x *mlx.Array) *mlx.Array
NumLayers() int
Tokenizer() *tokenizer.Tokenizer
MaxContextLength() int
// LoadWeights receives all tensors loaded from the manifest and assigns
// them to model fields. Model-specific logic (MLA absorption, expert

View File

@@ -6,12 +6,9 @@ import (
"bytes"
"context"
"errors"
"fmt"
"log/slog"
"net/http"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
@@ -47,35 +44,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
} else {
mlx.DisableCompile()
}
mlx.ResetPeakMemory()
inputs := r.Tokenizer.Encode(request.Prompt, true)
if len(inputs) == 0 {
return errors.New("empty prompt")
}
if len(inputs) >= r.contextLength {
return api.StatusError{
StatusCode: http.StatusBadRequest,
ErrorMessage: fmt.Sprintf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(inputs), r.contextLength),
}
}
// Cap generation to stay within the model's context length
maxGenerate := r.contextLength - len(inputs)
if request.Options.MaxTokens <= 0 {
request.Options.MaxTokens = maxGenerate
} else {
request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate)
}
session := r.cache.begin(r.Model, inputs)
defer session.close()
caches := session.caches
tokens := session.remaining
now := time.Now()
total, processed := len(tokens), 0
slog.Info("Prompt processing progress", "processed", processed, "total", total)
for total-processed > 1 {
if err := request.Ctx.Err(); err != nil {
return err
@@ -115,7 +93,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
var b bytes.Buffer
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1}
now := time.Now()
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
for i := range request.Options.MaxTokens {
if err := request.Ctx.Err(); err != nil {
return err
@@ -124,8 +103,9 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
nextSample, nextLogprobs = step(sample)
if i == 0 {
slog.Info("Prompt processing progress", "processed", total, "total", total)
mlx.Eval(sample)
final.PromptEvalDuration = time.Since(now)
final.PromptTokensDuration = time.Since(now)
now = time.Now()
}
@@ -133,16 +113,18 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
session.outputs = append(session.outputs, output)
if r.Tokenizer.IsEOS(output) {
final.Token = int(output)
final.DoneReason = 0
final.EvalCount = i
final.CompletionTokens = i
break
}
select {
case <-request.Ctx.Done():
return request.Ctx.Err()
case request.Responses <- CompletionResponse{
Content: r.Decode(output, &b),
case request.Responses <- Response{
Text: r.Decode(output, &b),
Token: int(output),
}:
}
@@ -155,8 +137,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
}
}
final.EvalDuration = time.Since(now)
final.PeakMemory = uint64(mlx.PeakMemory())
final.CompletionTokensDuration = time.Since(now)
select {
case <-request.Ctx.Done():
return request.Ctx.Err()

View File

@@ -4,15 +4,14 @@ package mlxrunner
import (
"context"
"errors"
"log/slog"
"net"
"net/http"
"strings"
"time"
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
@@ -22,7 +21,7 @@ import (
type Request struct {
TextCompletionsRequest
Responses chan CompletionResponse
Responses chan Response
Pipeline func(Request) error
Ctx context.Context
@@ -44,12 +43,25 @@ type TextCompletionsRequest struct {
} `json:"options"`
}
type Response struct {
Text string `json:"content,omitempty"`
Token int `json:"token,omitempty"`
Logprobs []float32 `json:"logprobs,omitempty"`
Done bool `json:"done,omitempty"`
DoneReason int `json:"done_reason,omitempty"`
PromptTokens int `json:"prompt_eval_count,omitempty"`
PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"`
CompletionTokens int `json:"eval_count,omitempty"`
CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
}
type Runner struct {
Model base.Model
Tokenizer *tokenizer.Tokenizer
Requests chan Request
cache kvCache
contextLength int
Model base.Model
Tokenizer *tokenizer.Tokenizer
Requests chan Request
cache kvCache
}
func (r *Runner) Load(modelName string) error {
@@ -78,7 +90,6 @@ func (r *Runner) Load(modelName string) error {
r.Model = m
r.Tokenizer = m.Tokenizer()
r.contextLength = m.MaxContextLength()
return nil
}
@@ -147,17 +158,6 @@ func (r *Runner) Run(host, port string, mux http.Handler) error {
case request := <-r.Requests:
if err := request.Pipeline(request); err != nil {
slog.Info("Request terminated", "error", err)
var statusErr api.StatusError
if !errors.As(err, &statusErr) {
statusErr = api.StatusError{
StatusCode: http.StatusInternalServerError,
ErrorMessage: err.Error(),
}
}
select {
case request.Responses <- CompletionResponse{Error: &statusErr}:
case <-request.Ctx.Done():
}
}
close(request.Responses)

View File

@@ -51,10 +51,9 @@ func Execute(args []string) error {
mux := http.NewServeMux()
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
if err := json.NewEncoder(w).Encode(statusResponse{
Status: 0,
Progress: 100,
ContextLength: runner.contextLength,
Memory: uint64(mlx.ActiveMemory() + mlx.CacheMemory()),
Status: 0,
Progress: 100,
Memory: uint(mlx.ActiveMemory() + mlx.CacheMemory()),
}); err != nil {
slog.Error("Failed to encode response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
@@ -80,7 +79,7 @@ func Execute(args []string) error {
})
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
request := Request{Responses: make(chan CompletionResponse)}
request := Request{Responses: make(chan Response)}
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
slog.Error("Failed to decode request", "error", err)
@@ -89,6 +88,9 @@ func Execute(args []string) error {
}
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
if request.Options.MaxTokens < 1 {
request.Options.MaxTokens = 16 << 10
}
request.Pipeline = runner.TextGenerationPipeline
request.Sampler = sample.New(

View File

@@ -430,10 +430,6 @@ func (m *Model) NumLayers() int {
return len(m.Layers)
}
func (m *Model) MaxContextLength() int {
return int(m.MaxPositionEmbeddings)
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
return m.tok
}

View File

@@ -733,7 +733,7 @@ func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
func (m *Model) NumLayers() int { return len(m.Layers) }
// MaxContextLength returns the maximum context length
func (m *Model) MaxContextLength() int { return int(m.MaxPositionEmbeddings) }
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
// VocabSize returns the vocabulary size
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }

View File

@@ -262,10 +262,6 @@ func (m *Model) NumLayers() int {
return len(m.Layers)
}
func (m *Model) MaxContextLength() int {
return int(m.MaxPositionEmbeddings)
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
return m.tok
}

View File

@@ -279,10 +279,6 @@ func (m *Model) NumLayers() int {
return len(m.Layers)
}
func (m *Model) MaxContextLength() int {
return int(m.MaxPositionEmbeddings)
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
return m.tok
}