mirror of
https://github.com/ollama/ollama.git
synced 2026-02-27 12:36:54 -05:00
Compare commits
3 Commits
pdevine/sa
...
pdevine/me
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6b9cf42086 | ||
|
|
83d0f3890e | ||
|
|
338bfa51a7 |
@@ -204,24 +204,6 @@ func (p *Qwen3Parser) eat() ([]qwen3Event, bool) {
|
||||
p.maybeThinkingOpenAtBOL = false
|
||||
}
|
||||
|
||||
thinkingCloseIdx := strings.Index(acc, qwen3ThinkingCloseTag)
|
||||
toolOpenIdx := strings.Index(acc, qwen3ToolOpenTag)
|
||||
|
||||
// If a tool call starts before </think>, treat that as the end of thinking
|
||||
// for parsing purposes and continue in tool-call mode.
|
||||
if toolOpenIdx != -1 && (thinkingCloseIdx == -1 || toolOpenIdx < thinkingCloseIdx) {
|
||||
before, after := p.splitAtTag(qwen3ToolOpenTag, true)
|
||||
if len(before) > 0 {
|
||||
events = append(events, qwen3EventThinkingContent{content: before})
|
||||
}
|
||||
if after == "" {
|
||||
p.state = qwen3ParserStateToolStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = qwen3ParserStateCollectingToolContent
|
||||
}
|
||||
return events, true
|
||||
}
|
||||
|
||||
if strings.Contains(acc, qwen3ThinkingCloseTag) {
|
||||
thinking, remaining := p.splitAtTag(qwen3ThinkingCloseTag, true)
|
||||
if len(thinking) > 0 {
|
||||
@@ -233,7 +215,7 @@ func (p *Qwen3Parser) eat() ([]qwen3Event, bool) {
|
||||
p.state = qwen3ParserStateCollectingContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := max(overlap(acc, qwen3ThinkingCloseTag), overlap(acc, qwen3ToolOpenTag)); overlapLen > 0 {
|
||||
} else if overlapLen := overlap(acc, qwen3ThinkingCloseTag); overlapLen > 0 {
|
||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||
|
||||
@@ -146,68 +146,6 @@ func TestQwen3ParserToolCall(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserThinkingWithToolCallBeforeThinkingClose(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
input := "Let me think<tool_call>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"San Francisco\",\"unit\":\"celsius\"}}</tool_call>"
|
||||
content, thinking, calls, err := parser.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed: %v", err)
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
t.Fatalf("expected empty content, got %q", content)
|
||||
}
|
||||
if thinking != "Let me think" {
|
||||
t.Fatalf("expected thinking %q, got %q", "Let me think", thinking)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "get_weather" {
|
||||
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3ParserThinkingWithSplitToolOpenTag(t *testing.T) {
|
||||
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
content, thinking, calls, err := parser.Add("Let me think<tool_ca", false)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed on first chunk: %v", err)
|
||||
}
|
||||
if content != "" || thinking != "Let me think" || len(calls) != 0 {
|
||||
t.Fatalf(
|
||||
"expected content=%q thinking=%q calls=%d, got content=%q thinking=%q calls=%d",
|
||||
"",
|
||||
"Let me think",
|
||||
0,
|
||||
content,
|
||||
thinking,
|
||||
len(calls),
|
||||
)
|
||||
}
|
||||
|
||||
content, thinking, calls, err = parser.Add("ll>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"SF\"}}</tool_call>", true)
|
||||
if err != nil {
|
||||
t.Fatalf("parse failed on second chunk: %v", err)
|
||||
}
|
||||
if content != "" {
|
||||
t.Fatalf("expected empty content, got %q", content)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("expected no additional thinking on second chunk, got %q", thinking)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "get_weather" {
|
||||
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen35ParserRespectsNoThink(t *testing.T) {
|
||||
parser := ParserForName("qwen3.5")
|
||||
if parser == nil {
|
||||
|
||||
@@ -180,22 +180,7 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
|
||||
return events, false
|
||||
}
|
||||
case CollectingThinkingContent:
|
||||
acc := p.buffer.String()
|
||||
thinkingCloseIdx := strings.Index(acc, thinkingCloseTag)
|
||||
toolOpenIdx := strings.Index(acc, toolOpenTag)
|
||||
|
||||
// If a tool call starts before </think>, treat that as the end of thinking
|
||||
// for parsing purposes and continue in tool-call mode.
|
||||
if toolOpenIdx != -1 && (thinkingCloseIdx == -1 || toolOpenIdx < thinkingCloseIdx) {
|
||||
before, _ := splitAtTag(&p.buffer, toolOpenTag, false)
|
||||
if len(before) > 0 {
|
||||
events = append(events, qwenEventThinkingContent{content: before})
|
||||
}
|
||||
p.state = CollectingToolContent
|
||||
return events, true
|
||||
}
|
||||
|
||||
if strings.Contains(acc, thinkingCloseTag) {
|
||||
if strings.Contains(p.buffer.String(), thinkingCloseTag) {
|
||||
thinking, remaining := splitAtTag(&p.buffer, thinkingCloseTag, true)
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, qwenEventThinkingContent{content: thinking})
|
||||
@@ -206,13 +191,13 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
|
||||
p.state = CollectingContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := max(overlap(acc, thinkingCloseTag), overlap(acc, toolOpenTag)); overlapLen > 0 {
|
||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||
} else if overlapLen := overlap(p.buffer.String(), thinkingCloseTag); overlapLen > 0 {
|
||||
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
|
||||
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
@@ -220,11 +205,11 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
whitespaceLen := trailingWhitespaceLen(acc)
|
||||
ambiguousStart := len(acc) - whitespaceLen
|
||||
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
|
||||
ambiguousStart := len(p.buffer.String()) - whitespaceLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
|
||||
@@ -98,12 +98,8 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
|
||||
desc: "nested thinking and tool call (outside thinking, inside tool call)",
|
||||
steps: []step{
|
||||
{
|
||||
input: "I'm thinking<tool_call>I'm nested tool call</tool_call></think>",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventThinkingContent{content: "I'm thinking"},
|
||||
qwenEventRawToolCall{raw: "I'm nested tool call"},
|
||||
qwenEventContent{content: "</think>"},
|
||||
},
|
||||
input: "I'm thinking<tool_call>I'm nested tool call</tool_call></think>",
|
||||
wantEvents: []qwenEvent{qwenEventThinkingContent{content: "I'm thinking<tool_call>I'm nested tool call</tool_call>"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -113,7 +109,8 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
|
||||
{
|
||||
input: "<tool_call>I'm nested tool call<think>I'm thinking</think></tool_call>",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventRawToolCall{raw: "I'm nested tool call<think>I'm thinking</think>"},
|
||||
qwenEventThinkingContent{content: "<tool_call>I'm nested tool call<think>I'm thinking"},
|
||||
qwenEventContent{content: "</tool_call>"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -124,8 +121,8 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
|
||||
{
|
||||
input: "I'm thinking<tool_call>I'm NOT a nested tool call</think></tool_call><tool_call>I'm nested tool call 2<think></tool_call></think>",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventThinkingContent{content: "I'm thinking"},
|
||||
qwenEventRawToolCall{raw: "I'm NOT a nested tool call</think>"},
|
||||
qwenEventThinkingContent{content: "I'm thinking<tool_call>I'm NOT a nested tool call"},
|
||||
qwenEventContent{content: "</tool_call>"},
|
||||
qwenEventRawToolCall{raw: "I'm nested tool call 2<think>"},
|
||||
qwenEventContent{content: "</think>"},
|
||||
},
|
||||
|
||||
@@ -288,6 +288,18 @@ func normalizeQuantType(quantize string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func isStackedExpertWeight(name string) bool {
|
||||
// Combined/stacked expert tensors may be emitted either as "...proj.weight" (per-expert)
|
||||
// or "...proj" (pre-stacked packed tensor).
|
||||
if strings.HasSuffix(name, ".bias") || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".qbias") {
|
||||
return false
|
||||
}
|
||||
|
||||
return strings.Contains(name, ".mlp.switch_mlp.") ||
|
||||
strings.Contains(name, ".mlp.experts.") ||
|
||||
strings.Contains(name, ".mlp.shared_experts.")
|
||||
}
|
||||
|
||||
// GetTensorQuantization returns the appropriate quantization type for a tensor.
|
||||
// Returns "" if the tensor should not be quantized.
|
||||
// This implements mixed-precision quantization:
|
||||
@@ -296,18 +308,25 @@ func normalizeQuantType(quantize string) string {
|
||||
// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
|
||||
// - Norms, embeddings, biases, routing gates: no quantization
|
||||
func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||
stackedExpert := isStackedExpertWeight(name)
|
||||
|
||||
// Use basic name-based check first
|
||||
if !ShouldQuantize(name, "") {
|
||||
if !stackedExpert && !ShouldQuantize(name, "") {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
|
||||
if len(shape) != 2 {
|
||||
// Quantize standard linear weights (2D). Also allow stacked expert weights (3D),
|
||||
// e.g. qwen switch_mlp / experts combined tensors.
|
||||
if len(shape) != 2 && !(len(shape) == 3 && stackedExpert) {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Skip small tensors (less than 1024 elements) - not worth quantizing
|
||||
if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 {
|
||||
var elems int64 = 1
|
||||
for _, d := range shape {
|
||||
elems *= int64(d)
|
||||
}
|
||||
if elems < 1024 {
|
||||
return ""
|
||||
}
|
||||
|
||||
|
||||
@@ -557,6 +557,10 @@ func TestShouldQuantizeTensor(t *testing.T) {
|
||||
// 3D+ tensors should not be quantized
|
||||
{"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
|
||||
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
|
||||
{"stacked expert switch_mlp gate_up 3D int8", "model.layers.1.mlp.switch_mlp.gate_up_proj.weight", []int32{64, 22016, 4096}, "int8", true},
|
||||
{"stacked expert experts down_proj 3D int8", "model.layers.1.mlp.experts.down_proj.weight", []int32{64, 4096, 14336}, "int8", true},
|
||||
{"stacked expert combined gate_up 3D int8", "model.language_model.layers.0.mlp.experts.gate_up_proj", []int32{256, 1024, 2048}, "int8", true},
|
||||
{"stacked expert combined down_proj 3D int8", "model.language_model.layers.0.mlp.experts.down_proj", []int32{256, 2048, 512}, "int8", true},
|
||||
|
||||
// Embeddings should not be quantized regardless of shape
|
||||
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
|
||||
@@ -619,6 +623,44 @@ func TestExpertGroupPrefix(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTensorQuantization_StackedExpert3D(t *testing.T) {
|
||||
gateUp := GetTensorQuantization(
|
||||
"model.layers.1.mlp.switch_mlp.gate_up_proj.weight",
|
||||
[]int32{64, 22016, 4096},
|
||||
"int4",
|
||||
)
|
||||
if gateUp != "int4" {
|
||||
t.Fatalf("gate_up_proj quantization = %q, want %q", gateUp, "int4")
|
||||
}
|
||||
|
||||
down := GetTensorQuantization(
|
||||
"model.layers.1.mlp.experts.down_proj.weight",
|
||||
[]int32{64, 4096, 14336},
|
||||
"int4",
|
||||
)
|
||||
if down != "int8" {
|
||||
t.Fatalf("down_proj quantization = %q, want %q", down, "int8")
|
||||
}
|
||||
|
||||
combinedGateUp := GetTensorQuantization(
|
||||
"model.language_model.layers.0.mlp.experts.gate_up_proj",
|
||||
[]int32{256, 1024, 2048},
|
||||
"int8",
|
||||
)
|
||||
if combinedGateUp != "int8" {
|
||||
t.Fatalf("combined gate_up_proj quantization = %q, want %q", combinedGateUp, "int8")
|
||||
}
|
||||
|
||||
combinedDown := GetTensorQuantization(
|
||||
"model.language_model.layers.0.mlp.experts.down_proj",
|
||||
[]int32{256, 2048, 512},
|
||||
"int4",
|
||||
)
|
||||
if combinedDown != "int8" {
|
||||
t.Fatalf("combined down_proj quantization = %q, want %q", combinedDown, "int8")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
|
||||
@@ -30,21 +30,64 @@ type cacheSession struct {
|
||||
remaining []int32
|
||||
}
|
||||
|
||||
func (c *kvCache) free() {
|
||||
for i, kv := range c.caches {
|
||||
if kv == nil {
|
||||
continue
|
||||
}
|
||||
kv.Free()
|
||||
c.caches[i] = nil
|
||||
}
|
||||
c.caches = nil
|
||||
c.tokens = nil
|
||||
}
|
||||
|
||||
func (c *kvCache) cachesCanTrim() bool {
|
||||
for _, kv := range c.caches {
|
||||
if kv == nil {
|
||||
continue
|
||||
}
|
||||
if !kv.CanTrim() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *kvCache) trimToPrefix(prefix int) {
|
||||
for _, kv := range c.caches {
|
||||
if kv == nil || !kv.CanTrim() {
|
||||
continue
|
||||
}
|
||||
if trim := kv.Offset() - prefix; trim > 0 {
|
||||
kv.Trim(trim)
|
||||
}
|
||||
}
|
||||
if prefix < len(c.tokens) {
|
||||
c.tokens = c.tokens[:prefix]
|
||||
}
|
||||
}
|
||||
|
||||
// begin prepares caches for a new request. It finds the nearest
|
||||
// matching cache or creates new caches if none match.
|
||||
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
||||
if len(c.caches) == 0 {
|
||||
ensureCaches := func() {
|
||||
if len(c.caches) != 0 {
|
||||
return
|
||||
}
|
||||
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
|
||||
c.caches = cacheFactory.NewCaches()
|
||||
} else {
|
||||
c.caches = make([]cache.Cache, m.NumLayers())
|
||||
for i := range c.caches {
|
||||
c.caches[i] = cache.NewKVCache()
|
||||
}
|
||||
return
|
||||
}
|
||||
c.caches = make([]cache.Cache, m.NumLayers())
|
||||
for i := range c.caches {
|
||||
c.caches[i] = cache.NewKVCache()
|
||||
}
|
||||
}
|
||||
ensureCaches()
|
||||
|
||||
remaining := c.findRemaining(inputs)
|
||||
ensureCaches()
|
||||
|
||||
return &cacheSession{
|
||||
cache: c,
|
||||
@@ -56,18 +99,34 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
||||
|
||||
// close saves the token state if the forward pass ran.
|
||||
func (s *cacheSession) close() {
|
||||
if offset := s.caches[0].Offset(); offset > 0 {
|
||||
// Ensure that if we have run the forward pass and set the metadata
|
||||
// that we also actually have the data
|
||||
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
|
||||
for _, c := range s.caches {
|
||||
k, v := c.State()
|
||||
arrays = append(arrays, k, v)
|
||||
}
|
||||
mlx.AsyncEval(arrays...)
|
||||
|
||||
s.cache.tokens = append(s.inputs, s.outputs...)[:offset]
|
||||
if len(s.caches) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
offset := -1
|
||||
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
|
||||
for _, kv := range s.caches {
|
||||
if kv == nil {
|
||||
continue
|
||||
}
|
||||
if off := kv.Offset(); offset < 0 || off < offset {
|
||||
offset = off
|
||||
}
|
||||
arrays = append(arrays, kv.Materialize()...)
|
||||
}
|
||||
if offset <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure that if we have run the forward pass and set the metadata
|
||||
// that we also actually have the data.
|
||||
mlx.AsyncEval(arrays...)
|
||||
|
||||
stored := append(s.inputs, s.outputs...)
|
||||
if offset > len(stored) {
|
||||
offset = len(stored)
|
||||
}
|
||||
s.cache.tokens = stored[:offset]
|
||||
}
|
||||
|
||||
// findRemaining finds the longest common prefix between tokens and the cached
|
||||
@@ -78,17 +137,14 @@ func (c *kvCache) findRemaining(tokens []int32) []int32 {
|
||||
prefix++
|
||||
}
|
||||
|
||||
if prefix == len(tokens) && prefix > 0 {
|
||||
// Leave one token to run through the model so we can sample a response.
|
||||
prefix--
|
||||
}
|
||||
|
||||
if prefix < len(c.tokens) {
|
||||
trim := len(c.tokens) - prefix
|
||||
for _, kv := range c.caches {
|
||||
kv.Trim(trim)
|
||||
if c.cachesCanTrim() {
|
||||
c.trimToPrefix(prefix)
|
||||
} else {
|
||||
c.free()
|
||||
slog.Info("Cache miss", "left", len(tokens), "matched", prefix, "reason", "non_trimmable_divergence")
|
||||
return tokens
|
||||
}
|
||||
c.tokens = c.tokens[:prefix]
|
||||
}
|
||||
|
||||
if prefix == 0 {
|
||||
@@ -103,10 +159,21 @@ func (c *kvCache) log() {
|
||||
if len(c.caches) == 0 {
|
||||
return
|
||||
}
|
||||
offset := -1
|
||||
var totalBytes int
|
||||
for _, kv := range c.caches {
|
||||
k, v := kv.State()
|
||||
totalBytes += k.NumBytes() + v.NumBytes()
|
||||
if kv == nil {
|
||||
continue
|
||||
}
|
||||
if off := kv.Offset(); offset < 0 || off < offset {
|
||||
offset = off
|
||||
}
|
||||
for _, a := range kv.Materialize() {
|
||||
totalBytes += a.NumBytes()
|
||||
}
|
||||
}
|
||||
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
|
||||
if offset < 0 {
|
||||
return
|
||||
}
|
||||
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", offset, mlx.PrettyBytes(totalBytes)))
|
||||
}
|
||||
|
||||
18
x/mlxrunner/cache/cache.go
vendored
18
x/mlxrunner/cache/cache.go
vendored
@@ -10,6 +10,8 @@ import (
|
||||
type Cache interface {
|
||||
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
|
||||
State() (keys, values *mlx.Array)
|
||||
Materialize() []*mlx.Array
|
||||
CanTrim() bool
|
||||
Trim(int) int
|
||||
Clone() Cache
|
||||
Free()
|
||||
@@ -67,6 +69,20 @@ func (c *KVCache) State() (*mlx.Array, *mlx.Array) {
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||
}
|
||||
|
||||
// Materialize returns the backing key/value buffers currently held by the cache.
|
||||
func (c *KVCache) Materialize() []*mlx.Array {
|
||||
out := make([]*mlx.Array, 0, 2)
|
||||
if c.keys != nil && c.keys.Valid() {
|
||||
out = append(out, c.keys)
|
||||
}
|
||||
if c.values != nil && c.values.Valid() {
|
||||
out = append(out, c.values)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *KVCache) CanTrim() bool { return true }
|
||||
|
||||
func (c *KVCache) Trim(n int) int {
|
||||
n = min(c.offset, n)
|
||||
c.offset -= n
|
||||
@@ -190,6 +206,8 @@ func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) {
|
||||
return c.keys, c.values
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) CanTrim() bool { return true }
|
||||
|
||||
func (c *RotatingKVCache) Trim(n int) int {
|
||||
n = min(c.offset, n)
|
||||
c.offset -= n
|
||||
|
||||
220
x/mlxrunner/cache/recurrent.go
vendored
Normal file
220
x/mlxrunner/cache/recurrent.go
vendored
Normal file
@@ -0,0 +1,220 @@
|
||||
//go:build mlx
|
||||
|
||||
package cache
|
||||
|
||||
import "github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
|
||||
// RecurrentCache stores state for linear-recurrent layers.
|
||||
//
|
||||
// Conv state shape: [B, convTail, convDim]
|
||||
// Delta state shape: [B, numVHeads, headVDim, headKDim]
|
||||
type RecurrentCache struct {
|
||||
convState *mlx.Array
|
||||
deltaState *mlx.Array
|
||||
offset int
|
||||
|
||||
convTail int
|
||||
convDim int
|
||||
numVHeads int
|
||||
headVDim int
|
||||
headKDim int
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) setStateMaterialized(dst **mlx.Array, v *mlx.Array) {
|
||||
if v == nil || !v.Valid() {
|
||||
return
|
||||
}
|
||||
if *dst == v {
|
||||
return
|
||||
}
|
||||
|
||||
// Break dependency chains so recurrent state does not retain the full
|
||||
// per-token compute graph over time.
|
||||
snap := mlx.Snapshot(v)
|
||||
mlx.Eval(snap)
|
||||
|
||||
old := *dst
|
||||
*dst = snap
|
||||
mlx.Pin(snap)
|
||||
|
||||
// Drop references to the previous cached state root and transient incoming
|
||||
// graph root now that a detached snapshot is retained in cache. Actual
|
||||
// cleanup happens at the runner's normal sweep points.
|
||||
if old != nil && old != snap {
|
||||
mlx.Unpin(old)
|
||||
}
|
||||
if v != snap && v != old {
|
||||
mlx.Unpin(v)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) setStateRaw(dst **mlx.Array, v *mlx.Array) {
|
||||
if v == nil || !v.Valid() {
|
||||
return
|
||||
}
|
||||
if *dst == v {
|
||||
return
|
||||
}
|
||||
|
||||
old := *dst
|
||||
*dst = v
|
||||
mlx.Pin(v)
|
||||
if old != nil && old != v {
|
||||
mlx.Unpin(old)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) setStateDetached(dst **mlx.Array, v *mlx.Array, ensureContiguous bool) {
|
||||
if v == nil || !v.Valid() {
|
||||
return
|
||||
}
|
||||
if *dst == v {
|
||||
return
|
||||
}
|
||||
|
||||
root := v
|
||||
if ensureContiguous {
|
||||
root = mlx.Contiguous(v, false)
|
||||
}
|
||||
detached := mlx.Detach(root)
|
||||
|
||||
old := *dst
|
||||
*dst = detached
|
||||
mlx.Pin(detached)
|
||||
if old != nil && old != detached {
|
||||
mlx.Unpin(old)
|
||||
}
|
||||
|
||||
// Intentionally do not force-release root/v here. In the fast path, the detached
|
||||
// handle aliases the same MLX value and may still be lazily computed. Releasing the
|
||||
// source handles can invalidate the cached state before the next eval/sweep point.
|
||||
}
|
||||
|
||||
func snapshotPinned(a *mlx.Array) *mlx.Array {
|
||||
if a == nil || !a.Valid() {
|
||||
return nil
|
||||
}
|
||||
snap := mlx.Snapshot(a)
|
||||
mlx.Eval(snap)
|
||||
mlx.Pin(snap)
|
||||
return snap
|
||||
}
|
||||
|
||||
func NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim int32) *RecurrentCache {
|
||||
return &RecurrentCache{
|
||||
convTail: int(convTail),
|
||||
convDim: int(convDim),
|
||||
numVHeads: int(numVHeads),
|
||||
headVDim: int(headVDim),
|
||||
headKDim: int(headKDim),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) ensure(batch int, dtype mlx.DType) {
|
||||
if batch <= 0 {
|
||||
batch = 1
|
||||
}
|
||||
|
||||
needConv := c.convState == nil || !c.convState.Valid() || c.convState.DType() != dtype ||
|
||||
c.convState.Dim(0) != batch || c.convState.Dim(1) != c.convTail || c.convState.Dim(2) != c.convDim
|
||||
needDelta := c.deltaState == nil || !c.deltaState.Valid() || c.deltaState.DType() != dtype ||
|
||||
c.deltaState.Dim(0) != batch || c.deltaState.Dim(1) != c.numVHeads || c.deltaState.Dim(2) != c.headVDim || c.deltaState.Dim(3) != c.headKDim
|
||||
if !needConv && !needDelta {
|
||||
return
|
||||
}
|
||||
|
||||
if needConv {
|
||||
c.setStateRaw(&c.convState, mlx.Zeros(dtype, batch, c.convTail, c.convDim))
|
||||
}
|
||||
if needDelta {
|
||||
c.setStateRaw(&c.deltaState, mlx.Zeros(dtype, batch, c.numVHeads, c.headVDim, c.headKDim))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) ConvState(batch int, dtype mlx.DType) *mlx.Array {
|
||||
c.ensure(batch, dtype)
|
||||
return c.convState
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) SetConvState(v *mlx.Array) {
|
||||
c.setStateMaterialized(&c.convState, v)
|
||||
}
|
||||
|
||||
// SetConvStateFast stores conv state without forcing an immediate snapshot/eval.
|
||||
// Use only for decode hot paths that accept higher transient memory until the next
|
||||
// sync/sweep point. The conv-state input is usually a slice view, so request a
|
||||
// compact contiguous copy to avoid pinning the whole source buffer.
|
||||
func (c *RecurrentCache) SetConvStateFast(v *mlx.Array) {
|
||||
c.setStateDetached(&c.convState, v, true)
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) DeltaState(batch int, dtype mlx.DType) *mlx.Array {
|
||||
c.ensure(batch, dtype)
|
||||
return c.deltaState
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) SetDeltaState(v *mlx.Array) {
|
||||
c.setStateMaterialized(&c.deltaState, v)
|
||||
}
|
||||
|
||||
// SetDeltaStateFast stores delta state without forcing an immediate snapshot/eval.
|
||||
// Use only for decode hot paths that accept higher transient memory until the next
|
||||
// sync/sweep point.
|
||||
func (c *RecurrentCache) SetDeltaStateFast(v *mlx.Array) {
|
||||
c.setStateDetached(&c.deltaState, v, false)
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Advance(n int) {
|
||||
c.offset += n
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
return keys, values
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) State() (*mlx.Array, *mlx.Array) {
|
||||
return c.convState, c.deltaState
|
||||
}
|
||||
|
||||
// Materialize returns the recurrent state roots (conv and delta) held by the cache.
|
||||
func (c *RecurrentCache) Materialize() []*mlx.Array {
|
||||
out := make([]*mlx.Array, 0, 2)
|
||||
if c.convState != nil && c.convState.Valid() {
|
||||
out = append(out, c.convState)
|
||||
}
|
||||
if c.deltaState != nil && c.deltaState.Valid() {
|
||||
out = append(out, c.deltaState)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) CanTrim() bool { return false }
|
||||
|
||||
func (c *RecurrentCache) Trim(n int) int {
|
||||
// Recurrent state is not directly trimmable. Divergent prefixes must drop the cache.
|
||||
_ = n
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Clone() Cache {
|
||||
clone := &RecurrentCache{
|
||||
offset: c.offset,
|
||||
convTail: c.convTail,
|
||||
convDim: c.convDim,
|
||||
numVHeads: c.numVHeads,
|
||||
headVDim: c.headVDim,
|
||||
headKDim: c.headKDim,
|
||||
convState: snapshotPinned(c.convState),
|
||||
deltaState: snapshotPinned(c.deltaState),
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Free() {
|
||||
mlx.Unpin(c.convState, c.deltaState)
|
||||
c.convState, c.deltaState = nil, nil
|
||||
c.offset = 0
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Offset() int { return c.offset }
|
||||
func (c *RecurrentCache) Len() int { return c.offset }
|
||||
@@ -7,4 +7,6 @@ import (
|
||||
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
|
||||
_ "github.com/ollama/ollama/x/models/llama"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3_5"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3_5_moe"
|
||||
)
|
||||
|
||||
275
x/mlxrunner/mlx/gated_delta_metal.go
Normal file
275
x/mlxrunner/mlx/gated_delta_metal.go
Normal file
@@ -0,0 +1,275 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlx
|
||||
|
||||
// #include <stdlib.h>
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var (
|
||||
gatedDeltaMetalKernelOnce sync.Once
|
||||
gatedDeltaMetalKernel C.mlx_fast_metal_kernel
|
||||
gatedDeltaMetalDisabled atomic.Bool
|
||||
)
|
||||
|
||||
const gatedDeltaMetalKernelSource = `
|
||||
auto n = thread_position_in_grid.z;
|
||||
auto b_idx = n / Hv;
|
||||
auto hv_idx = n % Hv;
|
||||
auto hk_idx = hv_idx / (Hv / Hk);
|
||||
constexpr int n_per_t = Dk / 32;
|
||||
|
||||
// q, k: [B, T, Hk, Dk]
|
||||
auto q_ = q + b_idx * T * Hk * Dk + hk_idx * Dk;
|
||||
auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk;
|
||||
|
||||
// v, y: [B, T, Hv, Dv]
|
||||
auto v_ = v + b_idx * T * Hv * Dv + hv_idx * Dv;
|
||||
y += b_idx * T * Hv * Dv + hv_idx * Dv;
|
||||
|
||||
auto dk_idx = thread_position_in_threadgroup.x;
|
||||
auto dv_idx = thread_position_in_grid.y;
|
||||
|
||||
// state_in, state_out: [B, Hv, Dv, Dk]
|
||||
auto i_state = state_in + (n * Dv + dv_idx) * Dk;
|
||||
auto o_state = state_out + (n * Dv + dv_idx) * Dk;
|
||||
|
||||
float state[n_per_t];
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
state[i] = static_cast<float>(i_state[s_idx]);
|
||||
}
|
||||
|
||||
// g: [B, T, Hv]
|
||||
auto g_ = g + b_idx * T * Hv;
|
||||
auto beta_ = beta + b_idx * T * Hv;
|
||||
|
||||
for (int t = 0; t < T; ++t) {
|
||||
float kv_mem = 0.0f;
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
state[i] = state[i] * g_[hv_idx];
|
||||
kv_mem += state[i] * k_[s_idx];
|
||||
}
|
||||
kv_mem = simd_sum(kv_mem);
|
||||
|
||||
auto delta = (v_[dv_idx] - kv_mem) * beta_[hv_idx];
|
||||
|
||||
float out = 0.0f;
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
state[i] = state[i] + k_[s_idx] * delta;
|
||||
out += state[i] * q_[s_idx];
|
||||
}
|
||||
out = simd_sum(out);
|
||||
if (thread_index_in_simdgroup == 0) {
|
||||
y[dv_idx] = static_cast<InT>(out);
|
||||
}
|
||||
|
||||
q_ += Hk * Dk;
|
||||
k_ += Hk * Dk;
|
||||
v_ += Hv * Dv;
|
||||
y += Hv * Dv;
|
||||
g_ += Hv;
|
||||
beta_ += Hv;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
o_state[s_idx] = static_cast<InT>(state[i]);
|
||||
}
|
||||
`
|
||||
|
||||
func cStringVector(values []string) (C.mlx_vector_string, func(), bool) {
|
||||
vec := C.mlx_vector_string_new()
|
||||
ok := true
|
||||
for _, s := range values {
|
||||
cs := C.CString(s)
|
||||
if C.mlx_vector_string_append_value(vec, cs) != 0 {
|
||||
ok = false
|
||||
}
|
||||
C.free(unsafe.Pointer(cs))
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
cleanup := func() {
|
||||
C.mlx_vector_string_free(vec)
|
||||
}
|
||||
return vec, cleanup, ok
|
||||
}
|
||||
|
||||
func initGatedDeltaMetalKernel() {
|
||||
inputs, freeInputs, ok := cStringVector([]string{"q", "k", "v", "g", "beta", "state_in", "T"})
|
||||
if !ok {
|
||||
gatedDeltaMetalDisabled.Store(true)
|
||||
freeInputs()
|
||||
return
|
||||
}
|
||||
defer freeInputs()
|
||||
|
||||
outputs, freeOutputs, ok := cStringVector([]string{"y", "state_out"})
|
||||
if !ok {
|
||||
gatedDeltaMetalDisabled.Store(true)
|
||||
freeOutputs()
|
||||
return
|
||||
}
|
||||
defer freeOutputs()
|
||||
|
||||
cName := C.CString("gated_delta_step")
|
||||
defer C.free(unsafe.Pointer(cName))
|
||||
cSource := C.CString(gatedDeltaMetalKernelSource)
|
||||
defer C.free(unsafe.Pointer(cSource))
|
||||
cHeader := C.CString("")
|
||||
defer C.free(unsafe.Pointer(cHeader))
|
||||
|
||||
gatedDeltaMetalKernel = C.mlx_fast_metal_kernel_new(
|
||||
cName,
|
||||
inputs,
|
||||
outputs,
|
||||
cSource,
|
||||
cHeader,
|
||||
C.bool(true),
|
||||
C.bool(false),
|
||||
)
|
||||
}
|
||||
|
||||
// GatedDeltaKernel runs a fused Metal kernel for the qwen3.5 recurrent update.
|
||||
// It returns ok=false on unsupported shapes/devices or kernel setup/apply failure.
|
||||
func GatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok bool) {
|
||||
if gatedDeltaMetalDisabled.Load() {
|
||||
return nil, nil, false
|
||||
}
|
||||
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
|
||||
return nil, nil, false
|
||||
}
|
||||
if !q.Valid() || !k.Valid() || !v.Valid() || !g.Valid() || !beta.Valid() || !state.Valid() {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
qd := q.Dims()
|
||||
kd := k.Dims()
|
||||
vd := v.Dims()
|
||||
gd := g.Dims()
|
||||
bd := beta.Dims()
|
||||
sd := state.Dims()
|
||||
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
B, T, Hk, Dk := qd[0], qd[1], qd[2], qd[3]
|
||||
if T <= 0 || Hk <= 0 || Dk <= 0 || Dk%32 != 0 {
|
||||
return nil, nil, false
|
||||
}
|
||||
if kd[0] != B || kd[1] != T || kd[2] != Hk || kd[3] != Dk {
|
||||
return nil, nil, false
|
||||
}
|
||||
Hv, Dv := vd[2], vd[3]
|
||||
if vd[0] != B || vd[1] != T || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
|
||||
return nil, nil, false
|
||||
}
|
||||
if gd[0] != B || gd[1] != T || gd[2] != Hv {
|
||||
return nil, nil, false
|
||||
}
|
||||
if bd[0] != B || bd[1] != T || bd[2] != Hv {
|
||||
return nil, nil, false
|
||||
}
|
||||
if sd[0] != B || sd[1] != Hv || sd[2] != Dv || sd[3] != Dk {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
dtype := q.DType()
|
||||
if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype || state.DType() != dtype {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
gatedDeltaMetalKernelOnce.Do(initGatedDeltaMetalKernel)
|
||||
if gatedDeltaMetalDisabled.Load() {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
cfg := C.mlx_fast_metal_kernel_config_new()
|
||||
defer C.mlx_fast_metal_kernel_config_free(cfg)
|
||||
|
||||
cInT := C.CString("InT")
|
||||
defer C.free(unsafe.Pointer(cInT))
|
||||
if C.mlx_fast_metal_kernel_config_add_template_arg_dtype(cfg, cInT, C.mlx_dtype(dtype)) != 0 {
|
||||
gatedDeltaMetalDisabled.Store(true)
|
||||
return nil, nil, false
|
||||
}
|
||||
for _, tpl := range []struct {
|
||||
name string
|
||||
value int
|
||||
}{
|
||||
{name: "Dk", value: Dk},
|
||||
{name: "Dv", value: Dv},
|
||||
{name: "Hk", value: Hk},
|
||||
{name: "Hv", value: Hv},
|
||||
} {
|
||||
cn := C.CString(tpl.name)
|
||||
rc := C.mlx_fast_metal_kernel_config_add_template_arg_int(cfg, cn, C.int(tpl.value))
|
||||
C.free(unsafe.Pointer(cn))
|
||||
if rc != 0 {
|
||||
gatedDeltaMetalDisabled.Store(true)
|
||||
return nil, nil, false
|
||||
}
|
||||
}
|
||||
|
||||
yShape := []C.int{C.int(B), C.int(T), C.int(Hv), C.int(Dv)}
|
||||
stateShape := []C.int{C.int(B), C.int(Hv), C.int(Dv), C.int(Dk)}
|
||||
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(yShape), C.size_t(len(yShape)), C.mlx_dtype(dtype)) != 0 {
|
||||
gatedDeltaMetalDisabled.Store(true)
|
||||
return nil, nil, false
|
||||
}
|
||||
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(dtype)) != 0 {
|
||||
gatedDeltaMetalDisabled.Store(true)
|
||||
return nil, nil, false
|
||||
}
|
||||
if C.mlx_fast_metal_kernel_config_set_grid(cfg, 32, C.int(Dv), C.int(B*Hv)) != 0 {
|
||||
gatedDeltaMetalDisabled.Store(true)
|
||||
return nil, nil, false
|
||||
}
|
||||
threadY := Dv
|
||||
if threadY > 4 {
|
||||
threadY = 4
|
||||
}
|
||||
if C.mlx_fast_metal_kernel_config_set_thread_group(cfg, 32, C.int(threadY), 1) != 0 {
|
||||
gatedDeltaMetalDisabled.Store(true)
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
tScalar := FromValue(T)
|
||||
inputs := []C.mlx_array{
|
||||
q.ctx,
|
||||
k.ctx,
|
||||
v.ctx,
|
||||
g.ctx,
|
||||
beta.ctx,
|
||||
state.ctx,
|
||||
tScalar.ctx,
|
||||
}
|
||||
inVec := C.mlx_vector_array_new_data(unsafe.SliceData(inputs), C.size_t(len(inputs)))
|
||||
defer C.mlx_vector_array_free(inVec)
|
||||
|
||||
outVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(outVec)
|
||||
if C.mlx_fast_metal_kernel_apply(&outVec, gatedDeltaMetalKernel, inVec, cfg, DefaultStream().ctx) != 0 {
|
||||
gatedDeltaMetalDisabled.Store(true)
|
||||
return nil, nil, false
|
||||
}
|
||||
if int(C.mlx_vector_array_size(outVec)) < 2 {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
y = New("GATED_DELTA_METAL_Y")
|
||||
nextState = New("GATED_DELTA_METAL_STATE")
|
||||
C.mlx_vector_array_get(&y.ctx, outVec, 0)
|
||||
C.mlx_vector_array_get(&nextState.ctx, outVec, 1)
|
||||
return y, nextState, true
|
||||
}
|
||||
@@ -19,7 +19,7 @@ func doEval(outputs []*Array, async bool) {
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
for _, output := range outputs {
|
||||
if output.Valid() {
|
||||
if output != nil && output.Valid() {
|
||||
C.mlx_vector_array_append_value(vector, output.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,6 +113,35 @@ func Where(condition, a, b *Array) *Array {
|
||||
return out
|
||||
}
|
||||
|
||||
func Conv1d(x, weight *Array, bias *Array, stride, padding, dilation, groups int32) *Array {
|
||||
out := New("CONV1D")
|
||||
C.mlx_conv1d(
|
||||
&out.ctx,
|
||||
x.ctx,
|
||||
weight.ctx,
|
||||
C.int(stride),
|
||||
C.int(padding),
|
||||
C.int(dilation),
|
||||
C.int(groups),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
if bias != nil && bias.Valid() {
|
||||
out = Add(out, bias)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func Contiguous(a *Array, allowColMajor bool) *Array {
|
||||
out := New("CONTIGUOUS")
|
||||
C.mlx_contiguous(&out.ctx, a.ctx, C.bool(allowColMajor), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
|
||||
groups := int32(x.Dim(x.NumDims() - 1))
|
||||
return Conv1d(x, weight, bias, 1, 0, 1, groups)
|
||||
}
|
||||
|
||||
// Convenience wrappers (function-style for the model code)
|
||||
|
||||
func Stack(arrays []*Array, axis int) *Array {
|
||||
@@ -271,6 +300,24 @@ func Sigmoid(a *Array) *Array {
|
||||
return a.Sigmoid()
|
||||
}
|
||||
|
||||
func Exp(a *Array) *Array {
|
||||
out := New("EXP")
|
||||
C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Log(a *Array) *Array {
|
||||
out := New("LOG")
|
||||
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func SoftmaxAxis(a *Array, axis int, precise bool) *Array {
|
||||
out := New("SOFTMAX_AXIS")
|
||||
C.mlx_softmax_axis(&out.ctx, a.ctx, C.int(axis), C.bool(precise), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask bool) *Array {
|
||||
mask := New("")
|
||||
sinks := New("")
|
||||
@@ -288,7 +335,11 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b
|
||||
|
||||
func RMSNormFn(x, weight *Array, eps float32) *Array {
|
||||
out := New("FAST_RMSNORM")
|
||||
C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx)
|
||||
var w C.mlx_array
|
||||
if weight != nil {
|
||||
w = weight.ctx
|
||||
}
|
||||
C.mlx_fast_rms_norm(&out.ctx, x.ctx, w, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -378,6 +429,27 @@ func Collect(v any) []*Array {
|
||||
return arrays
|
||||
}
|
||||
|
||||
// Snapshot copies an array into a fresh leaf value with no Go-side graph inputs.
|
||||
func Snapshot(a *Array) *Array {
|
||||
if a == nil || !a.Valid() {
|
||||
return a
|
||||
}
|
||||
out := New("SNAPSHOT")
|
||||
C.mlx_copy(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Detach returns a new Array handle that shares the same MLX value but does
|
||||
// not retain Go-side graph input references.
|
||||
func Detach(a *Array) *Array {
|
||||
if a == nil || !a.Valid() {
|
||||
return a
|
||||
}
|
||||
out := New("DETACH")
|
||||
C.mlx_array_set(&out.ctx, a.ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
|
||||
if !v.IsValid() {
|
||||
return
|
||||
|
||||
@@ -13,11 +13,20 @@ import (
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
func prefillChunkSize() int {
|
||||
return 2 << 10
|
||||
}
|
||||
|
||||
func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
if r.Model == nil {
|
||||
return errors.New("model not loaded")
|
||||
}
|
||||
|
||||
ctx := request.Ctx
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
var (
|
||||
sample, logprobs *mlx.Array
|
||||
nextSample, nextLogprobs *mlx.Array
|
||||
@@ -51,24 +60,33 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
|
||||
caches := session.caches
|
||||
tokens := session.remaining
|
||||
prefillChunk := prefillChunkSize()
|
||||
|
||||
materializeCaches := func() {
|
||||
state := make([]*mlx.Array, 0, 2*len(caches))
|
||||
for _, c := range caches {
|
||||
if c == nil {
|
||||
continue
|
||||
}
|
||||
state = append(state, c.Materialize()...)
|
||||
}
|
||||
if len(state) == 0 {
|
||||
return
|
||||
}
|
||||
mlx.Eval(state...)
|
||||
}
|
||||
|
||||
total, processed := len(tokens), 0
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
for total-processed > 1 {
|
||||
if err := request.Ctx.Err(); err != nil {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
n := min(2<<10, total-processed-1)
|
||||
n := min(prefillChunk, total-processed-1)
|
||||
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
||||
mlx.Sweep()
|
||||
mlx.Eval(func() []*mlx.Array {
|
||||
s := make([]*mlx.Array, 2*len(caches))
|
||||
for i, c := range caches {
|
||||
s[2*i], s[2*i+1] = c.State()
|
||||
}
|
||||
return s
|
||||
}()...)
|
||||
materializeCaches()
|
||||
processed += n
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
mlx.ClearCache()
|
||||
@@ -96,7 +114,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
now := time.Now()
|
||||
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
|
||||
for i := range request.Options.MaxTokens {
|
||||
if err := request.Ctx.Err(); err != nil {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -120,8 +138,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
}
|
||||
|
||||
select {
|
||||
case <-request.Ctx.Done():
|
||||
return request.Ctx.Err()
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case request.Responses <- Response{
|
||||
Text: r.Decode(output, &b),
|
||||
Token: int(output),
|
||||
@@ -139,8 +157,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
|
||||
final.CompletionTokensDuration = time.Since(now)
|
||||
select {
|
||||
case <-request.Ctx.Done():
|
||||
return request.Ctx.Err()
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case request.Responses <- final:
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -15,6 +15,40 @@ type LinearLayer interface {
|
||||
OutputDim() int32
|
||||
}
|
||||
|
||||
// Conv1d applies 1D convolution over NLC input.
|
||||
type Conv1d struct {
|
||||
Weight *mlx.Array
|
||||
Bias *mlx.Array
|
||||
Stride int32
|
||||
Padding int32
|
||||
Dilation int32
|
||||
Groups int32
|
||||
}
|
||||
|
||||
func NewConv1d(weight, bias *mlx.Array, stride, padding, dilation, groups int32) *Conv1d {
|
||||
if stride <= 0 {
|
||||
stride = 1
|
||||
}
|
||||
if dilation <= 0 {
|
||||
dilation = 1
|
||||
}
|
||||
if groups <= 0 {
|
||||
groups = 1
|
||||
}
|
||||
return &Conv1d{
|
||||
Weight: weight,
|
||||
Bias: bias,
|
||||
Stride: stride,
|
||||
Padding: padding,
|
||||
Dilation: dilation,
|
||||
Groups: groups,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conv1d) Forward(x *mlx.Array) *mlx.Array {
|
||||
return mlx.Conv1d(x, c.Weight, c.Bias, c.Stride, c.Padding, c.Dilation, c.Groups)
|
||||
}
|
||||
|
||||
// Linear applies an affine transformation: y = x @ W.T + b
|
||||
type Linear struct {
|
||||
Weight *mlx.Array
|
||||
|
||||
1457
x/models/qwen3_5/qwen3_5.go
Normal file
1457
x/models/qwen3_5/qwen3_5.go
Normal file
File diff suppressed because it is too large
Load Diff
166
x/models/qwen3_5/qwen3_5_test.go
Normal file
166
x/models/qwen3_5/qwen3_5_test.go
Normal file
@@ -0,0 +1,166 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen3_5
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
func TestParseConfigNestedDefaults(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"model_type": "Qwen3_5MoeForConditionalGeneration",
|
||||
"text_config": {
|
||||
"hidden_size": 4096,
|
||||
"intermediate_size": 14336,
|
||||
"num_hidden_layers": 8,
|
||||
"num_attention_heads": 32,
|
||||
"num_key_value_heads": 8,
|
||||
"head_dim": 128,
|
||||
"linear_num_value_heads": 64,
|
||||
"linear_num_key_heads": 16,
|
||||
"linear_key_head_dim": 128,
|
||||
"linear_value_head_dim": 128,
|
||||
"linear_conv_kernel_dim": 4,
|
||||
"num_experts": 16,
|
||||
"num_experts_per_tok": 4,
|
||||
"moe_intermediate_size": 2048,
|
||||
"shared_expert_intermediate_size": 4096,
|
||||
"rope_parameters": {
|
||||
"rope_theta": 500000,
|
||||
"partial_rotary_factor": 0.5
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
cfg, err := parseConfig(data)
|
||||
if err != nil {
|
||||
t.Fatalf("parseConfig failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.RopeTheta != 500000 {
|
||||
t.Fatalf("rope theta mismatch: got %v", cfg.RopeTheta)
|
||||
}
|
||||
if cfg.RopeDim != 64 {
|
||||
t.Fatalf("rope dim mismatch: got %d want 64", cfg.RopeDim)
|
||||
}
|
||||
if cfg.FullAttentionInterval != 4 {
|
||||
t.Fatalf("full_attention_interval default mismatch: got %d want 4", cfg.FullAttentionInterval)
|
||||
}
|
||||
if !cfg.NormTopKProb {
|
||||
t.Fatalf("norm_topk_prob should default to true for MoE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLayerSelectionHelpers(t *testing.T) {
|
||||
cfg := &Config{
|
||||
NumHiddenLayers: 6,
|
||||
FullAttentionInterval: 3,
|
||||
NumExperts: 8,
|
||||
DecoderSparseStep: 2,
|
||||
MLPOnlyLayers: []int32{1},
|
||||
}
|
||||
|
||||
if !layerIsLinear(cfg, 0) {
|
||||
t.Fatalf("layer 0 should be linear")
|
||||
}
|
||||
if layerIsLinear(cfg, 2) {
|
||||
t.Fatalf("layer 2 should be full attention")
|
||||
}
|
||||
|
||||
if layerUsesMoE(cfg, 1) {
|
||||
t.Fatalf("layer 1 should be forced dense by mlp_only_layers")
|
||||
}
|
||||
if !layerUsesMoE(cfg, 3) {
|
||||
t.Fatalf("layer 3 should use moe with decoder_sparse_step=2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveTensorPathLayout(t *testing.T) {
|
||||
dummy := mlx.New("dummy")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
wantContainer string
|
||||
wantModel string
|
||||
}{
|
||||
{
|
||||
name: "standard",
|
||||
key: "model.embed_tokens.weight",
|
||||
wantContainer: "",
|
||||
wantModel: "model.",
|
||||
},
|
||||
{
|
||||
name: "nested language model with inner model",
|
||||
key: "model.language_model.model.embed_tokens.weight",
|
||||
wantContainer: "model.language_model.",
|
||||
wantModel: "model.",
|
||||
},
|
||||
{
|
||||
name: "nested language model without inner model",
|
||||
key: "model.language_model.embed_tokens.weight",
|
||||
wantContainer: "model.language_model.",
|
||||
wantModel: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
layout := resolveTensorPathLayout(map[string]*mlx.Array{
|
||||
tt.key: dummy,
|
||||
})
|
||||
|
||||
if layout.containerPrefix != tt.wantContainer || layout.modelPrefix != tt.wantModel {
|
||||
t.Fatalf(
|
||||
"resolveTensorPathLayout() = {%q %q}, want {%q %q}",
|
||||
layout.containerPrefix,
|
||||
layout.modelPrefix,
|
||||
tt.wantContainer,
|
||||
tt.wantModel,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRuntimeDefaults(t *testing.T) {
|
||||
m := &Model{}
|
||||
if m.DisablePromptCache() {
|
||||
t.Fatal("DisablePromptCache() = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCachesLayout(t *testing.T) {
|
||||
m := &Model{
|
||||
Config: &Config{
|
||||
LinearConvKernelDim: 4,
|
||||
LinearNumKeyHeads: 2,
|
||||
LinearKeyHeadDim: 8,
|
||||
LinearNumValueHeads: 4,
|
||||
LinearValueHeadDim: 16,
|
||||
},
|
||||
Layers: []*Layer{
|
||||
{IsLinear: true},
|
||||
{IsLinear: false},
|
||||
{IsLinear: true},
|
||||
},
|
||||
}
|
||||
|
||||
caches := m.NewCaches()
|
||||
if len(caches) != len(m.Layers) {
|
||||
t.Fatalf("len(caches) = %d, want %d", len(caches), len(m.Layers))
|
||||
}
|
||||
|
||||
if _, ok := caches[0].(*cache.RecurrentCache); !ok {
|
||||
t.Fatalf("cache[0] = %T, want *cache.RecurrentCache", caches[0])
|
||||
}
|
||||
if _, ok := caches[1].(*cache.KVCache); !ok {
|
||||
t.Fatalf("cache[1] = %T, want *cache.KVCache", caches[1])
|
||||
}
|
||||
if _, ok := caches[2].(*cache.RecurrentCache); !ok {
|
||||
t.Fatalf("cache[2] = %T, want *cache.RecurrentCache", caches[2])
|
||||
}
|
||||
}
|
||||
16
x/models/qwen3_5_moe/qwen3_5_moe.go
Normal file
16
x/models/qwen3_5_moe/qwen3_5_moe.go
Normal file
@@ -0,0 +1,16 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package qwen3_5_moe registers Qwen 3.5 MoE architecture aliases.
|
||||
package qwen3_5_moe
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/models/qwen3_5"
|
||||
)
|
||||
|
||||
func init() {
|
||||
base.Register("Qwen3_5MoeForConditionalGeneration", qwen3_5.NewModel)
|
||||
base.Register("Qwen3_5MoeForCausalLM", qwen3_5.NewModel)
|
||||
base.Register("Qwen3NextMoeForConditionalGeneration", qwen3_5.NewModel)
|
||||
base.Register("Qwen3NextMoeForCausalLM", qwen3_5.NewModel)
|
||||
}
|
||||
Reference in New Issue
Block a user