mirror of
https://github.com/ollama/ollama.git
synced 2026-02-23 10:45:08 -05:00
Compare commits
1 Commits
main
...
pdevine/qw
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
00f67e807a |
20
x/mlxrunner/cache/cache.go
vendored
20
x/mlxrunner/cache/cache.go
vendored
@@ -4,13 +4,19 @@ package cache
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
func kvCacheGrowDebugEnabled() bool {
|
||||
return os.Getenv("OLLAMA_MLX_DEBUG_CACHE_GROW") != ""
|
||||
}
|
||||
|
||||
type Cache interface {
|
||||
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
|
||||
State() (keys, values *mlx.Array)
|
||||
Materialize() []*mlx.Array
|
||||
Trim(int) int
|
||||
Clone() Cache
|
||||
Offset() int
|
||||
@@ -48,6 +54,9 @@ func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
} else {
|
||||
c.keys, c.values = newKeys, newValues
|
||||
}
|
||||
if kvCacheGrowDebugEnabled() {
|
||||
slog.Info("KVCache grow", "prev", prev, "new_capacity", c.keys.Dim(2), "step", c.step)
|
||||
}
|
||||
}
|
||||
|
||||
c.offset += L
|
||||
@@ -66,6 +75,17 @@ func (c *KVCache) State() (*mlx.Array, *mlx.Array) {
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||
}
|
||||
|
||||
func (c *KVCache) Materialize() []*mlx.Array {
|
||||
out := make([]*mlx.Array, 0, 2)
|
||||
if c.keys != nil && c.keys.Valid() {
|
||||
out = append(out, c.keys)
|
||||
}
|
||||
if c.values != nil && c.values.Valid() {
|
||||
out = append(out, c.values)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *KVCache) Trim(n int) int {
|
||||
n = min(c.offset, n)
|
||||
c.offset -= n
|
||||
|
||||
17
x/mlxrunner/cache/cache_test.go
vendored
Normal file
17
x/mlxrunner/cache/cache_test.go
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
//go:build mlx
|
||||
|
||||
package cache
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestKVCacheGrowDebugEnabled(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MLX_DEBUG_CACHE_GROW", "")
|
||||
if kvCacheGrowDebugEnabled() {
|
||||
t.Fatal("kvCacheGrowDebugEnabled() = true, want false")
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_DEBUG_CACHE_GROW", "1")
|
||||
if !kvCacheGrowDebugEnabled() {
|
||||
t.Fatal("kvCacheGrowDebugEnabled() = false, want true")
|
||||
}
|
||||
}
|
||||
162
x/mlxrunner/cache/recurrent.go
vendored
Normal file
162
x/mlxrunner/cache/recurrent.go
vendored
Normal file
@@ -0,0 +1,162 @@
|
||||
//go:build mlx
|
||||
|
||||
package cache
|
||||
|
||||
import "github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
|
||||
// RecurrentCache stores state for linear-recurrent layers.
|
||||
//
|
||||
// Conv state shape: [B, convTail, convDim]
|
||||
// Delta state shape: [B, numVHeads, headVDim, headKDim]
|
||||
type RecurrentCache struct {
|
||||
convState *mlx.Array
|
||||
deltaState *mlx.Array
|
||||
offset int
|
||||
|
||||
convTail int
|
||||
convDim int
|
||||
numVHeads int
|
||||
headVDim int
|
||||
headKDim int
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) setStateMaterialized(dst **mlx.Array, v *mlx.Array) {
|
||||
if v == nil || !v.Valid() {
|
||||
return
|
||||
}
|
||||
if *dst == v {
|
||||
return
|
||||
}
|
||||
|
||||
// Break dependency chains so recurrent state does not retain the full
|
||||
// per-token compute graph over time.
|
||||
snap := mlx.Snapshot(v)
|
||||
mlx.Eval(snap)
|
||||
|
||||
old := *dst
|
||||
*dst = snap
|
||||
|
||||
// Release previous cached state root, then recursively free the transient
|
||||
// incoming graph root now that a detached snapshot is retained in cache.
|
||||
if old != nil && old != snap {
|
||||
mlx.Release(old)
|
||||
}
|
||||
if v != snap && v != old {
|
||||
mlx.Free(v)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) setStateRaw(dst **mlx.Array, v *mlx.Array) {
|
||||
if v == nil || !v.Valid() {
|
||||
return
|
||||
}
|
||||
old := *dst
|
||||
*dst = v
|
||||
if old != nil && old != v {
|
||||
mlx.Release(old)
|
||||
}
|
||||
}
|
||||
|
||||
func NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim int32) *RecurrentCache {
|
||||
return &RecurrentCache{
|
||||
convTail: int(convTail),
|
||||
convDim: int(convDim),
|
||||
numVHeads: int(numVHeads),
|
||||
headVDim: int(headVDim),
|
||||
headKDim: int(headKDim),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) ensure(batch int, dtype mlx.DType) {
|
||||
if batch <= 0 {
|
||||
batch = 1
|
||||
}
|
||||
|
||||
if c.convState == nil || c.convState.DType() != dtype ||
|
||||
c.convState.Dim(0) != batch || c.convState.Dim(1) != c.convTail || c.convState.Dim(2) != c.convDim {
|
||||
c.setStateRaw(&c.convState, mlx.Zeros(dtype, batch, c.convTail, c.convDim))
|
||||
}
|
||||
|
||||
if c.deltaState == nil || c.deltaState.DType() != dtype ||
|
||||
c.deltaState.Dim(0) != batch || c.deltaState.Dim(1) != c.numVHeads || c.deltaState.Dim(2) != c.headVDim || c.deltaState.Dim(3) != c.headKDim {
|
||||
c.setStateRaw(&c.deltaState, mlx.Zeros(dtype, batch, c.numVHeads, c.headVDim, c.headKDim))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) ConvState(batch int, dtype mlx.DType) *mlx.Array {
|
||||
c.ensure(batch, dtype)
|
||||
return c.convState
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) SetConvState(v *mlx.Array) {
|
||||
c.setStateMaterialized(&c.convState, v)
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) DeltaState(batch int, dtype mlx.DType) *mlx.Array {
|
||||
c.ensure(batch, dtype)
|
||||
return c.deltaState
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) SetDeltaState(v *mlx.Array) {
|
||||
c.setStateMaterialized(&c.deltaState, v)
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Advance(n int) {
|
||||
c.offset += n
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
return keys, values
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) State() (*mlx.Array, *mlx.Array) {
|
||||
c.ensure(1, mlx.DTypeFloat32)
|
||||
return c.convState, c.deltaState
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Materialize() []*mlx.Array {
|
||||
out := make([]*mlx.Array, 0, 2)
|
||||
if c.convState != nil && c.convState.Valid() {
|
||||
out = append(out, c.convState)
|
||||
}
|
||||
if c.deltaState != nil && c.deltaState.Valid() {
|
||||
out = append(out, c.deltaState)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Trim(n int) int {
|
||||
n = min(c.offset, n)
|
||||
c.offset -= n
|
||||
// Recurrent state cannot be reversed cheaply; reset to a clean state when trimming.
|
||||
if n > 0 {
|
||||
if c.convState != nil {
|
||||
c.setStateRaw(&c.convState, mlx.Zeros(c.convState.DType(), c.convState.Dim(0), c.convState.Dim(1), c.convState.Dim(2)))
|
||||
}
|
||||
if c.deltaState != nil {
|
||||
c.setStateRaw(&c.deltaState, mlx.Zeros(c.deltaState.DType(), c.deltaState.Dim(0), c.deltaState.Dim(1), c.deltaState.Dim(2), c.deltaState.Dim(3)))
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Clone() Cache {
|
||||
clone := &RecurrentCache{
|
||||
offset: c.offset,
|
||||
convTail: c.convTail,
|
||||
convDim: c.convDim,
|
||||
numVHeads: c.numVHeads,
|
||||
headVDim: c.headVDim,
|
||||
headKDim: c.headKDim,
|
||||
}
|
||||
if c.convState != nil {
|
||||
clone.convState = c.convState.Clone()
|
||||
}
|
||||
if c.deltaState != nil {
|
||||
clone.deltaState = c.deltaState.Clone()
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Offset() int { return c.offset }
|
||||
func (c *RecurrentCache) Len() int { return c.offset }
|
||||
@@ -7,4 +7,6 @@ import (
|
||||
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
|
||||
_ "github.com/ollama/ollama/x/models/llama"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3_5"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3_5_moe"
|
||||
)
|
||||
|
||||
@@ -272,3 +272,39 @@ func Free(s ...*Array) (n int) {
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
// Release forcibly frees arrays regardless of reference accounting.
|
||||
// Use only for arrays that are known to be unreachable by any live model state.
|
||||
func Release(s ...*Array) (n int) {
|
||||
seen := make(map[*Array]bool, len(s))
|
||||
for _, t := range s {
|
||||
if t == nil || !t.Valid() || seen[t] {
|
||||
continue
|
||||
}
|
||||
seen[t] = true
|
||||
n += t.NumBytes()
|
||||
C.mlx_array_free(t.ctx)
|
||||
t.ctx.ctx = nil
|
||||
t.desc.inputs = nil
|
||||
t.desc.numRefs = 0
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
const pinnedNumRefs = 1 << 30
|
||||
|
||||
// Pin keeps arrays alive for the process lifetime by setting a very high
|
||||
// reference count floor. Use for model parameter tensors shared across many
|
||||
// decode steps, where recursive Free traversals must never reclaim them.
|
||||
func Pin(s ...*Array) {
|
||||
seen := make(map[*Array]bool, len(s))
|
||||
for _, t := range s {
|
||||
if t == nil || !t.Valid() || seen[t] {
|
||||
continue
|
||||
}
|
||||
seen[t] = true
|
||||
if t.desc.numRefs < pinnedNumRefs {
|
||||
t.desc.numRefs = pinnedNumRefs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -279,6 +279,24 @@ func Sigmoid(a *Array) *Array {
|
||||
return a.Sigmoid()
|
||||
}
|
||||
|
||||
func Exp(a *Array) *Array {
|
||||
out := New("EXP", a)
|
||||
C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Log(a *Array) *Array {
|
||||
out := New("LOG", a)
|
||||
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func SoftmaxAxis(a *Array, axis int, precise bool) *Array {
|
||||
out := New("SOFTMAX_AXIS", a)
|
||||
C.mlx_softmax_axis(&out.ctx, a.ctx, C.int(axis), C.bool(precise), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask bool) *Array {
|
||||
mask := New("")
|
||||
sinks := New("")
|
||||
@@ -386,6 +404,52 @@ func Collect(v any) []*Array {
|
||||
return arrays
|
||||
}
|
||||
|
||||
// Snapshot copies an array into a fresh leaf value with no Go-side graph inputs.
|
||||
func Snapshot(a *Array) *Array {
|
||||
if a == nil || !a.Valid() {
|
||||
return a
|
||||
}
|
||||
out := New("SNAPSHOT")
|
||||
C.mlx_copy(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// CollectReachable collects arrays from v and all transitive graph inputs.
|
||||
func CollectReachable(v any) []*Array {
|
||||
roots := Collect(v)
|
||||
if len(roots) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
seen := make(map[*Array]bool, len(roots))
|
||||
out := make([]*Array, 0, len(roots))
|
||||
stack := append([]*Array(nil), roots...)
|
||||
for len(stack) > 0 {
|
||||
a := stack[len(stack)-1]
|
||||
stack = stack[:len(stack)-1]
|
||||
|
||||
if a == nil || !a.Valid() || seen[a] {
|
||||
continue
|
||||
}
|
||||
seen[a] = true
|
||||
out = append(out, a)
|
||||
stack = append(stack, a.desc.inputs...)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// Detach returns a new Array handle that shares the same MLX value but does
|
||||
// not retain Go-side graph input references.
|
||||
func Detach(a *Array) *Array {
|
||||
if a == nil || !a.Valid() {
|
||||
return a
|
||||
}
|
||||
out := New("DETACH")
|
||||
C.mlx_array_set(&out.ctx, a.ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
|
||||
if !v.IsValid() {
|
||||
return
|
||||
|
||||
@@ -6,12 +6,43 @@ import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
func prefillChunkSize(lowMemoryDecode bool) int {
|
||||
if v := os.Getenv("OLLAMA_MLX_PREFILL_CHUNK"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
|
||||
if lowMemoryDecode {
|
||||
// Recurrent/no-prompt-cache path favors lower peak memory over prefill throughput.
|
||||
// Keep this conservative to avoid transient prefill spikes and allocator thrash.
|
||||
return 32
|
||||
}
|
||||
return 2 << 10
|
||||
}
|
||||
|
||||
func mlxDebugMemoryEnabled() bool {
|
||||
return os.Getenv("OLLAMA_MLX_DEBUG_MEMORY") != ""
|
||||
}
|
||||
|
||||
func finalizeRequestCaches(usePromptCache bool, insertCache func(), freeCaches func(), logMemory func(string, int)) {
|
||||
if usePromptCache {
|
||||
insertCache()
|
||||
logMemory("request_done_cached", -1)
|
||||
return
|
||||
}
|
||||
freeCaches()
|
||||
logMemory("request_done_freed", -1)
|
||||
}
|
||||
|
||||
func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
if r.Model == nil {
|
||||
return errors.New("model not loaded")
|
||||
@@ -29,7 +60,21 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
|
||||
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
||||
|
||||
caches, tokens := r.FindNearestCache(inputs)
|
||||
usePromptCache := true
|
||||
if m, ok := r.Model.(interface{ DisablePromptCache() bool }); ok && m.DisablePromptCache() {
|
||||
usePromptCache = false
|
||||
}
|
||||
lowMemoryDecode := !usePromptCache
|
||||
prefillChunk := prefillChunkSize(lowMemoryDecode)
|
||||
|
||||
var caches []cache.Cache
|
||||
var tokens []int32
|
||||
if usePromptCache {
|
||||
caches, tokens = r.FindNearestCache(inputs)
|
||||
} else {
|
||||
tokens = inputs
|
||||
}
|
||||
|
||||
if len(caches) == 0 {
|
||||
if cacheFactory, ok := r.Model.(interface{ NewCaches() []cache.Cache }); ok {
|
||||
caches = cacheFactory.NewCaches()
|
||||
@@ -41,23 +86,54 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
}
|
||||
}
|
||||
|
||||
materializeCaches := func() {
|
||||
state := make([]*mlx.Array, 0, 2*len(caches))
|
||||
for _, c := range caches {
|
||||
state = append(state, c.Materialize()...)
|
||||
}
|
||||
if len(state) == 0 {
|
||||
return
|
||||
}
|
||||
mlx.Eval(state...)
|
||||
}
|
||||
freeCaches := func() {
|
||||
state := make([]*mlx.Array, 0, 2*len(caches))
|
||||
for _, c := range caches {
|
||||
state = append(state, c.Materialize()...)
|
||||
}
|
||||
if len(state) == 0 {
|
||||
return
|
||||
}
|
||||
// Non-prompt-cache requests allocate fresh caches every generation.
|
||||
// Explicitly free cache roots so graph chains are reclaimed promptly.
|
||||
mlx.Free(state...)
|
||||
mlx.ClearCache()
|
||||
}
|
||||
debugMemory := mlxDebugMemoryEnabled()
|
||||
logMemory := func(phase string, token int) {
|
||||
if !debugMemory {
|
||||
return
|
||||
}
|
||||
if token >= 0 {
|
||||
slog.Info("MLX memory", "phase", phase, "token", token, "memory", mlx.Memory{})
|
||||
return
|
||||
}
|
||||
slog.Info("MLX memory", "phase", phase, "memory", mlx.Memory{})
|
||||
}
|
||||
logMemory("prefill_start", -1)
|
||||
|
||||
total, processed := len(tokens), 0
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
for total-processed > 1 {
|
||||
n := min(2<<10, total-processed-1)
|
||||
n := min(prefillChunk, total-processed-1)
|
||||
temp := r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
||||
defer mlx.Free(temp)
|
||||
mlx.Eval(func() []*mlx.Array {
|
||||
s := make([]*mlx.Array, 2*len(caches))
|
||||
for i, c := range caches {
|
||||
s[2*i], s[2*i+1] = c.State()
|
||||
}
|
||||
return s
|
||||
}()...)
|
||||
materializeCaches()
|
||||
mlx.Free(temp)
|
||||
processed += n
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
mlx.ClearCache()
|
||||
}
|
||||
logMemory("prefill_done", -1)
|
||||
|
||||
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
fwd := r.Model.Forward(token.ExpandDims(0), caches)
|
||||
@@ -69,7 +145,13 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
}
|
||||
|
||||
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed))
|
||||
mlx.AsyncEval(sample, logprobs)
|
||||
if !lowMemoryDecode {
|
||||
mlx.AsyncEval(sample, logprobs)
|
||||
} else {
|
||||
// Materialize cache updates to prevent transform graph growth.
|
||||
materializeCaches()
|
||||
}
|
||||
logMemory("decode_init", -1)
|
||||
|
||||
var b bytes.Buffer
|
||||
|
||||
@@ -77,12 +159,10 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
|
||||
outputs := make([]int32, 0, request.Options.MaxTokens)
|
||||
for i := range request.Options.MaxTokens {
|
||||
nextSample, nextLogprobs := step(sample)
|
||||
mlx.AsyncEval(nextSample, nextLogprobs)
|
||||
|
||||
if i == 0 {
|
||||
slog.Info("Prompt processing progress", "processed", total, "total", total)
|
||||
mlx.Eval(sample)
|
||||
logMemory("decode_first_eval", i)
|
||||
final.PromptTokensDuration = time.Since(now)
|
||||
now = time.Now()
|
||||
}
|
||||
@@ -94,6 +174,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
final.Token = int(output)
|
||||
final.DoneReason = 0
|
||||
final.CompletionTokens = i
|
||||
mlx.Free(sample, logprobs)
|
||||
break
|
||||
}
|
||||
|
||||
@@ -102,18 +183,43 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
Token: int(output),
|
||||
}
|
||||
|
||||
// For recurrent linear-attention models, avoid async prefetch to reduce
|
||||
// peak memory and clear allocator cache every token.
|
||||
if lowMemoryDecode {
|
||||
mlx.Free(sample, logprobs)
|
||||
if i+1 >= request.Options.MaxTokens {
|
||||
break
|
||||
}
|
||||
mlx.ClearCache()
|
||||
sample, logprobs = step(mlx.FromValues([]int32{output}, 1))
|
||||
// Materialize cache updates to avoid unbounded transform chains.
|
||||
materializeCaches()
|
||||
if i%32 == 0 {
|
||||
logMemory("decode_lowmem_step", i)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
nextSample, nextLogprobs := step(sample)
|
||||
mlx.AsyncEval(nextSample, nextLogprobs)
|
||||
|
||||
mlx.Free(sample, logprobs)
|
||||
if i%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
if i%64 == 0 {
|
||||
logMemory("decode_async_step", i)
|
||||
}
|
||||
|
||||
sample, logprobs = nextSample, nextLogprobs
|
||||
}
|
||||
|
||||
mlx.Free(sample, logprobs)
|
||||
final.CompletionTokensDuration = time.Since(now)
|
||||
request.Responses <- final
|
||||
r.InsertCache(append(inputs, outputs...), caches)
|
||||
finalizeRequestCaches(usePromptCache,
|
||||
func() { r.InsertCache(append(inputs, outputs...), caches) },
|
||||
freeCaches,
|
||||
logMemory,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
83
x/mlxrunner/pipeline_helpers_test.go
Normal file
83
x/mlxrunner/pipeline_helpers_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlxrunner
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestPrefillChunkSize(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MLX_PREFILL_CHUNK", "")
|
||||
if got := prefillChunkSize(false); got != 2<<10 {
|
||||
t.Fatalf("prefillChunkSize(false) = %d, want %d", got, 2<<10)
|
||||
}
|
||||
if got := prefillChunkSize(true); got != 32 {
|
||||
t.Fatalf("prefillChunkSize(true) = %d, want %d", got, 32)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefillChunkSizeEnvOverride(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MLX_PREFILL_CHUNK", "96")
|
||||
if got := prefillChunkSize(false); got != 96 {
|
||||
t.Fatalf("prefillChunkSize(false) with env = %d, want %d", got, 96)
|
||||
}
|
||||
if got := prefillChunkSize(true); got != 96 {
|
||||
t.Fatalf("prefillChunkSize(true) with env = %d, want %d", got, 96)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMLXDebugMemoryEnabled(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MLX_DEBUG_MEMORY", "")
|
||||
if mlxDebugMemoryEnabled() {
|
||||
t.Fatal("mlxDebugMemoryEnabled() = true, want false")
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MLX_DEBUG_MEMORY", "1")
|
||||
if !mlxDebugMemoryEnabled() {
|
||||
t.Fatal("mlxDebugMemoryEnabled() = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalizeRequestCachesUsesPromptCachePath(t *testing.T) {
|
||||
insertCalls := 0
|
||||
freeCalls := 0
|
||||
logPhase := ""
|
||||
|
||||
finalizeRequestCaches(
|
||||
true,
|
||||
func() { insertCalls++ },
|
||||
func() { freeCalls++ },
|
||||
func(phase string, _ int) { logPhase = phase },
|
||||
)
|
||||
|
||||
if insertCalls != 1 {
|
||||
t.Fatalf("insert calls = %d, want 1", insertCalls)
|
||||
}
|
||||
if freeCalls != 0 {
|
||||
t.Fatalf("free calls = %d, want 0", freeCalls)
|
||||
}
|
||||
if logPhase != "request_done_cached" {
|
||||
t.Fatalf("log phase = %q, want %q", logPhase, "request_done_cached")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFinalizeRequestCachesUsesFreePath(t *testing.T) {
|
||||
insertCalls := 0
|
||||
freeCalls := 0
|
||||
logPhase := ""
|
||||
|
||||
finalizeRequestCaches(
|
||||
false,
|
||||
func() { insertCalls++ },
|
||||
func() { freeCalls++ },
|
||||
func(phase string, _ int) { logPhase = phase },
|
||||
)
|
||||
|
||||
if insertCalls != 0 {
|
||||
t.Fatalf("insert calls = %d, want 0", insertCalls)
|
||||
}
|
||||
if freeCalls != 1 {
|
||||
t.Fatalf("free calls = %d, want 1", freeCalls)
|
||||
}
|
||||
if logPhase != "request_done_freed" {
|
||||
t.Fatalf("log phase = %q, want %q", logPhase, "request_done_freed")
|
||||
}
|
||||
}
|
||||
@@ -64,6 +64,38 @@ type Runner struct {
|
||||
CacheEntries map[int32]*CacheEntry
|
||||
}
|
||||
|
||||
func releaseTensorMap(tensors map[string]*mlx.Array, keep map[*mlx.Array]struct{}) (count int, bytes int) {
|
||||
if len(tensors) == 0 {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
seen := make(map[*mlx.Array]bool, len(tensors))
|
||||
toRelease := make([]*mlx.Array, 0, len(tensors))
|
||||
for name, arr := range tensors {
|
||||
if arr == nil || !arr.Valid() {
|
||||
delete(tensors, name)
|
||||
continue
|
||||
}
|
||||
if keep != nil {
|
||||
if _, ok := keep[arr]; ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
delete(tensors, name)
|
||||
if seen[arr] {
|
||||
continue
|
||||
}
|
||||
seen[arr] = true
|
||||
toRelease = append(toRelease, arr)
|
||||
}
|
||||
|
||||
if len(toRelease) == 0 {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
return len(toRelease), mlx.Release(toRelease...)
|
||||
}
|
||||
|
||||
func (r *Runner) Load(modelName string) error {
|
||||
root, err := model.Open(modelName)
|
||||
if err != nil {
|
||||
@@ -85,9 +117,29 @@ func (r *Runner) Load(modelName string) error {
|
||||
// Assign weights to model (model-specific logic)
|
||||
loadWeights := base.Weights(m)
|
||||
if err := loadWeights(tensors); err != nil {
|
||||
if count, bytes := releaseTensorMap(tensors, nil); count > 0 {
|
||||
slog.Info("Released tensors after load failure", "count", count, "bytes", mlx.PrettyBytes(bytes))
|
||||
mlx.ClearCache()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Pin only model-owned tensor roots. Pinning the full transitive graph can
|
||||
// retain large load-time intermediates and inflate steady-state memory.
|
||||
roots := mlx.Collect(m)
|
||||
mlx.Pin(roots...)
|
||||
|
||||
keep := make(map[*mlx.Array]struct{})
|
||||
for _, arr := range roots {
|
||||
if arr != nil && arr.Valid() {
|
||||
keep[arr] = struct{}{}
|
||||
}
|
||||
}
|
||||
if count, bytes := releaseTensorMap(tensors, keep); count > 0 {
|
||||
slog.Info("Released unused model tensors", "count", count, "bytes", mlx.PrettyBytes(bytes))
|
||||
mlx.ClearCache()
|
||||
}
|
||||
|
||||
r.Model = m
|
||||
r.Tokenizer = m.Tokenizer()
|
||||
return nil
|
||||
|
||||
1254
x/models/qwen3_5/qwen3_5.go
Normal file
1254
x/models/qwen3_5/qwen3_5.go
Normal file
File diff suppressed because it is too large
Load Diff
120
x/models/qwen3_5/qwen3_5_test.go
Normal file
120
x/models/qwen3_5/qwen3_5_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
//go:build mlx
|
||||
|
||||
package qwen3_5
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
)
|
||||
|
||||
func TestParseConfigNestedDefaults(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"model_type": "Qwen3_5MoeForConditionalGeneration",
|
||||
"text_config": {
|
||||
"hidden_size": 4096,
|
||||
"intermediate_size": 14336,
|
||||
"num_hidden_layers": 8,
|
||||
"num_attention_heads": 32,
|
||||
"num_key_value_heads": 8,
|
||||
"head_dim": 128,
|
||||
"linear_num_value_heads": 64,
|
||||
"linear_num_key_heads": 16,
|
||||
"linear_key_head_dim": 128,
|
||||
"linear_value_head_dim": 128,
|
||||
"linear_conv_kernel_dim": 4,
|
||||
"num_experts": 16,
|
||||
"num_experts_per_tok": 4,
|
||||
"moe_intermediate_size": 2048,
|
||||
"shared_expert_intermediate_size": 4096,
|
||||
"rope_parameters": {
|
||||
"rope_theta": 500000,
|
||||
"partial_rotary_factor": 0.5
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
cfg, err := parseConfig(data)
|
||||
if err != nil {
|
||||
t.Fatalf("parseConfig failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.RopeTheta != 500000 {
|
||||
t.Fatalf("rope theta mismatch: got %v", cfg.RopeTheta)
|
||||
}
|
||||
if cfg.RopeDim != 64 {
|
||||
t.Fatalf("rope dim mismatch: got %d want 64", cfg.RopeDim)
|
||||
}
|
||||
if cfg.FullAttentionInterval != 4 {
|
||||
t.Fatalf("full_attention_interval default mismatch: got %d want 4", cfg.FullAttentionInterval)
|
||||
}
|
||||
if !cfg.NormTopKProb {
|
||||
t.Fatalf("norm_topk_prob should default to true for MoE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLayerSelectionHelpers(t *testing.T) {
|
||||
cfg := &Config{
|
||||
NumHiddenLayers: 6,
|
||||
FullAttentionInterval: 3,
|
||||
NumExperts: 8,
|
||||
DecoderSparseStep: 2,
|
||||
MLPOnlyLayers: []int32{1},
|
||||
}
|
||||
|
||||
if !layerIsLinear(cfg, 0) {
|
||||
t.Fatalf("layer 0 should be linear")
|
||||
}
|
||||
if layerIsLinear(cfg, 2) {
|
||||
t.Fatalf("layer 2 should be full attention")
|
||||
}
|
||||
|
||||
if layerUsesMoE(cfg, 1) {
|
||||
t.Fatalf("layer 1 should be forced dense by mlp_only_layers")
|
||||
}
|
||||
if !layerUsesMoE(cfg, 3) {
|
||||
t.Fatalf("layer 3 should use moe with decoder_sparse_step=2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRuntimeToggles(t *testing.T) {
|
||||
m := &Model{}
|
||||
if !m.DisablePromptCache() {
|
||||
t.Fatal("DisablePromptCache() = false, want true")
|
||||
}
|
||||
if m.EnableCompile() {
|
||||
t.Fatal("EnableCompile() = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCachesLayout(t *testing.T) {
|
||||
m := &Model{
|
||||
Config: &Config{
|
||||
LinearConvKernelDim: 4,
|
||||
LinearNumKeyHeads: 2,
|
||||
LinearKeyHeadDim: 8,
|
||||
LinearNumValueHeads: 4,
|
||||
LinearValueHeadDim: 16,
|
||||
},
|
||||
Layers: []*Layer{
|
||||
{IsLinear: true},
|
||||
{IsLinear: false},
|
||||
{IsLinear: true},
|
||||
},
|
||||
}
|
||||
|
||||
caches := m.NewCaches()
|
||||
if len(caches) != len(m.Layers) {
|
||||
t.Fatalf("len(caches) = %d, want %d", len(caches), len(m.Layers))
|
||||
}
|
||||
|
||||
if _, ok := caches[0].(*cache.RecurrentCache); !ok {
|
||||
t.Fatalf("cache[0] = %T, want *cache.RecurrentCache", caches[0])
|
||||
}
|
||||
if _, ok := caches[1].(*cache.KVCache); !ok {
|
||||
t.Fatalf("cache[1] = %T, want *cache.KVCache", caches[1])
|
||||
}
|
||||
if _, ok := caches[2].(*cache.RecurrentCache); !ok {
|
||||
t.Fatalf("cache[2] = %T, want *cache.RecurrentCache", caches[2])
|
||||
}
|
||||
}
|
||||
16
x/models/qwen3_5_moe/qwen3_5_moe.go
Normal file
16
x/models/qwen3_5_moe/qwen3_5_moe.go
Normal file
@@ -0,0 +1,16 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package qwen3_5_moe registers Qwen 3.5 MoE architecture aliases.
|
||||
package qwen3_5_moe
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/models/qwen3_5"
|
||||
)
|
||||
|
||||
func init() {
|
||||
base.Register("Qwen3_5MoeForConditionalGeneration", qwen3_5.NewModel)
|
||||
base.Register("Qwen3_5MoeForCausalLM", qwen3_5.NewModel)
|
||||
base.Register("Qwen3NextMoeForConditionalGeneration", qwen3_5.NewModel)
|
||||
base.Register("Qwen3NextMoeForCausalLM", qwen3_5.NewModel)
|
||||
}
|
||||
Reference in New Issue
Block a user