Compare commits

..

1 Commits

Author SHA1 Message Date
Patrick Devine
857cffd22a bugfix: fix crash bug in token cache logic
This change fixes a problem in the token cache logic to avoid panics caused by empty token arrays
by ensuring at least one token remains on full cache hits in the relevant function. The happens
if there is an exact match in the cache on subsequent generations.
2026-02-26 18:35:44 -08:00
47 changed files with 300 additions and 4026 deletions

View File

@@ -15,7 +15,6 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/internal/orderedmap" "github.com/ollama/ollama/internal/orderedmap"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
@@ -570,7 +569,6 @@ type DebugInfo struct {
type Metrics struct { type Metrics struct {
TotalDuration time.Duration `json:"total_duration,omitempty"` TotalDuration time.Duration `json:"total_duration,omitempty"`
PeakMemory uint64 `json:"peak_memory,omitempty"`
LoadDuration time.Duration `json:"load_duration,omitempty"` LoadDuration time.Duration `json:"load_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"` PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"` PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
@@ -936,10 +934,6 @@ func (m *Metrics) Summary() {
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration) fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
} }
if m.PeakMemory > 0 {
fmt.Fprintf(os.Stderr, "peak memory: %s\n", formatPeakMemory(m.PeakMemory))
}
if m.LoadDuration > 0 { if m.LoadDuration > 0 {
fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration) fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration)
} }
@@ -963,14 +957,6 @@ func (m *Metrics) Summary() {
} }
} }
func formatPeakMemory(b uint64) string {
if b >= format.GibiByte {
return fmt.Sprintf("%.3f GiB", float64(b)/float64(format.GibiByte))
}
return format.HumanBytes2(b)
}
func (opts *Options) FromMap(m map[string]any) error { func (opts *Options) FromMap(m map[string]any) error {
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct

View File

@@ -74,7 +74,8 @@ type LlamaServer interface {
Tokenize(ctx context.Context, content string) ([]int, error) Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error) Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error Close() error
MemorySize() (total, vram uint64) VRAMSize() uint64 // Total VRAM across all GPUs
TotalSize() uint64
VRAMByGPU(id ml.DeviceID) uint64 VRAMByGPU(id ml.DeviceID) uint64
Pid() int Pid() int
GetPort() int GetPort() int
@@ -684,9 +685,8 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
// Windows CUDA should not use mmap for best performance // Windows CUDA should not use mmap for best performance
// Linux with a model larger than free space, mmap leads to thrashing // Linux with a model larger than free space, mmap leads to thrashing
// For CPU loads we want the memory to be allocated, not FS cache // For CPU loads we want the memory to be allocated, not FS cache
totalSize, _ := s.MemorySize()
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) || if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < totalSize && s.options.UseMMap == nil) || (runtime.GOOS == "linux" && systemInfo.FreeMemory < s.TotalSize() && s.options.UseMMap == nil) ||
(len(gpus) == 0 && s.options.UseMMap == nil) || (len(gpus) == 0 && s.options.UseMMap == nil) ||
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) || (len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
(s.options.UseMMap != nil && !*s.options.UseMMap) { (s.options.UseMMap != nil && !*s.options.UseMMap) {
@@ -1453,12 +1453,10 @@ type ImageData struct {
} }
type CompletionRequest struct { type CompletionRequest struct {
Prompt string Prompt string
Format json.RawMessage Format json.RawMessage
Images []ImageData Images []ImageData
Options *api.Options Options *api.Options
Think *api.ThinkValue
ExplicitOptions map[string]struct{}
Grammar string // set before sending the request to the subprocess Grammar string // set before sending the request to the subprocess
Shift bool Shift bool
@@ -1520,7 +1518,6 @@ type CompletionResponse struct {
PromptEvalDuration time.Duration `json:"prompt_eval_duration"` PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
EvalCount int `json:"eval_count"` EvalCount int `json:"eval_count"`
EvalDuration time.Duration `json:"eval_duration"` EvalDuration time.Duration `json:"eval_duration"`
PeakMemory uint64 `json:"peak_memory,omitempty"`
// Logprobs contains log probability information if requested // Logprobs contains log probability information if requested
Logprobs []Logprob `json:"logprobs,omitempty"` Logprobs []Logprob `json:"logprobs,omitempty"`
@@ -1851,17 +1848,17 @@ func (s *llamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
return nil return nil
} }
func (s *llmServer) MemorySize() (total, vram uint64) { func (s *llmServer) VRAMSize() uint64 {
if s.mem == nil { if s.mem == nil {
return 0, 0 return 0
} }
var mem uint64
for _, g := range s.mem.GPUs { for _, g := range s.mem.GPUs {
vram += g.Size() mem += g.Size()
} }
total = s.mem.InputWeights + s.mem.CPU.Size() + vram
// Some elements are always on CPU. However, if we have allocated all layers // Some elements are always on CPU. However, if we have allocated all layers
// on the GPU then include the CPU components as well, to represent complete offloading. // on the GPU then include the CPU components as well, to represent complete offloading.
noCPULayers := true noCPULayers := true
@@ -1872,11 +1869,25 @@ func (s *llmServer) MemorySize() (total, vram uint64) {
} }
} }
if noCPULayers { if noCPULayers {
vram += s.mem.InputWeights mem += s.mem.InputWeights
vram += s.mem.CPU.Graph mem += s.mem.CPU.Graph
} }
return total, vram return mem
}
func (s *llmServer) TotalSize() uint64 {
if s.mem == nil {
return 0
}
mem := s.mem.InputWeights
mem += s.mem.CPU.Size()
for _, g := range s.mem.GPUs {
mem += g.Size()
}
return mem
} }
func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 { func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {

View File

@@ -41,8 +41,8 @@ type GatedDeltaNet struct {
SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35) SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35)
SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35) SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35)
SSMConv1D *convKernel `gguf:"ssm_conv1d"` SSMConv1D *convKernel `gguf:"ssm_conv1d"`
SSMDT ml.Tensor `gguf:"ssm_dt,alt:ssm_dt.bias"` // alpha bias SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp() SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"` SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
SSMOut *nn.Linear `gguf:"ssm_out"` SSMOut *nn.Linear `gguf:"ssm_out"`
@@ -135,18 +135,6 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
default: default:
return nil, errors.New("qwen3next: missing linear attention beta/alpha projections") return nil, errors.New("qwen3next: missing linear attention beta/alpha projections")
} }
if gdn.SSMDT == nil {
return nil, errors.New("qwen3next: missing linear attention ssm_dt tensor")
}
if gdn.SSMA == nil {
return nil, errors.New("qwen3next: missing linear attention ssm_a tensor")
}
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
return nil, errors.New("qwen3next: missing linear attention ssm_conv1d tensor")
}
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
return nil, errors.New("qwen3next: missing linear attention ssm_norm/ssm_out projections")
}
// Compute gate: softplus(alpha + dt_bias) * -A // Compute gate: softplus(alpha + dt_bias) * -A
alphaBiased := alpha.Add(ctx, gdn.SSMDT) alphaBiased := alpha.Add(ctx, gdn.SSMDT)

View File

@@ -437,46 +437,6 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
return m.Output.Forward(ctx, hiddenStates), nil return m.Output.Forward(ctx, hiddenStates), nil
} }
func (m *Model) Validate() error {
if m.Options == nil {
return fmt.Errorf("qwen3next: missing model options")
}
if len(m.Layers) != len(m.Options.isRecurrent) {
return fmt.Errorf("qwen3next: layer config mismatch: have %d layers, %d recurrent flags", len(m.Layers), len(m.Options.isRecurrent))
}
for i, layer := range m.Layers {
if !m.Options.isRecurrent[i] {
continue
}
gdn, ok := layer.Operator.(*GatedDeltaNet)
if !ok || gdn == nil {
return fmt.Errorf("qwen3next: layer %d expected recurrent operator", i)
}
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
return fmt.Errorf("qwen3next: layer %d missing attn_qkv/attn_gate projections", i)
}
if gdn.SSMBetaAlpha == nil && (gdn.SSMBeta == nil || gdn.SSMAlpha == nil) {
return fmt.Errorf("qwen3next: layer %d missing linear attention beta/alpha projections", i)
}
if gdn.SSMDT == nil {
return fmt.Errorf("qwen3next: layer %d missing ssm_dt tensor", i)
}
if gdn.SSMA == nil {
return fmt.Errorf("qwen3next: layer %d missing ssm_a tensor", i)
}
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
return fmt.Errorf("qwen3next: layer %d missing ssm_conv1d tensor", i)
}
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
return fmt.Errorf("qwen3next: layer %d missing ssm_norm/ssm_out projections", i)
}
}
return nil
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
m.positionCache = nil m.positionCache = nil
if len(m.mropeSections) > 0 { if len(m.mropeSections) > 0 {
@@ -490,64 +450,6 @@ var (
_ model.MultimodalProcessor = (*Model)(nil) _ model.MultimodalProcessor = (*Model)(nil)
) )
func defaultVHeadReordered(arch string) bool {
return arch == "qwen35" || arch == "qwen35moe"
}
func inferRecurrentLayers(headCountKV []uint64, numLayers int, fullAttentionInterval uint32) ([]bool, error) {
isRecurrent := make([]bool, numLayers)
hasZero := false
hasFull := false
for i := range numLayers {
if i >= len(headCountKV) {
continue
}
if headCountKV[i] == 0 {
isRecurrent[i] = true
hasZero = true
} else {
hasFull = true
}
}
if hasZero && hasFull {
return isRecurrent, nil
}
if !hasFull {
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
}
// Compatibility path: older imports store a scalar KV head count and omit
// per-layer recurrent flags. Derive the hybrid layout from the interval.
interval := int(fullAttentionInterval)
if interval == 0 {
interval = min(4, numLayers)
}
if interval <= 0 {
return nil, fmt.Errorf("qwen3next: invalid block_count (%d)", numLayers)
}
if interval > numLayers {
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds block_count (%d)", interval, numLayers)
}
hasZero = false
hasFull = false
for i := range numLayers {
isRecurrent[i] = (i+1)%interval != 0
if isRecurrent[i] {
hasZero = true
} else {
hasFull = true
}
}
if !hasZero || !hasFull {
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) does not produce a mixed recurrent/full layout", interval)
}
return isRecurrent, nil
}
func New(c fs.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
numLayers := int(c.Uint("block_count")) numLayers := int(c.Uint("block_count"))
layers := make([]Layer, numLayers) layers := make([]Layer, numLayers)
@@ -558,14 +460,26 @@ func New(c fs.Config) (model.Model, error) {
HeadCountKV() []uint64 HeadCountKV() []uint64
} }
var isRecurrent []bool
var headCountKV []uint64 var headCountKV []uint64
if hc, ok := c.(headCounts); ok { if hc, ok := c.(headCounts); ok {
headCountKV = hc.HeadCountKV() headCountKV = hc.HeadCountKV()
} }
isRecurrent, err := inferRecurrentLayers(headCountKV, numLayers, c.Uint("full_attention_interval")) isRecurrent = make([]bool, numLayers)
if err != nil { hasZero := false
return nil, err hasFull := false
for i := range numLayers {
// If KV head count is 0, it's a recurrent layer
if i < len(headCountKV) && headCountKV[i] == 0 {
isRecurrent[i] = true
hasZero = true
} else if i < len(headCountKV) && headCountKV[i] > 0 {
hasFull = true
}
}
if !hasZero || !hasFull {
return nil, fmt.Errorf("qwen3next: invalid attention.head_count_kv array; expected mix of zero and non-zero values")
} }
// Determine if MoE // Determine if MoE
@@ -629,7 +543,7 @@ func New(c fs.Config) (model.Model, error) {
ssmNGroup: int(c.Uint("ssm.group_count")), ssmNGroup: int(c.Uint("ssm.group_count")),
ssmDtRank: int(c.Uint("ssm.time_step_rank")), ssmDtRank: int(c.Uint("ssm.time_step_rank")),
convKernelSize: int(c.Uint("ssm.conv_kernel")), convKernelSize: int(c.Uint("ssm.conv_kernel")),
vHeadReordered: c.Bool("ssm.v_head_reordered", defaultVHeadReordered(c.Architecture())), vHeadReordered: c.Bool("ssm.v_head_reordered", false),
isRecurrent: isRecurrent, isRecurrent: isRecurrent,
mropeSections: slices.Collect(func(yield func(int) bool) { mropeSections: slices.Collect(func(yield func(int) bool) {
for _, section := range mropeSections { for _, section := range mropeSections {
@@ -641,7 +555,7 @@ func New(c fs.Config) (model.Model, error) {
mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)), mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)),
} }
if opts.numKVHeads == 0 { if opts.numKVHeads == 0 {
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value") return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value")
} }
// Calculate cache dimensions // Calculate cache dimensions

View File

@@ -1,65 +0,0 @@
package qwen3next
import (
"slices"
"strings"
"testing"
)
func TestInferRecurrentLayersMixedKVArray(t *testing.T) {
got, err := inferRecurrentLayers([]uint64{0, 2, 0, 2}, 4, 0)
if err != nil {
t.Fatalf("inferRecurrentLayers() error = %v", err)
}
want := []bool{true, false, true, false}
if !slices.Equal(got, want) {
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
}
}
func TestInferRecurrentLayersScalarKVDefaultInterval(t *testing.T) {
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2, 2, 2}, 8, 0)
if err != nil {
t.Fatalf("inferRecurrentLayers() error = %v", err)
}
want := []bool{true, true, true, false, true, true, true, false}
if !slices.Equal(got, want) {
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
}
}
func TestInferRecurrentLayersScalarKVConfiguredInterval(t *testing.T) {
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2}, 6, 3)
if err != nil {
t.Fatalf("inferRecurrentLayers() error = %v", err)
}
want := []bool{true, true, false, true, true, false}
if !slices.Equal(got, want) {
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
}
}
func TestInferRecurrentLayersAllZeroRejects(t *testing.T) {
_, err := inferRecurrentLayers([]uint64{0, 0, 0, 0}, 4, 0)
if err == nil {
t.Fatal("inferRecurrentLayers() expected error, got nil")
}
if !strings.Contains(err.Error(), "must include at least one non-zero value") {
t.Fatalf("unexpected error = %v", err)
}
}
func TestDefaultVHeadReordered(t *testing.T) {
if !defaultVHeadReordered("qwen35") {
t.Fatal("defaultVHeadReordered(qwen35) = false, want true")
}
if !defaultVHeadReordered("qwen35moe") {
t.Fatal("defaultVHeadReordered(qwen35moe) = false, want true")
}
if defaultVHeadReordered("qwen3next") {
t.Fatal("defaultVHeadReordered(qwen3next) = true, want false")
}
}

View File

@@ -1,45 +0,0 @@
package qwen3next
import (
"strings"
"testing"
"github.com/ollama/ollama/ml/nn"
)
func TestValidateRecurrentLayerRequiresSSMDT(t *testing.T) {
m := &Model{
Layers: []Layer{{
Operator: &GatedDeltaNet{
SSMQKV: &nn.Linear{},
SSMQKVGate: &nn.Linear{},
SSMBeta: &nn.Linear{},
SSMAlpha: &nn.Linear{},
},
}},
Options: &Options{
isRecurrent: []bool{true},
},
}
err := m.Validate()
if err == nil {
t.Fatal("Validate() expected error, got nil")
}
if !strings.Contains(err.Error(), "missing ssm_dt") {
t.Fatalf("unexpected error = %v", err)
}
}
func TestValidateNonRecurrentSkipsLinearChecks(t *testing.T) {
m := &Model{
Layers: []Layer{{Operator: &FullAttention{}}},
Options: &Options{
isRecurrent: []bool{false},
},
}
if err := m.Validate(); err != nil {
t.Fatalf("Validate() error = %v", err)
}
}

View File

@@ -32,10 +32,9 @@ const (
) )
type GLM46Parser struct { type GLM46Parser struct {
state glm46ParserState state glm46ParserState
buffer strings.Builder buffer strings.Builder
tools []api.Tool tools []api.Tool
callIndex int
} }
func (p *GLM46Parser) HasToolSupport() bool { func (p *GLM46Parser) HasToolSupport() bool {
@@ -49,7 +48,6 @@ func (p *GLM46Parser) HasThinkingSupport() bool {
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool { // func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools p.tools = tools
p.callIndex = 0
return tools return tools
} }
@@ -91,8 +89,6 @@ func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string,
slog.Warn("glm-4.6 tool call parsing failed", "error", err) slog.Warn("glm-4.6 tool call parsing failed", "error", err)
return "", "", nil, err return "", "", nil, err
} }
toolCall.Function.Index = p.callIndex
p.callIndex++
toolCalls = append(toolCalls, toolCall) toolCalls = append(toolCalls, toolCall)
case glm46EventThinkingContent: case glm46EventThinkingContent:
thinkingSb.WriteString(event.content) thinkingSb.WriteString(event.content)

View File

@@ -11,7 +11,6 @@ type GLM47Parser struct {
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools p.tools = tools
p.callIndex = 0
// When thinking is enabled (nil or true), the prompt ends with <think>, // When thinking is enabled (nil or true), the prompt ends with <think>,
// so model output starts directly with thinking content (no opening tag). // so model output starts directly with thinking content (no opening tag).
if thinkValue == nil || thinkValue.Bool() { if thinkValue == nil || thinkValue.Bool() {

View File

@@ -97,91 +97,3 @@ func TestGLM47ParserToolCallEscaping(t *testing.T) {
t.Fatalf("expected %#v, got %#v", expected, toolCall) t.Fatalf("expected %#v, got %#v", expected, toolCall)
} }
} }
func TestGLM47ParserToolCallIndexing(t *testing.T) {
parser := GLM47Parser{}
parser.Init(nil, nil, nil)
input := `plan</think>
<tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>
<tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>
<tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>`
_, _, calls, err := parser.Add(input, true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
}
if len(calls) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
}
for i := range want {
if !toolCallEqual(calls[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
}
}
}
func TestGLM47ParserToolCallIndexingStreaming(t *testing.T) {
parser := GLM47Parser{}
parser.Init(nil, nil, nil)
var all []api.ToolCall
_, _, calls, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call><tool_call>second<arg_key>b</arg_key>", false)
if err != nil {
t.Fatalf("step 1 parse failed: %v", err)
}
all = append(all, calls...)
_, _, calls, err = parser.Add("<arg_value>2</arg_value></tool_call><tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>", true)
if err != nil {
t.Fatalf("step 2 parse failed: %v", err)
}
all = append(all, calls...)
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
}
if len(all) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(all))
}
for i := range want {
if !toolCallEqual(all[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
}
}
}
func TestGLM47ParserToolCallIndexResetOnInit(t *testing.T) {
parser := GLM47Parser{}
parser.Init(nil, nil, nil)
_, _, _, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>", true)
if err != nil {
t.Fatalf("first parse failed: %v", err)
}
parser.Init(nil, nil, nil)
_, _, calls, err := parser.Add("plan</think><tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>", true)
if err != nil {
t.Fatalf("second parse failed: %v", err)
}
want := api.ToolCall{
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
}
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if !toolCallEqual(calls[0], want) {
t.Fatalf("got %#v, want %#v", calls[0], want)
}
}

View File

@@ -38,7 +38,6 @@ type Qwen3Parser struct {
state qwen3ParserState state qwen3ParserState
buffer strings.Builder buffer strings.Builder
tools []api.Tool tools []api.Tool
callIndex int
hasThinkingSupport bool hasThinkingSupport bool
defaultThinking bool defaultThinking bool
maybeThinkingOpenAtBOL bool maybeThinkingOpenAtBOL bool
@@ -55,7 +54,6 @@ func (p *Qwen3Parser) HasThinkingSupport() bool {
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools p.tools = tools
p.buffer.Reset() p.buffer.Reset()
p.callIndex = 0
thinkingEnabled := thinkValue != nil && thinkValue.Bool() thinkingEnabled := thinkValue != nil && thinkValue.Bool()
if thinkValue == nil { if thinkValue == nil {
@@ -108,8 +106,6 @@ func (p *Qwen3Parser) Add(s string, done bool) (content string, thinking string,
slog.Warn("qwen3 tool call parsing failed", "error", err) slog.Warn("qwen3 tool call parsing failed", "error", err)
return "", "", nil, err return "", "", nil, err
} }
toolCall.Function.Index = p.callIndex
p.callIndex++
calls = append(calls, toolCall) calls = append(calls, toolCall)
case qwen3EventThinkingContent: case qwen3EventThinkingContent:
thinkingSb.WriteString(event.content) thinkingSb.WriteString(event.content)

View File

@@ -230,89 +230,3 @@ func TestQwen35ParserRespectsNoThink(t *testing.T) {
t.Fatalf("expected no tool calls, got %d", len(calls)) t.Fatalf("expected no tool calls, got %d", len(calls))
} }
} }
func TestQwen3ParserToolCallIndexing(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
input := `<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>
<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>
<tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`
_, _, calls, err := parser.Add(input, true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
}
if len(calls) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
}
for i := range want {
if !toolCallEqual(calls[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
}
}
}
func TestQwen3ParserToolCallIndexingStreaming(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
var all []api.ToolCall
_, _, calls, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call><tool_call>{"name":"second","arguments":{"b":"2"}`, false)
if err != nil {
t.Fatalf("step 1 parse failed: %v", err)
}
all = append(all, calls...)
_, _, calls, err = parser.Add(`}</tool_call><tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`, true)
if err != nil {
t.Fatalf("step 2 parse failed: %v", err)
}
all = append(all, calls...)
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
}
if len(all) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(all))
}
for i := range want {
if !toolCallEqual(all[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
}
}
}
func TestQwen3ParserToolCallIndexResetOnInit(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
_, _, _, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>`, true)
if err != nil {
t.Fatalf("first parse failed: %v", err)
}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
_, _, calls, err := parser.Add(`<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>`, true)
if err != nil {
t.Fatalf("second parse failed: %v", err)
}
want := api.ToolCall{
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
}
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if !toolCallEqual(calls[0], want) {
t.Fatalf("got %#v, want %#v", calls[0], want)
}
}

View File

@@ -29,10 +29,9 @@ const (
) )
type Qwen3CoderParser struct { type Qwen3CoderParser struct {
state qwenParserState state qwenParserState
acc strings.Builder acc strings.Builder
tools []api.Tool tools []api.Tool
callIndex int
} }
func (p *Qwen3CoderParser) HasToolSupport() bool { func (p *Qwen3CoderParser) HasToolSupport() bool {
@@ -45,7 +44,6 @@ func (p *Qwen3CoderParser) HasThinkingSupport() bool {
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools p.tools = tools
p.callIndex = 0
return tools // Qwen doesn't modify tools return tools // Qwen doesn't modify tools
} }
@@ -64,8 +62,6 @@ func (p *Qwen3CoderParser) Add(s string, done bool) (content string, thinking st
slog.Warn("qwen tool call parsing failed", "error", err) slog.Warn("qwen tool call parsing failed", "error", err)
return "", "", nil, err return "", "", nil, err
} }
toolCall.Function.Index = p.callIndex
p.callIndex++
toolCalls = append(toolCalls, toolCall) toolCalls = append(toolCalls, toolCall)
case qwenEventContent: case qwenEventContent:
// TODO(drifkin): if the same turn contains multiple interleaved content // TODO(drifkin): if the same turn contains multiple interleaved content

View File

@@ -1035,92 +1035,6 @@ func TestQwenToolCallValueParsing(t *testing.T) {
} }
} }
func TestQwen3CoderParserToolCallIndexing(t *testing.T) {
parser := Qwen3CoderParser{}
parser.Init(nil, nil, nil)
input := `<tool_call><function=first><parameter=a>1</parameter></function></tool_call>
<tool_call><function=second><parameter=b>2</parameter></function></tool_call>
<tool_call><function=third><parameter=c>3</parameter></function></tool_call>`
_, _, calls, err := parser.Add(input, true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
}
if len(calls) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
}
for i := range want {
if !toolCallEqual(calls[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
}
}
}
func TestQwen3CoderParserToolCallIndexingStreaming(t *testing.T) {
parser := Qwen3CoderParser{}
parser.Init(nil, nil, nil)
var all []api.ToolCall
_, _, calls, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call><tool_call><function=second>", false)
if err != nil {
t.Fatalf("step 1 parse failed: %v", err)
}
all = append(all, calls...)
_, _, calls, err = parser.Add("<parameter=b>2</parameter></function></tool_call><tool_call><function=third><parameter=c>3</parameter></function></tool_call>", true)
if err != nil {
t.Fatalf("step 2 parse failed: %v", err)
}
all = append(all, calls...)
want := []api.ToolCall{
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
}
if len(all) != len(want) {
t.Fatalf("expected %d calls, got %d", len(want), len(all))
}
for i := range want {
if !toolCallEqual(all[i], want[i]) {
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
}
}
}
func TestQwen3CoderParserToolCallIndexResetOnInit(t *testing.T) {
parser := Qwen3CoderParser{}
parser.Init(nil, nil, nil)
_, _, _, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call>", true)
if err != nil {
t.Fatalf("first parse failed: %v", err)
}
parser.Init(nil, nil, nil)
_, _, calls, err := parser.Add("<tool_call><function=second><parameter=b>2</parameter></function></tool_call>", true)
if err != nil {
t.Fatalf("second parse failed: %v", err)
}
want := api.ToolCall{
Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 0},
}
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if !toolCallEqual(calls[0], want) {
t.Fatalf("got %#v, want %#v", calls[0], want)
}
}
func TestQwenXMLTransform(t *testing.T) { func TestQwenXMLTransform(t *testing.T) {
cases := []struct { cases := []struct {
desc string desc string

View File

@@ -71,10 +71,6 @@ type Model struct {
Template *template.Template Template *template.Template
} }
func (m *Model) IsMLX() bool {
return m.Config.ModelFormat == "safetensors"
}
// Capabilities returns the capabilities that the model supports // Capabilities returns the capabilities that the model supports
func (m *Model) Capabilities() []model.Capability { func (m *Model) Capabilities() []model.Capability {
capabilities := []model.Capability{} capabilities := []model.Capability{}

View File

@@ -30,44 +30,42 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
lastMsgIdx := len(msgs) - 1 lastMsgIdx := len(msgs) - 1
currMsgIdx := 0 currMsgIdx := 0
if truncate { // Start with all messages and remove from the front until it fits in context
// Start with all messages and remove from the front until it fits in context for i := 0; i <= lastMsgIdx; i++ {
for i := 0; i <= lastMsgIdx; i++ { // Collect system messages from the portion we're about to skip
// Collect system messages from the portion we're about to skip system = make([]api.Message, 0)
system = make([]api.Message, 0) for j := range i {
for j := range i { if msgs[j].Role == "system" {
if msgs[j].Role == "system" { system = append(system, msgs[j])
system = append(system, msgs[j])
}
} }
}
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think) p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
s, err := tokenize(ctx, p) s, err := tokenize(ctx, p)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
ctxLen := len(s) ctxLen := len(s)
if m.ProjectorPaths != nil { if m.ProjectorPaths != nil {
for _, msg := range msgs[i:] { for _, msg := range msgs[i:] {
ctxLen += imageNumTokens * len(msg.Images) ctxLen += imageNumTokens * len(msg.Images)
}
} }
}
if ctxLen <= opts.NumCtx { if !truncate || ctxLen <= opts.NumCtx {
currMsgIdx = i currMsgIdx = i
break break
} }
// Must always include at least the last message // Must always include at least the last message
if i == lastMsgIdx { if i == lastMsgIdx {
currMsgIdx = lastMsgIdx currMsgIdx = lastMsgIdx
break break
}
} }
} }

View File

@@ -130,35 +130,6 @@ func (s *Server) modelOptions(model *Model, requestOpts map[string]any) (api.Opt
return opts, nil return opts, nil
} }
func explicitOptions(modelOpts, requestOpts map[string]any) map[string]struct{} {
keys := []string{
"temperature",
"top_p",
"min_p",
"top_k",
"repeat_last_n",
"repeat_penalty",
"presence_penalty",
"frequency_penalty",
}
explicit := make(map[string]struct{}, len(keys))
for _, key := range keys {
if optionSpecified(modelOpts, requestOpts, key) {
explicit[key] = struct{}{}
}
}
return explicit
}
func optionSpecified(modelOpts, requestOpts map[string]any, key string) bool {
if _, ok := requestOpts[key]; ok {
return true
}
_, ok := modelOpts[key]
return ok
}
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options. // scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise. // It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) { func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
@@ -513,8 +484,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
// the real chat handler, but doing this as a stopgap to get renderer // the real chat handler, but doing this as a stopgap to get renderer
// support for generate // support for generate
if values.Messages != nil && values.Suffix == "" && req.Template == "" { if values.Messages != nil && values.Suffix == "" && req.Template == "" {
genTruncate := (req.Truncate == nil || *req.Truncate) && !m.IsMLX() prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate)
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, genTruncate)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@@ -568,16 +538,14 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var sb strings.Builder var sb strings.Builder
defer close(ch) defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt, Prompt: prompt,
Images: images, Images: images,
Format: req.Format, Format: req.Format,
Options: opts, Options: opts,
Think: req.Think, Shift: req.Shift == nil || *req.Shift,
ExplicitOptions: explicitOptions(m.Options, req.Options), Truncate: req.Truncate == nil || *req.Truncate,
Shift: req.Shift == nil || *req.Shift, Logprobs: req.Logprobs,
Truncate: req.Truncate == nil || *req.Truncate, TopLogprobs: req.TopLogprobs,
Logprobs: req.Logprobs,
TopLogprobs: req.TopLogprobs,
}, func(cr llm.CompletionResponse) { }, func(cr llm.CompletionResponse) {
res := api.GenerateResponse{ res := api.GenerateResponse{
Model: req.Model, Model: req.Model,
@@ -589,7 +557,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
PromptEvalDuration: cr.PromptEvalDuration, PromptEvalDuration: cr.PromptEvalDuration,
EvalCount: cr.EvalCount, EvalCount: cr.EvalCount,
EvalDuration: cr.EvalDuration, EvalDuration: cr.EvalDuration,
PeakMemory: cr.PeakMemory,
}, },
Logprobs: toAPILogprobs(cr.Logprobs), Logprobs: toAPILogprobs(cr.Logprobs),
} }
@@ -1984,9 +1951,6 @@ func (s *Server) PsHandler(c *gin.Context) {
} }
if v.llama != nil { if v.llama != nil {
mr.ContextLength = v.llama.ContextLength() mr.ContextLength = v.llama.ContextLength()
total, vram := v.llama.MemorySize()
mr.Size = int64(total)
mr.SizeVRAM = int64(vram)
} }
// The scheduler waits to set expiresAt, so if a model is loading it's // The scheduler waits to set expiresAt, so if a model is loading it's
// possible that it will be set to the unix epoch. For those cases, just // possible that it will be set to the unix epoch. For those cases, just
@@ -2249,9 +2213,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
truncate := req.Truncate == nil || *req.Truncate truncate := req.Truncate == nil || *req.Truncate
if m.IsMLX() {
truncate = false
}
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate) prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
if err != nil { if err != nil {
slog.Error("chat prompt error", "error", err) slog.Error("chat prompt error", "error", err)
@@ -2329,16 +2290,14 @@ func (s *Server) ChatHandler(c *gin.Context) {
// sets up new context given parent context per request // sets up new context given parent context per request
ctx, cancel := context.WithCancel(c.Request.Context()) ctx, cancel := context.WithCancel(c.Request.Context())
err := r.Completion(ctx, llm.CompletionRequest{ err := r.Completion(ctx, llm.CompletionRequest{
Prompt: prompt, Prompt: prompt,
Images: images, Images: images,
Format: currentFormat, Format: currentFormat,
Options: opts, Options: opts,
Think: req.Think, Shift: req.Shift == nil || *req.Shift,
ExplicitOptions: explicitOptions(m.Options, req.Options), Truncate: truncate,
Shift: req.Shift == nil || *req.Shift, Logprobs: req.Logprobs,
Truncate: truncate, TopLogprobs: req.TopLogprobs,
Logprobs: req.Logprobs,
TopLogprobs: req.TopLogprobs,
}, func(r llm.CompletionResponse) { }, func(r llm.CompletionResponse) {
res := api.ChatResponse{ res := api.ChatResponse{
Model: req.Model, Model: req.Model,
@@ -2350,7 +2309,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
PromptEvalDuration: r.PromptEvalDuration, PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount, EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration, EvalDuration: r.EvalDuration,
PeakMemory: r.PeakMemory,
}, },
Logprobs: toAPILogprobs(r.Logprobs), Logprobs: toAPILogprobs(r.Logprobs),
} }

View File

@@ -231,7 +231,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
} }
// Check for experimental safetensors LLM models // Check for experimental safetensors LLM models
if pending.model.IsMLX() { if pending.model.Config.ModelFormat == "safetensors" {
if slices.Contains(pending.model.Config.Capabilities, "completion") { if slices.Contains(pending.model.Config.Capabilities, "completion") {
// LLM model with safetensors format - use MLX runner // LLM model with safetensors format - use MLX runner
if s.loadMLX(pending) { if s.loadMLX(pending) {
@@ -536,7 +536,6 @@ iGPUScan:
} }
} }
totalSize, vramSize := llama.MemorySize()
runner := &runnerRef{ runner := &runnerRef{
model: req.model, model: req.model,
modelPath: req.model.ModelPath, modelPath: req.model.ModelPath,
@@ -546,8 +545,8 @@ iGPUScan:
sessionDuration: sessionDuration, sessionDuration: sessionDuration,
gpus: gpuIDs, gpus: gpuIDs,
discreteGPUs: discreteGPUs, discreteGPUs: discreteGPUs,
totalSize: totalSize, vramSize: llama.VRAMSize(),
vramSize: vramSize, totalSize: llama.TotalSize(),
loading: true, loading: true,
pid: llama.Pid(), pid: llama.Pid(),
} }
@@ -620,7 +619,6 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
sessionDuration = req.sessionDuration.Duration sessionDuration = req.sessionDuration.Duration
} }
totalSize, vramSize := server.MemorySize()
runner := &runnerRef{ runner := &runnerRef{
model: req.model, model: req.model,
modelPath: req.model.ModelPath, modelPath: req.model.ModelPath,
@@ -630,8 +628,8 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
loading: false, loading: false,
isImagegen: isImagegen, isImagegen: isImagegen,
sessionDuration: sessionDuration, sessionDuration: sessionDuration,
totalSize: totalSize, totalSize: server.TotalSize(),
vramSize: vramSize, vramSize: server.VRAMSize(),
} }
s.loadedMu.Lock() s.loadedMu.Lock()
@@ -764,7 +762,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
defer cancel() defer cancel()
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed? if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed? !reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
(!runner.model.IsMLX() && !reflect.DeepEqual(optsExisting, optsNew)) || // have the runner options changed? !reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
runner.llama.Ping(ctx) != nil { runner.llama.Ping(ctx) != nil {
return true return true
} }

View File

@@ -861,7 +861,8 @@ func (s *mockLlm) Close() error {
s.closeCalled = true s.closeCalled = true
return s.closeResp return s.closeResp
} }
func (s *mockLlm) MemorySize() (uint64, uint64) { return s.totalSize, s.vramSize } func (s *mockLlm) VRAMSize() uint64 { return s.vramSize }
func (s *mockLlm) TotalSize() uint64 { return s.totalSize }
func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] } func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] }
func (s *mockLlm) Pid() int { return -1 } func (s *mockLlm) Pid() int { return -1 }
func (s *mockLlm) GetPort() int { return -1 } func (s *mockLlm) GetPort() int { return -1 }

View File

@@ -288,18 +288,6 @@ func normalizeQuantType(quantize string) string {
} }
} }
func isStackedExpertWeight(name string) bool {
// Combined/stacked expert tensors may be emitted either as "...proj.weight" (per-expert)
// or "...proj" (pre-stacked packed tensor).
if strings.HasSuffix(name, ".bias") || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".qbias") {
return false
}
return strings.Contains(name, ".mlp.switch_mlp.") ||
strings.Contains(name, ".mlp.experts.") ||
strings.Contains(name, ".mlp.shared_experts.")
}
// GetTensorQuantization returns the appropriate quantization type for a tensor. // GetTensorQuantization returns the appropriate quantization type for a tensor.
// Returns "" if the tensor should not be quantized. // Returns "" if the tensor should not be quantized.
// This implements mixed-precision quantization: // This implements mixed-precision quantization:
@@ -308,25 +296,18 @@ func isStackedExpertWeight(name string) bool {
// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel) // - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
// - Norms, embeddings, biases, routing gates: no quantization // - Norms, embeddings, biases, routing gates: no quantization
func GetTensorQuantization(name string, shape []int32, quantize string) string { func GetTensorQuantization(name string, shape []int32, quantize string) string {
stackedExpert := isStackedExpertWeight(name)
// Use basic name-based check first // Use basic name-based check first
if !stackedExpert && !ShouldQuantize(name, "") { if !ShouldQuantize(name, "") {
return "" return ""
} }
// Quantize standard linear weights (2D). Also allow stacked expert weights (3D), // Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
// e.g. qwen switch_mlp / experts combined tensors. if len(shape) != 2 {
if len(shape) != 2 && !(len(shape) == 3 && stackedExpert) {
return "" return ""
} }
// Skip small tensors (less than 1024 elements) - not worth quantizing // Skip small tensors (less than 1024 elements) - not worth quantizing
var elems int64 = 1 if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 {
for _, d := range shape {
elems *= int64(d)
}
if elems < 1024 {
return "" return ""
} }

View File

@@ -557,10 +557,6 @@ func TestShouldQuantizeTensor(t *testing.T) {
// 3D+ tensors should not be quantized // 3D+ tensors should not be quantized
{"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false}, {"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false}, {"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
{"stacked expert switch_mlp gate_up 3D int8", "model.layers.1.mlp.switch_mlp.gate_up_proj.weight", []int32{64, 22016, 4096}, "int8", true},
{"stacked expert experts down_proj 3D int8", "model.layers.1.mlp.experts.down_proj.weight", []int32{64, 4096, 14336}, "int8", true},
{"stacked expert combined gate_up 3D int8", "model.language_model.layers.0.mlp.experts.gate_up_proj", []int32{256, 1024, 2048}, "int8", true},
{"stacked expert combined down_proj 3D int8", "model.language_model.layers.0.mlp.experts.down_proj", []int32{256, 2048, 512}, "int8", true},
// Embeddings should not be quantized regardless of shape // Embeddings should not be quantized regardless of shape
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false}, {"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
@@ -623,44 +619,6 @@ func TestExpertGroupPrefix(t *testing.T) {
} }
} }
func TestGetTensorQuantization_StackedExpert3D(t *testing.T) {
gateUp := GetTensorQuantization(
"model.layers.1.mlp.switch_mlp.gate_up_proj.weight",
[]int32{64, 22016, 4096},
"int4",
)
if gateUp != "int4" {
t.Fatalf("gate_up_proj quantization = %q, want %q", gateUp, "int4")
}
down := GetTensorQuantization(
"model.layers.1.mlp.experts.down_proj.weight",
[]int32{64, 4096, 14336},
"int4",
)
if down != "int8" {
t.Fatalf("down_proj quantization = %q, want %q", down, "int8")
}
combinedGateUp := GetTensorQuantization(
"model.language_model.layers.0.mlp.experts.gate_up_proj",
[]int32{256, 1024, 2048},
"int8",
)
if combinedGateUp != "int8" {
t.Fatalf("combined gate_up_proj quantization = %q, want %q", combinedGateUp, "int8")
}
combinedDown := GetTensorQuantization(
"model.language_model.layers.0.mlp.experts.down_proj",
[]int32{256, 2048, 512},
"int4",
)
if combinedDown != "int8" {
t.Fatalf("combined down_proj quantization = %q, want %q", combinedDown, "int8")
}
}
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) { func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()

View File

@@ -374,9 +374,14 @@ func (s *Server) Close() error {
return nil return nil
} }
// MemorySize returns the total and VRAM memory usage. // VRAMSize returns the estimated VRAM usage.
func (s *Server) MemorySize() (total, vram uint64) { func (s *Server) VRAMSize() uint64 {
return s.vramSize, s.vramSize return s.vramSize
}
// TotalSize returns the total memory usage.
func (s *Server) TotalSize() uint64 {
return s.vramSize
} }
// VRAMByGPU returns VRAM usage for a specific GPU. // VRAMByGPU returns VRAM usage for a specific GPU.

View File

@@ -30,64 +30,21 @@ type cacheSession struct {
remaining []int32 remaining []int32
} }
func (c *kvCache) free() {
for i, kv := range c.caches {
if kv == nil {
continue
}
kv.Free()
c.caches[i] = nil
}
c.caches = nil
c.tokens = nil
}
func (c *kvCache) cachesCanTrim() bool {
for _, kv := range c.caches {
if kv == nil {
continue
}
if !kv.CanTrim() {
return false
}
}
return true
}
func (c *kvCache) trimToPrefix(prefix int) {
for _, kv := range c.caches {
if kv == nil || !kv.CanTrim() {
continue
}
if trim := kv.Offset() - prefix; trim > 0 {
kv.Trim(trim)
}
}
if prefix < len(c.tokens) {
c.tokens = c.tokens[:prefix]
}
}
// begin prepares caches for a new request. It finds the nearest // begin prepares caches for a new request. It finds the nearest
// matching cache or creates new caches if none match. // matching cache or creates new caches if none match.
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession { func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
ensureCaches := func() { if len(c.caches) == 0 {
if len(c.caches) != 0 {
return
}
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok { if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
c.caches = cacheFactory.NewCaches() c.caches = cacheFactory.NewCaches()
return } else {
} c.caches = make([]cache.Cache, m.NumLayers())
c.caches = make([]cache.Cache, m.NumLayers()) for i := range c.caches {
for i := range c.caches { c.caches[i] = cache.NewKVCache()
c.caches[i] = cache.NewKVCache() }
} }
} }
ensureCaches()
remaining := c.findRemaining(inputs) remaining := c.findRemaining(inputs)
ensureCaches()
return &cacheSession{ return &cacheSession{
cache: c, cache: c,
@@ -99,34 +56,18 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
// close saves the token state if the forward pass ran. // close saves the token state if the forward pass ran.
func (s *cacheSession) close() { func (s *cacheSession) close() {
if len(s.caches) == 0 { if offset := s.caches[0].Offset(); offset > 0 {
return // Ensure that if we have run the forward pass and set the metadata
} // that we also actually have the data
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
offset := -1 for _, c := range s.caches {
arrays := make([]*mlx.Array, 0, 2*len(s.caches)) k, v := c.State()
for _, kv := range s.caches { arrays = append(arrays, k, v)
if kv == nil {
continue
} }
if off := kv.Offset(); offset < 0 || off < offset { mlx.AsyncEval(arrays...)
offset = off
}
arrays = append(arrays, kv.Materialize()...)
}
if offset <= 0 {
return
}
// Ensure that if we have run the forward pass and set the metadata s.cache.tokens = append(s.inputs, s.outputs...)[:offset]
// that we also actually have the data.
mlx.AsyncEval(arrays...)
stored := append(s.inputs, s.outputs...)
if offset > len(stored) {
offset = len(stored)
} }
s.cache.tokens = stored[:offset]
} }
// findRemaining finds the longest common prefix between tokens and the cached // findRemaining finds the longest common prefix between tokens and the cached
@@ -137,20 +78,17 @@ func (c *kvCache) findRemaining(tokens []int32) []int32 {
prefix++ prefix++
} }
// Always keep at least one token to re-evaluate so the
// pipeline can seed token generation from it.
if prefix == len(tokens) && prefix > 0 { if prefix == len(tokens) && prefix > 0 {
// Leave one token to run through the model so we can sample a response.
prefix-- prefix--
} }
if prefix < len(c.tokens) { if prefix < len(c.tokens) {
if c.cachesCanTrim() { trim := len(c.tokens) - prefix
c.trimToPrefix(prefix) for _, kv := range c.caches {
} else { kv.Trim(trim)
c.free()
slog.Info("Cache miss", "left", len(tokens), "matched", prefix, "reason", "non_trimmable_divergence")
return tokens
} }
c.tokens = c.tokens[:prefix]
} }
if prefix == 0 { if prefix == 0 {
@@ -165,21 +103,10 @@ func (c *kvCache) log() {
if len(c.caches) == 0 { if len(c.caches) == 0 {
return return
} }
offset := -1
var totalBytes int var totalBytes int
for _, kv := range c.caches { for _, kv := range c.caches {
if kv == nil { k, v := kv.State()
continue totalBytes += k.NumBytes() + v.NumBytes()
}
if off := kv.Offset(); offset < 0 || off < offset {
offset = off
}
for _, a := range kv.Materialize() {
totalBytes += a.NumBytes()
}
} }
if offset < 0 { logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
return
}
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", offset, mlx.PrettyBytes(totalBytes)))
} }

View File

@@ -10,8 +10,6 @@ import (
type Cache interface { type Cache interface {
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array) Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
State() (keys, values *mlx.Array) State() (keys, values *mlx.Array)
Materialize() []*mlx.Array
CanTrim() bool
Trim(int) int Trim(int) int
Clone() Cache Clone() Cache
Free() Free()
@@ -69,20 +67,6 @@ func (c *KVCache) State() (*mlx.Array, *mlx.Array) {
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()) c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
} }
// Materialize returns the backing key/value buffers currently held by the cache.
func (c *KVCache) Materialize() []*mlx.Array {
out := make([]*mlx.Array, 0, 2)
if c.keys != nil && c.keys.Valid() {
out = append(out, c.keys)
}
if c.values != nil && c.values.Valid() {
out = append(out, c.values)
}
return out
}
func (c *KVCache) CanTrim() bool { return true }
func (c *KVCache) Trim(n int) int { func (c *KVCache) Trim(n int) int {
n = min(c.offset, n) n = min(c.offset, n)
c.offset -= n c.offset -= n
@@ -206,8 +190,6 @@ func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) {
return c.keys, c.values return c.keys, c.values
} }
func (c *RotatingKVCache) CanTrim() bool { return true }
func (c *RotatingKVCache) Trim(n int) int { func (c *RotatingKVCache) Trim(n int) int {
n = min(c.offset, n) n = min(c.offset, n)
c.offset -= n c.offset -= n

View File

@@ -1,220 +0,0 @@
//go:build mlx
package cache
import "github.com/ollama/ollama/x/mlxrunner/mlx"
// RecurrentCache stores state for linear-recurrent layers.
//
// Conv state shape: [B, convTail, convDim]
// Delta state shape: [B, numVHeads, headVDim, headKDim]
type RecurrentCache struct {
convState *mlx.Array
deltaState *mlx.Array
offset int
convTail int
convDim int
numVHeads int
headVDim int
headKDim int
}
func (c *RecurrentCache) setStateMaterialized(dst **mlx.Array, v *mlx.Array) {
if v == nil || !v.Valid() {
return
}
if *dst == v {
return
}
// Break dependency chains so recurrent state does not retain the full
// per-token compute graph over time.
snap := mlx.Snapshot(v)
mlx.Eval(snap)
old := *dst
*dst = snap
mlx.Pin(snap)
// Drop references to the previous cached state root and transient incoming
// graph root now that a detached snapshot is retained in cache. Actual
// cleanup happens at the runner's normal sweep points.
if old != nil && old != snap {
mlx.Unpin(old)
}
if v != snap && v != old {
mlx.Unpin(v)
}
}
func (c *RecurrentCache) setStateRaw(dst **mlx.Array, v *mlx.Array) {
if v == nil || !v.Valid() {
return
}
if *dst == v {
return
}
old := *dst
*dst = v
mlx.Pin(v)
if old != nil && old != v {
mlx.Unpin(old)
}
}
func (c *RecurrentCache) setStateDetached(dst **mlx.Array, v *mlx.Array, ensureContiguous bool) {
if v == nil || !v.Valid() {
return
}
if *dst == v {
return
}
root := v
if ensureContiguous {
root = mlx.Contiguous(v, false)
}
detached := mlx.Detach(root)
old := *dst
*dst = detached
mlx.Pin(detached)
if old != nil && old != detached {
mlx.Unpin(old)
}
// Intentionally do not force-release root/v here. In the fast path, the detached
// handle aliases the same MLX value and may still be lazily computed. Releasing the
// source handles can invalidate the cached state before the next eval/sweep point.
}
func snapshotPinned(a *mlx.Array) *mlx.Array {
if a == nil || !a.Valid() {
return nil
}
snap := mlx.Snapshot(a)
mlx.Eval(snap)
mlx.Pin(snap)
return snap
}
func NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim int32) *RecurrentCache {
return &RecurrentCache{
convTail: int(convTail),
convDim: int(convDim),
numVHeads: int(numVHeads),
headVDim: int(headVDim),
headKDim: int(headKDim),
}
}
func (c *RecurrentCache) ensure(batch int, dtype mlx.DType) {
if batch <= 0 {
batch = 1
}
needConv := c.convState == nil || !c.convState.Valid() || c.convState.DType() != dtype ||
c.convState.Dim(0) != batch || c.convState.Dim(1) != c.convTail || c.convState.Dim(2) != c.convDim
needDelta := c.deltaState == nil || !c.deltaState.Valid() || c.deltaState.DType() != dtype ||
c.deltaState.Dim(0) != batch || c.deltaState.Dim(1) != c.numVHeads || c.deltaState.Dim(2) != c.headVDim || c.deltaState.Dim(3) != c.headKDim
if !needConv && !needDelta {
return
}
if needConv {
c.setStateRaw(&c.convState, mlx.Zeros(dtype, batch, c.convTail, c.convDim))
}
if needDelta {
c.setStateRaw(&c.deltaState, mlx.Zeros(dtype, batch, c.numVHeads, c.headVDim, c.headKDim))
}
}
func (c *RecurrentCache) ConvState(batch int, dtype mlx.DType) *mlx.Array {
c.ensure(batch, dtype)
return c.convState
}
func (c *RecurrentCache) SetConvState(v *mlx.Array) {
c.setStateMaterialized(&c.convState, v)
}
// SetConvStateFast stores conv state without forcing an immediate snapshot/eval.
// Use only for decode hot paths that accept higher transient memory until the next
// sync/sweep point. The conv-state input is usually a slice view, so request a
// compact contiguous copy to avoid pinning the whole source buffer.
func (c *RecurrentCache) SetConvStateFast(v *mlx.Array) {
c.setStateDetached(&c.convState, v, true)
}
func (c *RecurrentCache) DeltaState(batch int, dtype mlx.DType) *mlx.Array {
c.ensure(batch, dtype)
return c.deltaState
}
func (c *RecurrentCache) SetDeltaState(v *mlx.Array) {
c.setStateMaterialized(&c.deltaState, v)
}
// SetDeltaStateFast stores delta state without forcing an immediate snapshot/eval.
// Use only for decode hot paths that accept higher transient memory until the next
// sync/sweep point.
func (c *RecurrentCache) SetDeltaStateFast(v *mlx.Array) {
c.setStateDetached(&c.deltaState, v, false)
}
func (c *RecurrentCache) Advance(n int) {
c.offset += n
}
func (c *RecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
return keys, values
}
func (c *RecurrentCache) State() (*mlx.Array, *mlx.Array) {
return c.convState, c.deltaState
}
// Materialize returns the recurrent state roots (conv and delta) held by the cache.
func (c *RecurrentCache) Materialize() []*mlx.Array {
out := make([]*mlx.Array, 0, 2)
if c.convState != nil && c.convState.Valid() {
out = append(out, c.convState)
}
if c.deltaState != nil && c.deltaState.Valid() {
out = append(out, c.deltaState)
}
return out
}
func (c *RecurrentCache) CanTrim() bool { return false }
func (c *RecurrentCache) Trim(n int) int {
// Recurrent state is not directly trimmable. Divergent prefixes must drop the cache.
_ = n
return 0
}
func (c *RecurrentCache) Clone() Cache {
clone := &RecurrentCache{
offset: c.offset,
convTail: c.convTail,
convDim: c.convDim,
numVHeads: c.numVHeads,
headVDim: c.headVDim,
headKDim: c.headKDim,
convState: snapshotPinned(c.convState),
deltaState: snapshotPinned(c.deltaState),
}
return clone
}
func (c *RecurrentCache) Free() {
mlx.Unpin(c.convState, c.deltaState)
c.convState, c.deltaState = nil, nil
c.offset = 0
}
func (c *RecurrentCache) Offset() int { return c.offset }
func (c *RecurrentCache) Len() int { return c.offset }

View File

@@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"math"
"math/rand" "math/rand"
"net" "net"
"net/http" "net/http"
@@ -18,27 +19,25 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/x/imagegen" "github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
) )
// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models. // Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models.
type Client struct { type Client struct {
port int port int
modelName string modelName string
contextLength atomic.Int64 vramSize uint64
memory atomic.Uint64 done chan error
done chan error client *http.Client
client *http.Client lastErr string
lastErr string lastErrLock sync.Mutex
lastErrLock sync.Mutex mu sync.Mutex
mu sync.Mutex cmd *exec.Cmd
cmd *exec.Cmd
} }
// NewClient spawns a new MLX runner subprocess for LLM models and waits until it's ready. // NewClient spawns a new MLX runner subprocess for LLM models and waits until it's ready.
@@ -99,9 +98,18 @@ func NewClient(modelName string) (*Client, error) {
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal) slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
} }
// Estimate VRAM based on tensor size from manifest
var vramSize uint64
if modelManifest, err := manifest.LoadManifest(modelName); err == nil {
vramSize = uint64(modelManifest.TotalTensorSize())
} else {
vramSize = 8 * 1024 * 1024 * 1024
}
c := &Client{ c := &Client{
port: port, port: port,
modelName: modelName, modelName: modelName,
vramSize: vramSize,
done: make(chan error, 1), done: make(chan error, 1),
client: &http.Client{Timeout: 10 * time.Minute}, client: &http.Client{Timeout: 10 * time.Minute},
cmd: cmd, cmd: cmd,
@@ -182,34 +190,15 @@ func (c *Client) waitUntilRunning() error {
// completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization. // completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization.
type completionRequest struct { type completionRequest struct {
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Think *bool `json:"think,omitempty"`
Options *completionOpts `json:"options,omitempty"` Options *completionOpts `json:"options,omitempty"`
} }
type completionOpts struct { type completionOpts struct {
Temperature *float32 `json:"temperature,omitempty"` Temperature float32 `json:"temperature,omitempty"`
TopP *float32 `json:"top_p,omitempty"` TopP float32 `json:"top_p,omitempty"`
MinP *float32 `json:"min_p,omitempty"` MinP float32 `json:"min_p,omitempty"`
TopK *int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
RepeatLastN *int `json:"repeat_last_n,omitempty"` NumPredict int `json:"num_predict,omitempty"`
RepeatPenalty *float32 `json:"repeat_penalty,omitempty"`
PresencePenalty *float32 `json:"presence_penalty,omitempty"`
FrequencyPenalty *float32 `json:"frequency_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
}
type CompletionResponse struct {
Content string
Done bool
DoneReason int
PromptEvalCount int
PromptEvalDuration time.Duration
EvalCount int
EvalDuration time.Duration
PeakMemory uint64
Error *api.StatusError
} }
// Close terminates the subprocess. // Close terminates the subprocess.
@@ -233,27 +222,16 @@ func (c *Client) Close() error {
// Completion implements llm.LlamaServer. // Completion implements llm.LlamaServer.
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error { func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
var think *bool
if req.Think != nil {
enabled := req.Think.Bool()
think = &enabled
}
creq := completionRequest{ creq := completionRequest{
Prompt: req.Prompt, Prompt: req.Prompt,
Think: think,
} }
if req.Options != nil { if req.Options != nil {
creq.Options = &completionOpts{ creq.Options = &completionOpts{
Temperature: float32Ptr(req.Options.Temperature, hasExplicitOption(req.ExplicitOptions, "temperature")), Temperature: req.Options.Temperature,
TopP: float32Ptr(req.Options.TopP, hasExplicitOption(req.ExplicitOptions, "top_p")), TopP: req.Options.TopP,
MinP: float32Ptr(req.Options.MinP, hasExplicitOption(req.ExplicitOptions, "min_p")), MinP: req.Options.MinP,
TopK: intPtr(req.Options.TopK, hasExplicitOption(req.ExplicitOptions, "top_k")), TopK: req.Options.TopK,
RepeatLastN: intPtr(req.Options.RepeatLastN, hasExplicitOption(req.ExplicitOptions, "repeat_last_n")), NumPredict: req.Options.NumPredict,
RepeatPenalty: float32Ptr(req.Options.RepeatPenalty, hasExplicitOption(req.ExplicitOptions, "repeat_penalty")),
PresencePenalty: float32Ptr(req.Options.PresencePenalty, hasExplicitOption(req.ExplicitOptions, "presence_penalty")),
FrequencyPenalty: float32Ptr(req.Options.FrequencyPenalty, hasExplicitOption(req.ExplicitOptions, "frequency_penalty")),
NumPredict: req.Options.NumPredict,
} }
} }
@@ -282,25 +260,28 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() { for scanner.Scan() {
var raw CompletionResponse var raw struct {
Content string `json:"content,omitempty"`
Done bool `json:"done"`
DoneReason int `json:"done_reason,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration int `json:"eval_duration,omitempty"`
}
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil { if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes())) slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
continue continue
} }
if raw.Error != nil {
return *raw.Error
}
cresp := llm.CompletionResponse{ cresp := llm.CompletionResponse{
Content: raw.Content, Content: raw.Content,
Done: raw.Done, Done: raw.Done,
DoneReason: llm.DoneReason(raw.DoneReason), DoneReason: llm.DoneReason(raw.DoneReason),
PromptEvalCount: raw.PromptEvalCount, PromptEvalCount: raw.PromptEvalCount,
PromptEvalDuration: raw.PromptEvalDuration, PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
EvalCount: raw.EvalCount, EvalCount: raw.EvalCount,
EvalDuration: raw.EvalDuration, EvalDuration: time.Duration(raw.EvalDuration),
PeakMemory: raw.PeakMemory,
} }
fn(cresp) fn(cresp)
@@ -312,27 +293,8 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
return scanner.Err() return scanner.Err()
} }
func hasExplicitOption(explicit map[string]struct{}, key string) bool {
_, ok := explicit[key]
return ok
}
func float32Ptr(v float32, ok bool) *float32 {
if !ok {
return nil
}
return &v
}
func intPtr(v int, ok bool) *int {
if !ok {
return nil
}
return &v
}
func (c *Client) ContextLength() int { func (c *Client) ContextLength() int {
return int(c.contextLength.Load()) return math.MaxInt
} }
// Detokenize implements llm.LlamaServer. // Detokenize implements llm.LlamaServer.
@@ -385,16 +347,9 @@ func (c *Client) Pid() int {
return -1 return -1
} }
type statusResponse struct {
Status int
Progress int
ContextLength int
Memory uint64
}
// Ping implements llm.LlamaServer. // Ping implements llm.LlamaServer.
func (c *Client) Ping(ctx context.Context) error { func (c *Client) Ping(ctx context.Context) error {
reqURL := fmt.Sprintf("http://127.0.0.1:%d/v1/status", c.port) reqURL := fmt.Sprintf("http://127.0.0.1:%d/health", c.port)
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
if err != nil { if err != nil {
return err return err
@@ -407,15 +362,6 @@ func (c *Client) Ping(ctx context.Context) error {
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return fmt.Errorf("health check failed: %d", resp.StatusCode) return fmt.Errorf("health check failed: %d", resp.StatusCode)
} }
var status statusResponse
if err := json.NewDecoder(resp.Body).Decode(&status); err != nil {
return err
}
c.contextLength.Store(int64(status.ContextLength))
c.memory.Store(status.Memory)
return nil return nil
} }
@@ -442,24 +388,19 @@ func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
return tokens, nil return tokens, nil
} }
func (c *Client) currentMemory() uint64 { // TotalSize implements llm.LlamaServer.
ctx, cancel := context.WithTimeout(context.Background(), time.Second) func (c *Client) TotalSize() uint64 {
defer cancel() return c.vramSize
if err := c.Ping(ctx); err != nil {
slog.Warn("failed to get current memory", "error", err)
}
return c.memory.Load()
}
// MemorySize implements llm.LlamaServer.
func (c *Client) MemorySize() (total, vram uint64) {
mem := c.currentMemory()
return mem, mem
} }
// VRAMByGPU implements llm.LlamaServer. // VRAMByGPU implements llm.LlamaServer.
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 { func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
return c.currentMemory() return c.vramSize
}
// VRAMSize implements llm.LlamaServer.
func (c *Client) VRAMSize() uint64 {
return c.vramSize
} }
// WaitUntilRunning implements llm.LlamaServer. // WaitUntilRunning implements llm.LlamaServer.

View File

@@ -1,167 +0,0 @@
package mlxrunner
import (
"context"
"encoding/json"
"io"
"net/http"
"strings"
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
)
func TestCompletionForwardsThink(t *testing.T) {
boolPtr := func(v bool) *bool { return &v }
testCases := []struct {
name string
think *api.ThinkValue
want *bool
}{
{name: "unset", think: nil, want: nil},
{name: "enabled", think: &api.ThinkValue{Value: true}, want: boolPtr(true)},
{name: "disabled", think: &api.ThinkValue{Value: false}, want: boolPtr(false)},
{name: "level maps to enabled", think: &api.ThinkValue{Value: "high"}, want: boolPtr(true)},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var got completionRequest
rt := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path != "/completion" {
t.Fatalf("request path = %q, want %q", r.URL.Path, "/completion")
}
if err := json.NewDecoder(r.Body).Decode(&got); err != nil {
return nil, err
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("{\"done\":true}\n")),
Request: r,
}, nil
})
c := &Client{
port: 11434,
client: &http.Client{
Transport: rt,
},
}
err := c.Completion(context.Background(), llm.CompletionRequest{
Prompt: "hello",
Think: tc.think,
}, func(llm.CompletionResponse) {})
if err != nil {
t.Fatalf("completion request failed: %v", err)
}
if got.Prompt != "hello" {
t.Fatalf("prompt = %q, want %q", got.Prompt, "hello")
}
switch {
case tc.want == nil && got.Think != nil:
t.Fatalf("think = %v, want nil", *got.Think)
case tc.want != nil && got.Think == nil:
t.Fatalf("think = nil, want %v", *tc.want)
case tc.want != nil && got.Think != nil && *tc.want != *got.Think:
t.Fatalf("think = %v, want %v", *got.Think, *tc.want)
}
})
}
}
func TestCompletionForwardsOnlySpecifiedSamplingOptions(t *testing.T) {
var got completionRequest
rt := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if err := json.NewDecoder(r.Body).Decode(&got); err != nil {
return nil, err
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("{\"done\":true}\n")),
Request: r,
}, nil
})
c := &Client{
port: 11434,
client: &http.Client{
Transport: rt,
},
}
opts := &api.Options{
Temperature: 1.0,
TopP: 0.95,
MinP: 0.1,
TopK: 20,
RepeatLastN: 128,
RepeatPenalty: 1.2,
PresencePenalty: 1.5,
FrequencyPenalty: 0.25,
NumPredict: 64,
}
err := c.Completion(context.Background(), llm.CompletionRequest{
Prompt: "hello",
Options: opts,
ExplicitOptions: map[string]struct{}{
"temperature": {},
"top_k": {},
"repeat_penalty": {},
"presence_penalty": {},
},
}, func(llm.CompletionResponse) {})
if err != nil {
t.Fatalf("completion request failed: %v", err)
}
if got.Options == nil {
t.Fatal("options = nil, want serialized options")
}
if got.Options.Temperature == nil || *got.Options.Temperature != opts.Temperature {
t.Fatalf("temperature = %v, want %v", got.Options.Temperature, opts.Temperature)
}
if got.Options.TopK == nil || *got.Options.TopK != opts.TopK {
t.Fatalf("top_k = %v, want %v", got.Options.TopK, opts.TopK)
}
if got.Options.RepeatPenalty == nil || *got.Options.RepeatPenalty != opts.RepeatPenalty {
t.Fatalf("repeat_penalty = %v, want %v", got.Options.RepeatPenalty, opts.RepeatPenalty)
}
if got.Options.PresencePenalty == nil || *got.Options.PresencePenalty != opts.PresencePenalty {
t.Fatalf("presence_penalty = %v, want %v", got.Options.PresencePenalty, opts.PresencePenalty)
}
if got.Options.TopP != nil {
t.Fatalf("top_p = %v, want nil", *got.Options.TopP)
}
if got.Options.MinP != nil {
t.Fatalf("min_p = %v, want nil", *got.Options.MinP)
}
if got.Options.RepeatLastN != nil {
t.Fatalf("repeat_last_n = %v, want nil", *got.Options.RepeatLastN)
}
if got.Options.FrequencyPenalty != nil {
t.Fatalf("frequency_penalty = %v, want nil", *got.Options.FrequencyPenalty)
}
if got.Options.NumPredict != opts.NumPredict {
t.Fatalf("num_predict = %d, want %d", got.Options.NumPredict, opts.NumPredict)
}
}
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return f(r)
}

View File

@@ -7,6 +7,4 @@ import (
_ "github.com/ollama/ollama/x/models/glm4_moe_lite" _ "github.com/ollama/ollama/x/models/glm4_moe_lite"
_ "github.com/ollama/ollama/x/models/llama" _ "github.com/ollama/ollama/x/models/llama"
_ "github.com/ollama/ollama/x/models/qwen3" _ "github.com/ollama/ollama/x/models/qwen3"
_ "github.com/ollama/ollama/x/models/qwen3_5"
_ "github.com/ollama/ollama/x/models/qwen3_5_moe"
) )

View File

@@ -1,275 +0,0 @@
//go:build mlx
package mlx
// #include <stdlib.h>
// #include "generated.h"
import "C"
import (
"sync"
"sync/atomic"
"unsafe"
)
var (
gatedDeltaMetalKernelOnce sync.Once
gatedDeltaMetalKernel C.mlx_fast_metal_kernel
gatedDeltaMetalDisabled atomic.Bool
)
const gatedDeltaMetalKernelSource = `
auto n = thread_position_in_grid.z;
auto b_idx = n / Hv;
auto hv_idx = n % Hv;
auto hk_idx = hv_idx / (Hv / Hk);
constexpr int n_per_t = Dk / 32;
// q, k: [B, T, Hk, Dk]
auto q_ = q + b_idx * T * Hk * Dk + hk_idx * Dk;
auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk;
// v, y: [B, T, Hv, Dv]
auto v_ = v + b_idx * T * Hv * Dv + hv_idx * Dv;
y += b_idx * T * Hv * Dv + hv_idx * Dv;
auto dk_idx = thread_position_in_threadgroup.x;
auto dv_idx = thread_position_in_grid.y;
// state_in, state_out: [B, Hv, Dv, Dk]
auto i_state = state_in + (n * Dv + dv_idx) * Dk;
auto o_state = state_out + (n * Dv + dv_idx) * Dk;
float state[n_per_t];
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = static_cast<float>(i_state[s_idx]);
}
// g: [B, T, Hv]
auto g_ = g + b_idx * T * Hv;
auto beta_ = beta + b_idx * T * Hv;
for (int t = 0; t < T; ++t) {
float kv_mem = 0.0f;
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = state[i] * g_[hv_idx];
kv_mem += state[i] * k_[s_idx];
}
kv_mem = simd_sum(kv_mem);
auto delta = (v_[dv_idx] - kv_mem) * beta_[hv_idx];
float out = 0.0f;
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = state[i] + k_[s_idx] * delta;
out += state[i] * q_[s_idx];
}
out = simd_sum(out);
if (thread_index_in_simdgroup == 0) {
y[dv_idx] = static_cast<InT>(out);
}
q_ += Hk * Dk;
k_ += Hk * Dk;
v_ += Hv * Dv;
y += Hv * Dv;
g_ += Hv;
beta_ += Hv;
}
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
o_state[s_idx] = static_cast<InT>(state[i]);
}
`
func cStringVector(values []string) (C.mlx_vector_string, func(), bool) {
vec := C.mlx_vector_string_new()
ok := true
for _, s := range values {
cs := C.CString(s)
if C.mlx_vector_string_append_value(vec, cs) != 0 {
ok = false
}
C.free(unsafe.Pointer(cs))
if !ok {
break
}
}
cleanup := func() {
C.mlx_vector_string_free(vec)
}
return vec, cleanup, ok
}
func initGatedDeltaMetalKernel() {
inputs, freeInputs, ok := cStringVector([]string{"q", "k", "v", "g", "beta", "state_in", "T"})
if !ok {
gatedDeltaMetalDisabled.Store(true)
freeInputs()
return
}
defer freeInputs()
outputs, freeOutputs, ok := cStringVector([]string{"y", "state_out"})
if !ok {
gatedDeltaMetalDisabled.Store(true)
freeOutputs()
return
}
defer freeOutputs()
cName := C.CString("gated_delta_step")
defer C.free(unsafe.Pointer(cName))
cSource := C.CString(gatedDeltaMetalKernelSource)
defer C.free(unsafe.Pointer(cSource))
cHeader := C.CString("")
defer C.free(unsafe.Pointer(cHeader))
gatedDeltaMetalKernel = C.mlx_fast_metal_kernel_new(
cName,
inputs,
outputs,
cSource,
cHeader,
C.bool(true),
C.bool(false),
)
}
// GatedDeltaKernel runs a fused Metal kernel for the qwen3.5 recurrent update.
// It returns ok=false on unsupported shapes/devices or kernel setup/apply failure.
func GatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok bool) {
if gatedDeltaMetalDisabled.Load() {
return nil, nil, false
}
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
return nil, nil, false
}
if !q.Valid() || !k.Valid() || !v.Valid() || !g.Valid() || !beta.Valid() || !state.Valid() {
return nil, nil, false
}
qd := q.Dims()
kd := k.Dims()
vd := v.Dims()
gd := g.Dims()
bd := beta.Dims()
sd := state.Dims()
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
return nil, nil, false
}
B, T, Hk, Dk := qd[0], qd[1], qd[2], qd[3]
if T <= 0 || Hk <= 0 || Dk <= 0 || Dk%32 != 0 {
return nil, nil, false
}
if kd[0] != B || kd[1] != T || kd[2] != Hk || kd[3] != Dk {
return nil, nil, false
}
Hv, Dv := vd[2], vd[3]
if vd[0] != B || vd[1] != T || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
return nil, nil, false
}
if gd[0] != B || gd[1] != T || gd[2] != Hv {
return nil, nil, false
}
if bd[0] != B || bd[1] != T || bd[2] != Hv {
return nil, nil, false
}
if sd[0] != B || sd[1] != Hv || sd[2] != Dv || sd[3] != Dk {
return nil, nil, false
}
dtype := q.DType()
if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype || state.DType() != dtype {
return nil, nil, false
}
gatedDeltaMetalKernelOnce.Do(initGatedDeltaMetalKernel)
if gatedDeltaMetalDisabled.Load() {
return nil, nil, false
}
cfg := C.mlx_fast_metal_kernel_config_new()
defer C.mlx_fast_metal_kernel_config_free(cfg)
cInT := C.CString("InT")
defer C.free(unsafe.Pointer(cInT))
if C.mlx_fast_metal_kernel_config_add_template_arg_dtype(cfg, cInT, C.mlx_dtype(dtype)) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
for _, tpl := range []struct {
name string
value int
}{
{name: "Dk", value: Dk},
{name: "Dv", value: Dv},
{name: "Hk", value: Hk},
{name: "Hv", value: Hv},
} {
cn := C.CString(tpl.name)
rc := C.mlx_fast_metal_kernel_config_add_template_arg_int(cfg, cn, C.int(tpl.value))
C.free(unsafe.Pointer(cn))
if rc != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
}
yShape := []C.int{C.int(B), C.int(T), C.int(Hv), C.int(Dv)}
stateShape := []C.int{C.int(B), C.int(Hv), C.int(Dv), C.int(Dk)}
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(yShape), C.size_t(len(yShape)), C.mlx_dtype(dtype)) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(dtype)) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
if C.mlx_fast_metal_kernel_config_set_grid(cfg, 32, C.int(Dv), C.int(B*Hv)) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
threadY := Dv
if threadY > 4 {
threadY = 4
}
if C.mlx_fast_metal_kernel_config_set_thread_group(cfg, 32, C.int(threadY), 1) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
tScalar := FromValue(T)
inputs := []C.mlx_array{
q.ctx,
k.ctx,
v.ctx,
g.ctx,
beta.ctx,
state.ctx,
tScalar.ctx,
}
inVec := C.mlx_vector_array_new_data(unsafe.SliceData(inputs), C.size_t(len(inputs)))
defer C.mlx_vector_array_free(inVec)
outVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(outVec)
if C.mlx_fast_metal_kernel_apply(&outVec, gatedDeltaMetalKernel, inVec, cfg, DefaultStream().ctx) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
if int(C.mlx_vector_array_size(outVec)) < 2 {
return nil, nil, false
}
y = New("GATED_DELTA_METAL_Y")
nextState = New("GATED_DELTA_METAL_STATE")
C.mlx_vector_array_get(&y.ctx, outVec, 0)
C.mlx_vector_array_get(&nextState.ctx, outVec, 1)
return y, nextState, true
}

View File

@@ -64,10 +64,6 @@ func PeakMemory() int {
return int(peak) return int(peak)
} }
func ResetPeakMemory() {
C.mlx_reset_peak_memory()
}
type Memory struct{} type Memory struct{}
func (Memory) LogValue() slog.Value { func (Memory) LogValue() slog.Value {

View File

@@ -19,7 +19,7 @@ func doEval(outputs []*Array, async bool) {
defer C.mlx_vector_array_free(vector) defer C.mlx_vector_array_free(vector)
for _, output := range outputs { for _, output := range outputs {
if output != nil && output.Valid() { if output.Valid() {
C.mlx_vector_array_append_value(vector, output.ctx) C.mlx_vector_array_append_value(vector, output.ctx)
} }
} }

View File

@@ -93,12 +93,6 @@ func (t *Array) Divide(other *Array) *Array {
return out return out
} }
func (t *Array) Cumsum(axis int, reverse, inclusive bool) *Array {
out := New("CUMSUM")
C.mlx_cumsum(&out.ctx, t.ctx, C.int(axis), C.bool(reverse), C.bool(inclusive), DefaultStream().ctx)
return out
}
func (t *Array) ExpandDims(axis int) *Array { func (t *Array) ExpandDims(axis int) *Array {
out := New("EXPAND_DIMS") out := New("EXPAND_DIMS")
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx) C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
@@ -129,30 +123,12 @@ func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array {
return out return out
} }
func (t *Array) GreaterEqual(other *Array) *Array {
out := New("GREATER_EQUAL")
C.mlx_greater_equal(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Logsumexp(keepDims bool) *Array { func (t *Array) Logsumexp(keepDims bool) *Array {
out := New("LOGSUMEXP") out := New("LOGSUMEXP")
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx) C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
return out return out
} }
func (t *Array) Less(other *Array) *Array {
out := New("LESS")
C.mlx_less(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) LogicalOr(other *Array) *Array {
out := New("LOGICAL_OR")
C.mlx_logical_or(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Matmul(other *Array) *Array { func (t *Array) Matmul(other *Array) *Array {
out := New("MATMUL") out := New("MATMUL")
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx) C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)

View File

@@ -113,35 +113,6 @@ func Where(condition, a, b *Array) *Array {
return out return out
} }
func Conv1d(x, weight *Array, bias *Array, stride, padding, dilation, groups int32) *Array {
out := New("CONV1D")
C.mlx_conv1d(
&out.ctx,
x.ctx,
weight.ctx,
C.int(stride),
C.int(padding),
C.int(dilation),
C.int(groups),
DefaultStream().ctx,
)
if bias != nil && bias.Valid() {
out = Add(out, bias)
}
return out
}
func Contiguous(a *Array, allowColMajor bool) *Array {
out := New("CONTIGUOUS")
C.mlx_contiguous(&out.ctx, a.ctx, C.bool(allowColMajor), DefaultStream().ctx)
return out
}
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
groups := int32(x.Dim(x.NumDims() - 1))
return Conv1d(x, weight, bias, 1, 0, 1, groups)
}
// Convenience wrappers (function-style for the model code) // Convenience wrappers (function-style for the model code)
func Stack(arrays []*Array, axis int) *Array { func Stack(arrays []*Array, axis int) *Array {
@@ -300,24 +271,6 @@ func Sigmoid(a *Array) *Array {
return a.Sigmoid() return a.Sigmoid()
} }
func Exp(a *Array) *Array {
out := New("EXP")
C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Log(a *Array) *Array {
out := New("LOG")
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func SoftmaxAxis(a *Array, axis int, precise bool) *Array {
out := New("SOFTMAX_AXIS")
C.mlx_softmax_axis(&out.ctx, a.ctx, C.int(axis), C.bool(precise), DefaultStream().ctx)
return out
}
func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask bool) *Array { func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask bool) *Array {
mask := New("") mask := New("")
sinks := New("") sinks := New("")
@@ -335,11 +288,7 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b
func RMSNormFn(x, weight *Array, eps float32) *Array { func RMSNormFn(x, weight *Array, eps float32) *Array {
out := New("FAST_RMSNORM") out := New("FAST_RMSNORM")
var w C.mlx_array C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx)
if weight != nil {
w = weight.ctx
}
C.mlx_fast_rms_norm(&out.ctx, x.ctx, w, C.float(eps), DefaultStream().ctx)
return out return out
} }
@@ -429,27 +378,6 @@ func Collect(v any) []*Array {
return arrays return arrays
} }
// Snapshot copies an array into a fresh leaf value with no Go-side graph inputs.
func Snapshot(a *Array) *Array {
if a == nil || !a.Valid() {
return a
}
out := New("SNAPSHOT")
C.mlx_copy(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// Detach returns a new Array handle that shares the same MLX value but does
// not retain Go-side graph input references.
func Detach(a *Array) *Array {
if a == nil || !a.Valid() {
return a
}
out := New("DETACH")
C.mlx_array_set(&out.ctx, a.ctx)
return out
}
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) { func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
if !v.IsValid() { if !v.IsValid() {
return return

View File

@@ -20,7 +20,6 @@ type Model interface {
Unembed(x *mlx.Array) *mlx.Array Unembed(x *mlx.Array) *mlx.Array
NumLayers() int NumLayers() int
Tokenizer() *tokenizer.Tokenizer Tokenizer() *tokenizer.Tokenizer
MaxContextLength() int
// LoadWeights receives all tensors loaded from the manifest and assigns // LoadWeights receives all tensors loaded from the manifest and assigns
// them to model fields. Model-specific logic (MLA absorption, expert // them to model fields. Model-specific logic (MLA absorption, expert

View File

@@ -6,30 +6,18 @@ import (
"bytes" "bytes"
"context" "context"
"errors" "errors"
"fmt"
"log/slog" "log/slog"
"net/http"
"time" "time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/mlx"
) )
func prefillChunkSize() int {
return 2 << 10
}
func (r *Runner) TextGenerationPipeline(request Request) error { func (r *Runner) TextGenerationPipeline(request Request) error {
if r.Model == nil { if r.Model == nil {
return errors.New("model not loaded") return errors.New("model not loaded")
} }
ctx := request.Ctx
if ctx == nil {
ctx = context.Background()
}
var ( var (
sample, logprobs *mlx.Array sample, logprobs *mlx.Array
nextSample, nextLogprobs *mlx.Array nextSample, nextLogprobs *mlx.Array
@@ -56,72 +44,43 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
} else { } else {
mlx.DisableCompile() mlx.DisableCompile()
} }
mlx.ResetPeakMemory()
inputs := r.Tokenizer.Encode(request.Prompt, true) inputs := r.Tokenizer.Encode(request.Prompt, true)
if len(inputs) == 0 {
return errors.New("empty prompt")
}
if len(inputs) >= r.contextLength {
return api.StatusError{
StatusCode: http.StatusBadRequest,
ErrorMessage: fmt.Sprintf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(inputs), r.contextLength),
}
}
// Cap generation to stay within the model's context length
maxGenerate := r.contextLength - len(inputs)
if request.Options.MaxTokens <= 0 {
request.Options.MaxTokens = maxGenerate
} else {
request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate)
}
session := r.cache.begin(r.Model, inputs) session := r.cache.begin(r.Model, inputs)
defer session.close() defer session.close()
caches := session.caches caches := session.caches
tokens := session.remaining tokens := session.remaining
history := append([]int32(nil), session.inputs...)
prefillChunk := prefillChunkSize()
materializeCaches := func() {
state := make([]*mlx.Array, 0, 2*len(caches))
for _, c := range caches {
if c == nil {
continue
}
state = append(state, c.Materialize()...)
}
if len(state) == 0 {
return
}
mlx.Eval(state...)
}
now := time.Now()
total, processed := len(tokens), 0 total, processed := len(tokens), 0
slog.Info("Prompt processing progress", "processed", processed, "total", total)
for total-processed > 1 { for total-processed > 1 {
if err := ctx.Err(); err != nil { if err := request.Ctx.Err(); err != nil {
return err return err
} }
n := min(prefillChunk, total-processed-1) n := min(2<<10, total-processed-1)
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches) r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
mlx.Sweep() mlx.Sweep()
materializeCaches() mlx.Eval(func() []*mlx.Array {
s := make([]*mlx.Array, 2*len(caches))
for i, c := range caches {
s[2*i], s[2*i+1] = c.State()
}
return s
}()...)
processed += n processed += n
slog.Info("Prompt processing progress", "processed", processed, "total", total) slog.Info("Prompt processing progress", "processed", processed, "total", total)
mlx.ClearCache() mlx.ClearCache()
} }
step := func(token *mlx.Array, history []int32) (*mlx.Array, *mlx.Array) { step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
fwd := r.Model.Forward(token.ExpandDims(0), caches) fwd := r.Model.Forward(token.ExpandDims(0), caches)
logits := r.Model.Unembed(fwd) logits := r.Model.Unembed(fwd)
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1) logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
logprobs := logits.Subtract(logits.Logsumexp(true)) logprobs := logits.Subtract(logits.Logsumexp(true))
sample := request.Sample(logprobs, history) sample := request.Sample(logprobs)
mlx.Pin(sample, logprobs) mlx.Pin(sample, logprobs)
mlx.Sweep() mlx.Sweep()
@@ -130,42 +89,45 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
return sample, logprobs return sample, logprobs
} }
sample, logprobs = step(mlx.FromValues(tokens[processed:], total-processed), history) sample, logprobs = step(mlx.FromValues(tokens[processed:], total-processed))
var b bytes.Buffer var b bytes.Buffer
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1} now := time.Now()
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
for i := range request.Options.MaxTokens { for i := range request.Options.MaxTokens {
if err := ctx.Err(); err != nil { if err := request.Ctx.Err(); err != nil {
return err return err
} }
nextSample, nextLogprobs = step(sample)
if i == 0 { if i == 0 {
slog.Info("Prompt processing progress", "processed", total, "total", total)
mlx.Eval(sample) mlx.Eval(sample)
final.PromptEvalDuration = time.Since(now) final.PromptTokensDuration = time.Since(now)
now = time.Now() now = time.Now()
} }
output := int32(sample.Int()) output := int32(sample.Int())
session.outputs = append(session.outputs, output) session.outputs = append(session.outputs, output)
history = append(history, output)
if r.Tokenizer.IsEOS(output) { if r.Tokenizer.IsEOS(output) {
final.Token = int(output)
final.DoneReason = 0 final.DoneReason = 0
final.EvalCount = i final.CompletionTokens = i
break break
} }
select { select {
case <-request.Ctx.Done(): case <-request.Ctx.Done():
return request.Ctx.Err() return request.Ctx.Err()
case request.Responses <- CompletionResponse{ case request.Responses <- Response{
Content: r.Decode(output, &b), Text: r.Decode(output, &b),
Token: int(output),
}: }:
} }
nextSample, nextLogprobs = step(sample, history)
mlx.Unpin(sample, logprobs) mlx.Unpin(sample, logprobs)
sample, logprobs = nextSample, nextLogprobs sample, logprobs = nextSample, nextLogprobs
nextSample, nextLogprobs = nil, nil nextSample, nextLogprobs = nil, nil
@@ -175,11 +137,10 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
} }
} }
final.EvalDuration = time.Since(now) final.CompletionTokensDuration = time.Since(now)
final.PeakMemory = uint64(mlx.PeakMemory())
select { select {
case <-ctx.Done(): case <-request.Ctx.Done():
return ctx.Err() return request.Ctx.Err()
case request.Responses <- final: case request.Responses <- final:
return nil return nil
} }

View File

@@ -4,15 +4,14 @@ package mlxrunner
import ( import (
"context" "context"
"errors"
"log/slog" "log/slog"
"net" "net"
"net/http" "net/http"
"strings" "strings"
"time"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model" "github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base" "github.com/ollama/ollama/x/mlxrunner/model/base"
@@ -22,7 +21,7 @@ import (
type Request struct { type Request struct {
TextCompletionsRequest TextCompletionsRequest
Responses chan CompletionResponse Responses chan Response
Pipeline func(Request) error Pipeline func(Request) error
Ctx context.Context Ctx context.Context
@@ -32,29 +31,37 @@ type Request struct {
type TextCompletionsRequest struct { type TextCompletionsRequest struct {
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Think *bool `json:"think,omitempty"`
Options struct { Options struct {
Temperature *float32 `json:"temperature"` Temperature float32 `json:"temperature"`
TopP *float32 `json:"top_p"` TopP float32 `json:"top_p"`
MinP *float32 `json:"min_p"` MinP float32 `json:"min_p"`
TopK *int `json:"top_k"` TopK int `json:"top_k"`
RepeatLastN *int `json:"repeat_last_n"` MaxTokens int `json:"max_tokens"`
RepeatPenalty *float32 `json:"repeat_penalty"`
PresencePenalty *float32 `json:"presence_penalty"`
FrequencyPenalty *float32 `json:"frequency_penalty"`
MaxTokens int `json:"max_tokens"`
// Deprecated: use MaxTokens instead // Deprecated: use MaxTokens instead
NumPredict int `json:"num_predict"` NumPredict int `json:"num_predict"`
} `json:"options"` } `json:"options"`
} }
type Response struct {
Text string `json:"content,omitempty"`
Token int `json:"token,omitempty"`
Logprobs []float32 `json:"logprobs,omitempty"`
Done bool `json:"done,omitempty"`
DoneReason int `json:"done_reason,omitempty"`
PromptTokens int `json:"prompt_eval_count,omitempty"`
PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"`
CompletionTokens int `json:"eval_count,omitempty"`
CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
}
type Runner struct { type Runner struct {
Model base.Model Model base.Model
Tokenizer *tokenizer.Tokenizer Tokenizer *tokenizer.Tokenizer
Requests chan Request Requests chan Request
cache kvCache cache kvCache
contextLength int
} }
func (r *Runner) Load(modelName string) error { func (r *Runner) Load(modelName string) error {
@@ -83,7 +90,6 @@ func (r *Runner) Load(modelName string) error {
r.Model = m r.Model = m
r.Tokenizer = m.Tokenizer() r.Tokenizer = m.Tokenizer()
r.contextLength = m.MaxContextLength()
return nil return nil
} }
@@ -152,17 +158,6 @@ func (r *Runner) Run(host, port string, mux http.Handler) error {
case request := <-r.Requests: case request := <-r.Requests:
if err := request.Pipeline(request); err != nil { if err := request.Pipeline(request); err != nil {
slog.Info("Request terminated", "error", err) slog.Info("Request terminated", "error", err)
var statusErr api.StatusError
if !errors.As(err, &statusErr) {
statusErr = api.StatusError{
StatusCode: http.StatusInternalServerError,
ErrorMessage: err.Error(),
}
}
select {
case request.Responses <- CompletionResponse{Error: &statusErr}:
case <-request.Ctx.Done():
}
} }
close(request.Responses) close(request.Responses)

View File

@@ -9,204 +9,69 @@ import (
) )
type Sampler interface { type Sampler interface {
Sample(*mlx.Array, []int32) *mlx.Array Sample(*mlx.Array) *mlx.Array
} }
func New(temp, top_p, min_p float32, top_k, repeatLastN int, repeatPenalty, presencePenalty, frequencyPenalty float32) Sampler { func New(temp, top_p, min_p float32, top_k int) Sampler {
var samplers []Sampler if temp == 0 {
if repeatLastN > 0 && (repeatPenalty != 1 || presencePenalty != 0 || frequencyPenalty != 0) { return greedy{}
samplers = append(samplers, Penalty{
RepeatLastN: repeatLastN,
RepeatPenalty: repeatPenalty,
PresencePenalty: presencePenalty,
FrequencyPenalty: frequencyPenalty,
})
} }
if temp == 0 { var samplers []Sampler
samplers = append(samplers, greedy{}) if top_p > 0 && top_p < 1 {
} else { samplers = append(samplers, TopP(top_p))
samplers = append(samplers, Distribution{
Temperature: temp,
TopK: top_k,
TopP: top_p,
MinP: min_p,
})
} }
if min_p != 0 {
samplers = append(samplers, MinP(min_p))
}
if top_k > 0 {
samplers = append(samplers, TopK(top_k))
}
samplers = append(samplers, Temperature(temp))
return chain(samplers) return chain(samplers)
} }
type greedy struct{} type greedy struct{}
func (greedy) Sample(logits *mlx.Array, _ []int32) *mlx.Array { func (greedy) Sample(logits *mlx.Array) *mlx.Array {
return logits.Argmax(-1, false) return logits.Argmax(-1, false)
} }
type chain []Sampler type chain []Sampler
func (c chain) Sample(logits *mlx.Array, history []int32) *mlx.Array { func (c chain) Sample(logits *mlx.Array) *mlx.Array {
for _, sampler := range c { for _, sampler := range c {
logits = sampler.Sample(logits, history) logits = sampler.Sample(logits)
} }
return logits return logits
} }
type Distribution struct { type Temperature float32
Temperature float32
TopK int func (t Temperature) Sample(logits *mlx.Array) *mlx.Array {
TopP float32 return mlx.DivScalar(logits, float32(t)).Categorical(-1)
MinP float32
} }
func (d Distribution) Sample(logits *mlx.Array, _ []int32) *mlx.Array { type TopP float32
filtered, indices := d.filter(logits)
sample := filtered.Categorical(-1)
if indices == nil {
return sample
}
positions := sample.ExpandDims(1) func (p TopP) Sample(logprobs *mlx.Array) *mlx.Array {
return indices.TakeAlongAxis(positions, -1).Squeeze(1) // TODO: implement
return logprobs
} }
func (d Distribution) filter(logits *mlx.Array) (*mlx.Array, *mlx.Array) { type MinP float32
candidates := logits
var candidateIndices *mlx.Array
if d.TopK > 0 && d.TopK < logits.Dim(logits.NumDims()-1) { func (p MinP) Sample(logprobs *mlx.Array) *mlx.Array {
partitions := logits.Negative().ArgpartitionAxis(d.TopK-1, -1) // TODO: implement
switch logits.NumDims() { return logprobs
case 1:
candidateIndices = partitions.Slice(mlx.Slice(0, d.TopK))
default:
candidateIndices = partitions.Slice(mlx.Slice(), mlx.Slice(0, d.TopK))
}
candidates = logits.TakeAlongAxis(candidateIndices, -1)
}
if d.Temperature != 1 {
candidates = mlx.DivScalar(candidates, d.Temperature)
}
if !d.needsProbabilityFilters() {
return candidates, candidateIndices
}
order := candidates.Negative().ArgsortAxis(-1)
sortedLogits := candidates.TakeAlongAxis(order, -1)
sortedProbs := mlx.SoftmaxAxis(candidates, -1, true).TakeAlongAxis(order, -1)
remove := d.topPRemovalMask(sortedProbs)
if d.MinP > 0 {
minPRemove := d.minPRemovalMask(sortedProbs)
if remove == nil {
remove = minPRemove
} else {
remove = remove.LogicalOr(minPRemove)
}
}
if remove == nil {
return candidates, candidateIndices
}
negInf := mlx.FromValue(float32(math.Inf(-1)))
filtered := mlx.Where(remove, negInf, sortedLogits)
return candidates.PutAlongAxis(order, filtered, -1), candidateIndices
} }
func (d Distribution) needsProbabilityFilters() bool { type TopK int
return (d.TopP > 0 && d.TopP < 1) || d.MinP > 0
} func (k TopK) Sample(logprobs *mlx.Array) *mlx.Array {
mask := logprobs.Negative().ArgpartitionAxis(int(k)-1, -1).Slice(mlx.Slice(), mlx.Slice(int(k), 0))
func (d Distribution) topPRemovalMask(sortedProbs *mlx.Array) *mlx.Array { return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
if d.TopP <= 0 || d.TopP >= 1 {
return nil
}
threshold := mlx.NewScalarArray(d.TopP)
prevCum := sortedProbs.Cumsum(-1, false, true).Subtract(sortedProbs)
return prevCum.GreaterEqual(threshold)
}
func (d Distribution) minPRemovalMask(sortedProbs *mlx.Array) *mlx.Array {
if d.MinP <= 0 {
return nil
}
var maxProb *mlx.Array
switch sortedProbs.NumDims() {
case 1:
maxProb = sortedProbs.Slice(mlx.Slice(0, 1))
default:
maxProb = sortedProbs.Slice(mlx.Slice(), mlx.Slice(0, 1))
}
threshold := mlx.MulScalar(maxProb, d.MinP)
return sortedProbs.Less(threshold)
}
type Penalty struct {
RepeatLastN int
RepeatPenalty float32
PresencePenalty float32
FrequencyPenalty float32
}
func (p Penalty) Sample(logprobs *mlx.Array, history []int32) *mlx.Array {
if len(history) == 0 {
return logprobs
}
window := p.RepeatLastN
if window <= 0 || window > len(history) {
window = len(history)
}
counts := make(map[int32]int, window)
order := make([]int32, 0, window)
for _, token := range history[len(history)-window:] {
if token < 0 {
continue
}
if counts[token] == 0 {
order = append(order, token)
}
counts[token]++
}
if len(order) == 0 {
return logprobs
}
indexShape := []int32{int32(len(order))}
valueShape := []int{len(order)}
if logprobs.NumDims() > 1 {
indexShape = []int32{1, int32(len(order))}
valueShape = []int{1, len(order)}
}
indices := mlx.NewArrayInt32(order, indexShape)
selected := logprobs.TakeAlongAxis(indices, -1)
mlx.Eval(selected)
values := selected.Floats()
for i, token := range order {
v := values[i]
if p.RepeatPenalty != 1 {
if v < 0 {
v *= p.RepeatPenalty
} else {
v /= p.RepeatPenalty
}
}
if p.PresencePenalty != 0 {
v -= p.PresencePenalty
}
if p.FrequencyPenalty != 0 {
v -= p.FrequencyPenalty * float32(counts[token])
}
values[i] = v
}
return logprobs.PutAlongAxis(indices, mlx.FromValues(values, valueShape...), -1)
} }

View File

@@ -1,104 +0,0 @@
//go:build mlx
package sample
import (
"math"
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func TestPenaltySample(t *testing.T) {
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
logprobs := mlx.FromValues([]float32{
1.0, -2.0, 3.0, 4.0,
}, 1, 4)
got := Penalty{
RepeatLastN: 3,
RepeatPenalty: 2.0,
PresencePenalty: 1.5,
FrequencyPenalty: 0.25,
}.Sample(logprobs, []int32{2, 1, 2})
mlx.Eval(got)
want := []float32{1.0, -5.75, -0.5, 4.0}
values := got.Floats()
if len(values) != len(want) {
t.Fatalf("len(values) = %d, want %d", len(values), len(want))
}
for i := range want {
if math.Abs(float64(values[i]-want[i])) > 1e-5 {
t.Fatalf("values[%d] = %v, want %v", i, values[i], want[i])
}
}
}
func TestPenaltySampleHonorsRepeatWindow(t *testing.T) {
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
logprobs := mlx.FromValues([]float32{
1.0, 2.0, 3.0,
}, 1, 3)
got := Penalty{
RepeatLastN: 1,
PresencePenalty: 1.0,
}.Sample(logprobs, []int32{0, 1})
mlx.Eval(got)
want := []float32{1.0, 1.0, 3.0}
values := got.Floats()
for i := range want {
if math.Abs(float64(values[i]-want[i])) > 1e-5 {
t.Fatalf("values[%d] = %v, want %v", i, values[i], want[i])
}
}
}
func TestDistributionFilterTopP(t *testing.T) {
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
logits := mlx.FromValues([]float32{
10.0, 9.0, 1.0, 0.0,
}, 1, 4)
filtered, indices := Distribution{
Temperature: 1.0,
TopK: 2,
TopP: 0.55,
}.filter(logits)
got := materializeFilteredLogits(filtered, indices, 4)
mlx.Eval(got)
values := got.Floats()
if values[0] != 10.0 {
t.Fatalf("values[0] = %v, want 10", values[0])
}
for i := 1; i < len(values); i++ {
if !math.IsInf(float64(values[i]), -1) {
t.Fatalf("values[%d] = %v, want -Inf", i, values[i])
}
}
}
func materializeFilteredLogits(filtered, indices *mlx.Array, width int) *mlx.Array {
if indices == nil {
return filtered
}
base := mlx.AddScalar(mlx.Zeros(mlx.DTypeFloat32, 1, width), float32(math.Inf(-1)))
return base.PutAlongAxis(indices, filtered, -1)
}

View File

@@ -16,89 +16,12 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/mlxrunner/sample" "github.com/ollama/ollama/x/mlxrunner/sample"
"github.com/ollama/ollama/x/models/qwen3_5"
) )
type samplingConfig struct {
temperature float32
topP float32
minP float32
topK int
repeatLastN int
repeatPenalty float32
presencePenalty float32
frequencyPenalty float32
}
func defaultSamplingConfig(m base.Model, think *bool) samplingConfig {
if _, ok := m.(*qwen3_5.Model); ok {
cfg := samplingConfig{
temperature: 1.0,
topP: 0.95,
minP: 0.0,
topK: 20,
repeatLastN: 64,
repeatPenalty: 1.0,
presencePenalty: 1.5,
frequencyPenalty: 0.0,
}
if think != nil && !*think {
cfg.temperature = 0.7
cfg.topP = 0.8
}
return cfg
}
opts := api.DefaultOptions()
return samplingConfig{
temperature: opts.Temperature,
topP: opts.TopP,
minP: opts.MinP,
topK: opts.TopK,
repeatLastN: opts.RepeatLastN,
repeatPenalty: opts.RepeatPenalty,
presencePenalty: opts.PresencePenalty,
frequencyPenalty: opts.FrequencyPenalty,
}
}
func resolveSamplingConfig(m base.Model, req Request) samplingConfig {
cfg := defaultSamplingConfig(m, req.Think)
if req.Options.Temperature != nil {
cfg.temperature = *req.Options.Temperature
}
if req.Options.TopP != nil {
cfg.topP = *req.Options.TopP
}
if req.Options.MinP != nil {
cfg.minP = *req.Options.MinP
}
if req.Options.TopK != nil {
cfg.topK = *req.Options.TopK
}
if req.Options.RepeatLastN != nil {
cfg.repeatLastN = *req.Options.RepeatLastN
}
if req.Options.RepeatPenalty != nil {
cfg.repeatPenalty = *req.Options.RepeatPenalty
}
if req.Options.PresencePenalty != nil {
cfg.presencePenalty = *req.Options.PresencePenalty
}
if req.Options.FrequencyPenalty != nil {
cfg.frequencyPenalty = *req.Options.FrequencyPenalty
}
return cfg
}
func Execute(args []string) error { func Execute(args []string) error {
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel())) slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
@@ -127,11 +50,9 @@ func Execute(args []string) error {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
if err := json.NewEncoder(w).Encode(statusResponse{ if err := json.NewEncoder(w).Encode(map[string]any{
Status: 0, "status": 0,
Progress: 100, "progress": 100,
ContextLength: runner.contextLength,
Memory: uint64(mlx.ActiveMemory() + mlx.CacheMemory()),
}); err != nil { }); err != nil {
slog.Error("Failed to encode response", "error", err) slog.Error("Failed to encode response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError) http.Error(w, "Internal Server Error", http.StatusInternalServerError)
@@ -157,7 +78,7 @@ func Execute(args []string) error {
}) })
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
request := Request{Responses: make(chan CompletionResponse)} request := Request{Responses: make(chan Response)}
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil { if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
slog.Error("Failed to decode request", "error", err) slog.Error("Failed to decode request", "error", err)
@@ -166,19 +87,16 @@ func Execute(args []string) error {
} }
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict) request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
if request.Options.MaxTokens < 1 {
sampling := resolveSamplingConfig(runner.Model, request) request.Options.MaxTokens = 16 << 10
}
request.Pipeline = runner.TextGenerationPipeline request.Pipeline = runner.TextGenerationPipeline
request.Sampler = sample.New( request.Sampler = sample.New(
sampling.temperature, request.Options.Temperature,
sampling.topP, request.Options.TopP,
sampling.minP, request.Options.MinP,
sampling.topK, request.Options.TopK,
sampling.repeatLastN,
sampling.repeatPenalty,
sampling.presencePenalty,
sampling.frequencyPenalty,
) )
var cancel context.CancelFunc var cancel context.CancelFunc

View File

@@ -1,172 +0,0 @@
//go:build mlx
package mlxrunner
import (
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/qwen3_5"
"github.com/ollama/ollama/x/tokenizer"
)
type stubModel struct{}
func (stubModel) Forward(*mlx.Array, []cache.Cache) *mlx.Array { return nil }
func (stubModel) Unembed(*mlx.Array) *mlx.Array { return nil }
func (stubModel) NumLayers() int { return 0 }
func (stubModel) Tokenizer() *tokenizer.Tokenizer { return nil }
func (stubModel) LoadWeights(map[string]*mlx.Array) error { return nil }
func TestResolveSamplingConfigDefaults(t *testing.T) {
trueValue := true
falseValue := false
tests := []struct {
name string
model base.Model
req Request
want samplingConfig
}{
{
name: "generic model uses api defaults",
model: stubModel{},
req: Request{},
want: samplingConfig{
temperature: 0.8,
topP: 0.9,
minP: 0.0,
topK: 40,
repeatLastN: 64,
repeatPenalty: 1.1,
presencePenalty: 0.0,
frequencyPenalty: 0.0,
},
},
{
name: "qwen3.5 defaults to thinking profile when think unset",
model: &qwen3_5.Model{},
req: Request{},
want: samplingConfig{
temperature: 1.0,
topP: 0.95,
minP: 0.0,
topK: 20,
repeatLastN: 64,
repeatPenalty: 1.0,
presencePenalty: 1.5,
frequencyPenalty: 0.0,
},
},
{
name: "qwen3.5 thinking disabled defaults",
model: &qwen3_5.Model{},
req: Request{TextCompletionsRequest: TextCompletionsRequest{Think: &falseValue}},
want: samplingConfig{
temperature: 0.7,
topP: 0.8,
minP: 0.0,
topK: 20,
repeatLastN: 64,
repeatPenalty: 1.0,
presencePenalty: 1.5,
frequencyPenalty: 0.0,
},
},
{
name: "qwen3.5 thinking enabled defaults",
model: &qwen3_5.Model{},
req: Request{TextCompletionsRequest: TextCompletionsRequest{Think: &trueValue}},
want: samplingConfig{
temperature: 1.0,
topP: 0.95,
minP: 0.0,
topK: 20,
repeatLastN: 64,
repeatPenalty: 1.0,
presencePenalty: 1.5,
frequencyPenalty: 0.0,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := resolveSamplingConfig(tt.model, tt.req); got != tt.want {
t.Fatalf("resolveSamplingConfig() = %+v, want %+v", got, tt.want)
}
})
}
}
func TestResolveSamplingConfigOverridesSpecifiedValues(t *testing.T) {
trueValue := true
temperature := float32(0.4)
topP := float32(0.6)
minP := float32(0.05)
topK := 12
repeatLastN := 32
repeatPenalty := float32(1.1)
presencePenalty := float32(0.7)
frequencyPenalty := float32(0.2)
got := resolveSamplingConfig(stubModel{}, Request{
TextCompletionsRequest: TextCompletionsRequest{
Think: &trueValue,
Options: struct {
Temperature *float32 `json:"temperature"`
TopP *float32 `json:"top_p"`
MinP *float32 `json:"min_p"`
TopK *int `json:"top_k"`
RepeatLastN *int `json:"repeat_last_n"`
RepeatPenalty *float32 `json:"repeat_penalty"`
PresencePenalty *float32 `json:"presence_penalty"`
FrequencyPenalty *float32 `json:"frequency_penalty"`
MaxTokens int `json:"max_tokens"`
NumPredict int `json:"num_predict"`
}{
Temperature: &temperature,
TopP: &topP,
MinP: &minP,
TopK: &topK,
RepeatLastN: &repeatLastN,
RepeatPenalty: &repeatPenalty,
PresencePenalty: &presencePenalty,
FrequencyPenalty: &frequencyPenalty,
},
},
})
want := samplingConfig{
temperature: temperature,
topP: topP,
minP: minP,
topK: topK,
repeatLastN: repeatLastN,
repeatPenalty: repeatPenalty,
presencePenalty: presencePenalty,
frequencyPenalty: frequencyPenalty,
}
if got != want {
t.Fatalf("resolveSamplingConfig() = %+v, want %+v", got, want)
}
}
func TestResolveSamplingConfigMatchesGenericDefaults(t *testing.T) {
want := api.DefaultOptions()
got := defaultSamplingConfig(stubModel{}, nil)
if got.temperature != want.Temperature ||
got.topP != want.TopP ||
got.minP != want.MinP ||
got.topK != want.TopK ||
got.repeatLastN != want.RepeatLastN ||
got.repeatPenalty != want.RepeatPenalty ||
got.presencePenalty != want.PresencePenalty ||
got.frequencyPenalty != want.FrequencyPenalty {
t.Fatalf("defaultSamplingConfig() = %+v, want api defaults %+v", got, want)
}
}

View File

@@ -430,10 +430,6 @@ func (m *Model) NumLayers() int {
return len(m.Layers) return len(m.Layers)
} }
func (m *Model) MaxContextLength() int {
return int(m.MaxPositionEmbeddings)
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer { func (m *Model) Tokenizer() *tokenizer.Tokenizer {
return m.tok return m.tok
} }

View File

@@ -733,7 +733,7 @@ func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
func (m *Model) NumLayers() int { return len(m.Layers) } func (m *Model) NumLayers() int { return len(m.Layers) }
// MaxContextLength returns the maximum context length // MaxContextLength returns the maximum context length
func (m *Model) MaxContextLength() int { return int(m.MaxPositionEmbeddings) } func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
// VocabSize returns the vocabulary size // VocabSize returns the vocabulary size
func (m *Model) VocabSize() int32 { return m.Config.VocabSize } func (m *Model) VocabSize() int32 { return m.Config.VocabSize }

View File

@@ -262,10 +262,6 @@ func (m *Model) NumLayers() int {
return len(m.Layers) return len(m.Layers)
} }
func (m *Model) MaxContextLength() int {
return int(m.MaxPositionEmbeddings)
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer { func (m *Model) Tokenizer() *tokenizer.Tokenizer {
return m.tok return m.tok
} }

View File

@@ -15,40 +15,6 @@ type LinearLayer interface {
OutputDim() int32 OutputDim() int32
} }
// Conv1d applies 1D convolution over NLC input.
type Conv1d struct {
Weight *mlx.Array
Bias *mlx.Array
Stride int32
Padding int32
Dilation int32
Groups int32
}
func NewConv1d(weight, bias *mlx.Array, stride, padding, dilation, groups int32) *Conv1d {
if stride <= 0 {
stride = 1
}
if dilation <= 0 {
dilation = 1
}
if groups <= 0 {
groups = 1
}
return &Conv1d{
Weight: weight,
Bias: bias,
Stride: stride,
Padding: padding,
Dilation: dilation,
Groups: groups,
}
}
func (c *Conv1d) Forward(x *mlx.Array) *mlx.Array {
return mlx.Conv1d(x, c.Weight, c.Bias, c.Stride, c.Padding, c.Dilation, c.Groups)
}
// Linear applies an affine transformation: y = x @ W.T + b // Linear applies an affine transformation: y = x @ W.T + b
type Linear struct { type Linear struct {
Weight *mlx.Array Weight *mlx.Array

View File

@@ -279,10 +279,6 @@ func (m *Model) NumLayers() int {
return len(m.Layers) return len(m.Layers)
} }
func (m *Model) MaxContextLength() int {
return int(m.MaxPositionEmbeddings)
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer { func (m *Model) Tokenizer() *tokenizer.Tokenizer {
return m.tok return m.tok
} }

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,166 +0,0 @@
//go:build mlx
package qwen3_5
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func TestParseConfigNestedDefaults(t *testing.T) {
data := []byte(`{
"model_type": "Qwen3_5MoeForConditionalGeneration",
"text_config": {
"hidden_size": 4096,
"intermediate_size": 14336,
"num_hidden_layers": 8,
"num_attention_heads": 32,
"num_key_value_heads": 8,
"head_dim": 128,
"linear_num_value_heads": 64,
"linear_num_key_heads": 16,
"linear_key_head_dim": 128,
"linear_value_head_dim": 128,
"linear_conv_kernel_dim": 4,
"num_experts": 16,
"num_experts_per_tok": 4,
"moe_intermediate_size": 2048,
"shared_expert_intermediate_size": 4096,
"rope_parameters": {
"rope_theta": 500000,
"partial_rotary_factor": 0.5
}
}
}`)
cfg, err := parseConfig(data)
if err != nil {
t.Fatalf("parseConfig failed: %v", err)
}
if cfg.RopeTheta != 500000 {
t.Fatalf("rope theta mismatch: got %v", cfg.RopeTheta)
}
if cfg.RopeDim != 64 {
t.Fatalf("rope dim mismatch: got %d want 64", cfg.RopeDim)
}
if cfg.FullAttentionInterval != 4 {
t.Fatalf("full_attention_interval default mismatch: got %d want 4", cfg.FullAttentionInterval)
}
if !cfg.NormTopKProb {
t.Fatalf("norm_topk_prob should default to true for MoE")
}
}
func TestLayerSelectionHelpers(t *testing.T) {
cfg := &Config{
NumHiddenLayers: 6,
FullAttentionInterval: 3,
NumExperts: 8,
DecoderSparseStep: 2,
MLPOnlyLayers: []int32{1},
}
if !layerIsLinear(cfg, 0) {
t.Fatalf("layer 0 should be linear")
}
if layerIsLinear(cfg, 2) {
t.Fatalf("layer 2 should be full attention")
}
if layerUsesMoE(cfg, 1) {
t.Fatalf("layer 1 should be forced dense by mlp_only_layers")
}
if !layerUsesMoE(cfg, 3) {
t.Fatalf("layer 3 should use moe with decoder_sparse_step=2")
}
}
func TestResolveTensorPathLayout(t *testing.T) {
dummy := mlx.New("dummy")
tests := []struct {
name string
key string
wantContainer string
wantModel string
}{
{
name: "standard",
key: "model.embed_tokens.weight",
wantContainer: "",
wantModel: "model.",
},
{
name: "nested language model with inner model",
key: "model.language_model.model.embed_tokens.weight",
wantContainer: "model.language_model.",
wantModel: "model.",
},
{
name: "nested language model without inner model",
key: "model.language_model.embed_tokens.weight",
wantContainer: "model.language_model.",
wantModel: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
layout := resolveTensorPathLayout(map[string]*mlx.Array{
tt.key: dummy,
})
if layout.containerPrefix != tt.wantContainer || layout.modelPrefix != tt.wantModel {
t.Fatalf(
"resolveTensorPathLayout() = {%q %q}, want {%q %q}",
layout.containerPrefix,
layout.modelPrefix,
tt.wantContainer,
tt.wantModel,
)
}
})
}
}
func TestModelRuntimeDefaults(t *testing.T) {
m := &Model{}
if m.DisablePromptCache() {
t.Fatal("DisablePromptCache() = true, want false")
}
}
func TestNewCachesLayout(t *testing.T) {
m := &Model{
Config: &Config{
LinearConvKernelDim: 4,
LinearNumKeyHeads: 2,
LinearKeyHeadDim: 8,
LinearNumValueHeads: 4,
LinearValueHeadDim: 16,
},
Layers: []*Layer{
{IsLinear: true},
{IsLinear: false},
{IsLinear: true},
},
}
caches := m.NewCaches()
if len(caches) != len(m.Layers) {
t.Fatalf("len(caches) = %d, want %d", len(caches), len(m.Layers))
}
if _, ok := caches[0].(*cache.RecurrentCache); !ok {
t.Fatalf("cache[0] = %T, want *cache.RecurrentCache", caches[0])
}
if _, ok := caches[1].(*cache.KVCache); !ok {
t.Fatalf("cache[1] = %T, want *cache.KVCache", caches[1])
}
if _, ok := caches[2].(*cache.RecurrentCache); !ok {
t.Fatalf("cache[2] = %T, want *cache.RecurrentCache", caches[2])
}
}

View File

@@ -1,16 +0,0 @@
//go:build mlx
// Package qwen3_5_moe registers Qwen 3.5 MoE architecture aliases.
package qwen3_5_moe
import (
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/qwen3_5"
)
func init() {
base.Register("Qwen3_5MoeForConditionalGeneration", qwen3_5.NewModel)
base.Register("Qwen3_5MoeForCausalLM", qwen3_5.NewModel)
base.Register("Qwen3NextMoeForConditionalGeneration", qwen3_5.NewModel)
base.Register("Qwen3NextMoeForCausalLM", qwen3_5.NewModel)
}