Compare commits

..

1 Commits

Author SHA1 Message Date
jmorganca
330b19b73f server: chunk quantization writes to reduce create memory usage 2026-02-28 23:21:37 -08:00
25 changed files with 181 additions and 3311 deletions

View File

@@ -1453,12 +1453,10 @@ type ImageData struct {
}
type CompletionRequest struct {
Prompt string
Format json.RawMessage
Images []ImageData
Options *api.Options
Think *api.ThinkValue
ExplicitOptions map[string]struct{}
Prompt string
Format json.RawMessage
Images []ImageData
Options *api.Options
Grammar string // set before sending the request to the subprocess
Shift bool

View File

@@ -21,33 +21,76 @@ type quantizer struct {
progressFn func(n uint64)
}
const quantizationChunkElements uint64 = 4 * 1024 * 1024
func (q quantizer) WriteTo(w io.Writer) (int64, error) {
quantize := q.from.Kind != q.to.Kind
sr := io.NewSectionReader(q, int64(q.offset), int64(q.from.Size()))
if !quantize {
n, err := io.Copy(w, sr)
q.progressFn(q.from.Size())
if q.progressFn != nil {
q.progressFn(q.from.Size())
}
return n, err
}
data, err := io.ReadAll(sr)
if err != nil {
slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err)
return 0, fmt.Errorf("unable to read tensor %s from %s: %s", q.from.Name, q.Name(), err)
if len(q.from.Shape) == 0 || q.from.Shape[0] == 0 {
return 0, fmt.Errorf("tensor %s has invalid shape %v", q.from.Name, q.from.Shape)
}
if uint64(len(data)) < q.from.Size() {
return 0, fmt.Errorf("tensor %s data size %d is less than expected %d from shape %v", q.from.Name, len(data), q.from.Size(), q.from.Shape)
fromType := fsggml.TensorType(q.from.Kind)
toType := fsggml.TensorType(q.to.Kind)
nPerRow := q.from.Shape[0]
totalElements := q.from.Elements()
if totalElements%nPerRow != 0 {
return 0, fmt.Errorf("tensor %s has non-row-aligned shape %v", q.from.Name, q.from.Shape)
}
var f32s []float32
newType := fsggml.TensorType(q.to.Kind)
if fsggml.TensorType(q.from.Kind) == fsggml.TensorTypeF32 {
f32s = unsafe.Slice((*float32)(unsafe.Pointer(&data[0])), q.from.Elements())
} else {
f32s = ggml.ConvertToF32(data, q.from.Kind, q.from.Elements())
inRowSize := fromType.RowSize(nPerRow)
if inRowSize == 0 {
return 0, fmt.Errorf("tensor %s has unsupported source type %v", q.from.Name, fromType)
}
data = ggml.Quantize(newType, f32s, q.from.Shape)
n, err := w.Write(data)
q.progressFn(q.from.Size())
return int64(n), err
totalRows := totalElements / nPerRow
rowsPerChunk := max(quantizationChunkElements/nPerRow, uint64(1))
chunkBuf := make([]byte, inRowSize*rowsPerChunk)
var written int64
for row := uint64(0); row < totalRows; {
chunkRows := min(rowsPerChunk, totalRows-row)
chunkBytes := inRowSize * chunkRows
data := chunkBuf[:chunkBytes]
if _, err := io.ReadFull(sr, data); err != nil {
slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err)
return written, fmt.Errorf("unable to read tensor %s from %s: %w", q.from.Name, q.Name(), err)
}
var f32s []float32
chunkElements := chunkRows * nPerRow
if fromType == fsggml.TensorTypeF32 {
f32s = unsafe.Slice((*float32)(unsafe.Pointer(&data[0])), chunkElements)
} else {
f32s = ggml.ConvertToF32(data, q.from.Kind, chunkElements)
}
quantized := ggml.Quantize(toType, f32s, []uint64{nPerRow, chunkRows})
n, err := w.Write(quantized)
written += int64(n)
if err != nil {
return written, err
}
if n != len(quantized) {
return written, io.ErrShortWrite
}
if q.progressFn != nil {
q.progressFn(chunkBytes)
}
row += chunkRows
}
return written, nil
}
type quantizeState struct {

View File

@@ -130,35 +130,6 @@ func (s *Server) modelOptions(model *Model, requestOpts map[string]any) (api.Opt
return opts, nil
}
func explicitOptions(modelOpts, requestOpts map[string]any) map[string]struct{} {
keys := []string{
"temperature",
"top_p",
"min_p",
"top_k",
"repeat_last_n",
"repeat_penalty",
"presence_penalty",
"frequency_penalty",
}
explicit := make(map[string]struct{}, len(keys))
for _, key := range keys {
if optionSpecified(modelOpts, requestOpts, key) {
explicit[key] = struct{}{}
}
}
return explicit
}
func optionSpecified(modelOpts, requestOpts map[string]any, key string) bool {
if _, ok := requestOpts[key]; ok {
return true
}
_, ok := modelOpts[key]
return ok
}
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
@@ -568,16 +539,14 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var sb strings.Builder
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
Think: req.Think,
ExplicitOptions: explicitOptions(m.Options, req.Options),
Shift: req.Shift == nil || *req.Shift,
Truncate: req.Truncate == nil || *req.Truncate,
Logprobs: req.Logprobs,
TopLogprobs: req.TopLogprobs,
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
Shift: req.Shift == nil || *req.Shift,
Truncate: req.Truncate == nil || *req.Truncate,
Logprobs: req.Logprobs,
TopLogprobs: req.TopLogprobs,
}, func(cr llm.CompletionResponse) {
res := api.GenerateResponse{
Model: req.Model,
@@ -2329,16 +2298,14 @@ func (s *Server) ChatHandler(c *gin.Context) {
// sets up new context given parent context per request
ctx, cancel := context.WithCancel(c.Request.Context())
err := r.Completion(ctx, llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: currentFormat,
Options: opts,
Think: req.Think,
ExplicitOptions: explicitOptions(m.Options, req.Options),
Shift: req.Shift == nil || *req.Shift,
Truncate: truncate,
Logprobs: req.Logprobs,
TopLogprobs: req.TopLogprobs,
Prompt: prompt,
Images: images,
Format: currentFormat,
Options: opts,
Shift: req.Shift == nil || *req.Shift,
Truncate: truncate,
Logprobs: req.Logprobs,
TopLogprobs: req.TopLogprobs,
}, func(r llm.CompletionResponse) {
res := api.ChatResponse{
Model: req.Model,

View File

@@ -288,18 +288,6 @@ func normalizeQuantType(quantize string) string {
}
}
func isStackedExpertWeight(name string) bool {
// Combined/stacked expert tensors may be emitted either as "...proj.weight" (per-expert)
// or "...proj" (pre-stacked packed tensor).
if strings.HasSuffix(name, ".bias") || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".qbias") {
return false
}
return strings.Contains(name, ".mlp.switch_mlp.") ||
strings.Contains(name, ".mlp.experts.") ||
strings.Contains(name, ".mlp.shared_experts.")
}
// GetTensorQuantization returns the appropriate quantization type for a tensor.
// Returns "" if the tensor should not be quantized.
// This implements mixed-precision quantization:
@@ -308,25 +296,18 @@ func isStackedExpertWeight(name string) bool {
// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
// - Norms, embeddings, biases, routing gates: no quantization
func GetTensorQuantization(name string, shape []int32, quantize string) string {
stackedExpert := isStackedExpertWeight(name)
// Use basic name-based check first
if !stackedExpert && !ShouldQuantize(name, "") {
if !ShouldQuantize(name, "") {
return ""
}
// Quantize standard linear weights (2D). Also allow stacked expert weights (3D),
// e.g. qwen switch_mlp / experts combined tensors.
if len(shape) != 2 && !(len(shape) == 3 && stackedExpert) {
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
if len(shape) != 2 {
return ""
}
// Skip small tensors (less than 1024 elements) - not worth quantizing
var elems int64 = 1
for _, d := range shape {
elems *= int64(d)
}
if elems < 1024 {
if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 {
return ""
}

View File

@@ -557,10 +557,6 @@ func TestShouldQuantizeTensor(t *testing.T) {
// 3D+ tensors should not be quantized
{"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
{"stacked expert switch_mlp gate_up 3D int8", "model.layers.1.mlp.switch_mlp.gate_up_proj.weight", []int32{64, 22016, 4096}, "int8", true},
{"stacked expert experts down_proj 3D int8", "model.layers.1.mlp.experts.down_proj.weight", []int32{64, 4096, 14336}, "int8", true},
{"stacked expert combined gate_up 3D int8", "model.language_model.layers.0.mlp.experts.gate_up_proj", []int32{256, 1024, 2048}, "int8", true},
{"stacked expert combined down_proj 3D int8", "model.language_model.layers.0.mlp.experts.down_proj", []int32{256, 2048, 512}, "int8", true},
// Embeddings should not be quantized regardless of shape
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
@@ -623,44 +619,6 @@ func TestExpertGroupPrefix(t *testing.T) {
}
}
func TestGetTensorQuantization_StackedExpert3D(t *testing.T) {
gateUp := GetTensorQuantization(
"model.layers.1.mlp.switch_mlp.gate_up_proj.weight",
[]int32{64, 22016, 4096},
"int4",
)
if gateUp != "int4" {
t.Fatalf("gate_up_proj quantization = %q, want %q", gateUp, "int4")
}
down := GetTensorQuantization(
"model.layers.1.mlp.experts.down_proj.weight",
[]int32{64, 4096, 14336},
"int4",
)
if down != "int8" {
t.Fatalf("down_proj quantization = %q, want %q", down, "int8")
}
combinedGateUp := GetTensorQuantization(
"model.language_model.layers.0.mlp.experts.gate_up_proj",
[]int32{256, 1024, 2048},
"int8",
)
if combinedGateUp != "int8" {
t.Fatalf("combined gate_up_proj quantization = %q, want %q", combinedGateUp, "int8")
}
combinedDown := GetTensorQuantization(
"model.language_model.layers.0.mlp.experts.down_proj",
[]int32{256, 2048, 512},
"int4",
)
if combinedDown != "int8" {
t.Fatalf("combined down_proj quantization = %q, want %q", combinedDown, "int8")
}
}
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
dir := t.TempDir()

View File

@@ -30,64 +30,21 @@ type cacheSession struct {
remaining []int32
}
func (c *kvCache) free() {
for i, kv := range c.caches {
if kv == nil {
continue
}
kv.Free()
c.caches[i] = nil
}
c.caches = nil
c.tokens = nil
}
func (c *kvCache) cachesCanTrim() bool {
for _, kv := range c.caches {
if kv == nil {
continue
}
if !kv.CanTrim() {
return false
}
}
return true
}
func (c *kvCache) trimToPrefix(prefix int) {
for _, kv := range c.caches {
if kv == nil || !kv.CanTrim() {
continue
}
if trim := kv.Offset() - prefix; trim > 0 {
kv.Trim(trim)
}
}
if prefix < len(c.tokens) {
c.tokens = c.tokens[:prefix]
}
}
// begin prepares caches for a new request. It finds the nearest
// matching cache or creates new caches if none match.
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
ensureCaches := func() {
if len(c.caches) != 0 {
return
}
if len(c.caches) == 0 {
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
c.caches = cacheFactory.NewCaches()
return
}
c.caches = make([]cache.Cache, m.NumLayers())
for i := range c.caches {
c.caches[i] = cache.NewKVCache()
} else {
c.caches = make([]cache.Cache, m.NumLayers())
for i := range c.caches {
c.caches[i] = cache.NewKVCache()
}
}
}
ensureCaches()
remaining := c.findRemaining(inputs)
ensureCaches()
return &cacheSession{
cache: c,
@@ -99,34 +56,18 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
// close saves the token state if the forward pass ran.
func (s *cacheSession) close() {
if len(s.caches) == 0 {
return
}
offset := -1
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
for _, kv := range s.caches {
if kv == nil {
continue
if offset := s.caches[0].Offset(); offset > 0 {
// Ensure that if we have run the forward pass and set the metadata
// that we also actually have the data
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
for _, c := range s.caches {
k, v := c.State()
arrays = append(arrays, k, v)
}
if off := kv.Offset(); offset < 0 || off < offset {
offset = off
}
arrays = append(arrays, kv.Materialize()...)
}
if offset <= 0 {
return
}
mlx.AsyncEval(arrays...)
// Ensure that if we have run the forward pass and set the metadata
// that we also actually have the data.
mlx.AsyncEval(arrays...)
stored := append(s.inputs, s.outputs...)
if offset > len(stored) {
offset = len(stored)
s.cache.tokens = append(s.inputs, s.outputs...)[:offset]
}
s.cache.tokens = stored[:offset]
}
// findRemaining finds the longest common prefix between tokens and the cached
@@ -144,13 +85,11 @@ func (c *kvCache) findRemaining(tokens []int32) []int32 {
}
if prefix < len(c.tokens) {
if c.cachesCanTrim() {
c.trimToPrefix(prefix)
} else {
c.free()
slog.Info("Cache miss", "left", len(tokens), "matched", prefix, "reason", "non_trimmable_divergence")
return tokens
trim := len(c.tokens) - prefix
for _, kv := range c.caches {
kv.Trim(trim)
}
c.tokens = c.tokens[:prefix]
}
if prefix == 0 {
@@ -165,21 +104,10 @@ func (c *kvCache) log() {
if len(c.caches) == 0 {
return
}
offset := -1
var totalBytes int
for _, kv := range c.caches {
if kv == nil {
continue
}
if off := kv.Offset(); offset < 0 || off < offset {
offset = off
}
for _, a := range kv.Materialize() {
totalBytes += a.NumBytes()
}
k, v := kv.State()
totalBytes += k.NumBytes() + v.NumBytes()
}
if offset < 0 {
return
}
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", offset, mlx.PrettyBytes(totalBytes)))
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
}

View File

@@ -10,8 +10,6 @@ import (
type Cache interface {
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
State() (keys, values *mlx.Array)
Materialize() []*mlx.Array
CanTrim() bool
Trim(int) int
Clone() Cache
Free()
@@ -69,20 +67,6 @@ func (c *KVCache) State() (*mlx.Array, *mlx.Array) {
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
}
// Materialize returns the backing key/value buffers currently held by the cache.
func (c *KVCache) Materialize() []*mlx.Array {
out := make([]*mlx.Array, 0, 2)
if c.keys != nil && c.keys.Valid() {
out = append(out, c.keys)
}
if c.values != nil && c.values.Valid() {
out = append(out, c.values)
}
return out
}
func (c *KVCache) CanTrim() bool { return true }
func (c *KVCache) Trim(n int) int {
n = min(c.offset, n)
c.offset -= n
@@ -206,8 +190,6 @@ func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) {
return c.keys, c.values
}
func (c *RotatingKVCache) CanTrim() bool { return true }
func (c *RotatingKVCache) Trim(n int) int {
n = min(c.offset, n)
c.offset -= n

View File

@@ -1,220 +0,0 @@
//go:build mlx
package cache
import "github.com/ollama/ollama/x/mlxrunner/mlx"
// RecurrentCache stores state for linear-recurrent layers.
//
// Conv state shape: [B, convTail, convDim]
// Delta state shape: [B, numVHeads, headVDim, headKDim]
type RecurrentCache struct {
convState *mlx.Array
deltaState *mlx.Array
offset int
convTail int
convDim int
numVHeads int
headVDim int
headKDim int
}
func (c *RecurrentCache) setStateMaterialized(dst **mlx.Array, v *mlx.Array) {
if v == nil || !v.Valid() {
return
}
if *dst == v {
return
}
// Break dependency chains so recurrent state does not retain the full
// per-token compute graph over time.
snap := mlx.Snapshot(v)
mlx.Eval(snap)
old := *dst
*dst = snap
mlx.Pin(snap)
// Drop references to the previous cached state root and transient incoming
// graph root now that a detached snapshot is retained in cache. Actual
// cleanup happens at the runner's normal sweep points.
if old != nil && old != snap {
mlx.Unpin(old)
}
if v != snap && v != old {
mlx.Unpin(v)
}
}
func (c *RecurrentCache) setStateRaw(dst **mlx.Array, v *mlx.Array) {
if v == nil || !v.Valid() {
return
}
if *dst == v {
return
}
old := *dst
*dst = v
mlx.Pin(v)
if old != nil && old != v {
mlx.Unpin(old)
}
}
func (c *RecurrentCache) setStateDetached(dst **mlx.Array, v *mlx.Array, ensureContiguous bool) {
if v == nil || !v.Valid() {
return
}
if *dst == v {
return
}
root := v
if ensureContiguous {
root = mlx.Contiguous(v, false)
}
detached := mlx.Detach(root)
old := *dst
*dst = detached
mlx.Pin(detached)
if old != nil && old != detached {
mlx.Unpin(old)
}
// Intentionally do not force-release root/v here. In the fast path, the detached
// handle aliases the same MLX value and may still be lazily computed. Releasing the
// source handles can invalidate the cached state before the next eval/sweep point.
}
func snapshotPinned(a *mlx.Array) *mlx.Array {
if a == nil || !a.Valid() {
return nil
}
snap := mlx.Snapshot(a)
mlx.Eval(snap)
mlx.Pin(snap)
return snap
}
func NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim int32) *RecurrentCache {
return &RecurrentCache{
convTail: int(convTail),
convDim: int(convDim),
numVHeads: int(numVHeads),
headVDim: int(headVDim),
headKDim: int(headKDim),
}
}
func (c *RecurrentCache) ensure(batch int, dtype mlx.DType) {
if batch <= 0 {
batch = 1
}
needConv := c.convState == nil || !c.convState.Valid() || c.convState.DType() != dtype ||
c.convState.Dim(0) != batch || c.convState.Dim(1) != c.convTail || c.convState.Dim(2) != c.convDim
needDelta := c.deltaState == nil || !c.deltaState.Valid() || c.deltaState.DType() != dtype ||
c.deltaState.Dim(0) != batch || c.deltaState.Dim(1) != c.numVHeads || c.deltaState.Dim(2) != c.headVDim || c.deltaState.Dim(3) != c.headKDim
if !needConv && !needDelta {
return
}
if needConv {
c.setStateRaw(&c.convState, mlx.Zeros(dtype, batch, c.convTail, c.convDim))
}
if needDelta {
c.setStateRaw(&c.deltaState, mlx.Zeros(dtype, batch, c.numVHeads, c.headVDim, c.headKDim))
}
}
func (c *RecurrentCache) ConvState(batch int, dtype mlx.DType) *mlx.Array {
c.ensure(batch, dtype)
return c.convState
}
func (c *RecurrentCache) SetConvState(v *mlx.Array) {
c.setStateMaterialized(&c.convState, v)
}
// SetConvStateFast stores conv state without forcing an immediate snapshot/eval.
// Use only for decode hot paths that accept higher transient memory until the next
// sync/sweep point. The conv-state input is usually a slice view, so request a
// compact contiguous copy to avoid pinning the whole source buffer.
func (c *RecurrentCache) SetConvStateFast(v *mlx.Array) {
c.setStateDetached(&c.convState, v, true)
}
func (c *RecurrentCache) DeltaState(batch int, dtype mlx.DType) *mlx.Array {
c.ensure(batch, dtype)
return c.deltaState
}
func (c *RecurrentCache) SetDeltaState(v *mlx.Array) {
c.setStateMaterialized(&c.deltaState, v)
}
// SetDeltaStateFast stores delta state without forcing an immediate snapshot/eval.
// Use only for decode hot paths that accept higher transient memory until the next
// sync/sweep point.
func (c *RecurrentCache) SetDeltaStateFast(v *mlx.Array) {
c.setStateDetached(&c.deltaState, v, false)
}
func (c *RecurrentCache) Advance(n int) {
c.offset += n
}
func (c *RecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
return keys, values
}
func (c *RecurrentCache) State() (*mlx.Array, *mlx.Array) {
return c.convState, c.deltaState
}
// Materialize returns the recurrent state roots (conv and delta) held by the cache.
func (c *RecurrentCache) Materialize() []*mlx.Array {
out := make([]*mlx.Array, 0, 2)
if c.convState != nil && c.convState.Valid() {
out = append(out, c.convState)
}
if c.deltaState != nil && c.deltaState.Valid() {
out = append(out, c.deltaState)
}
return out
}
func (c *RecurrentCache) CanTrim() bool { return false }
func (c *RecurrentCache) Trim(n int) int {
// Recurrent state is not directly trimmable. Divergent prefixes must drop the cache.
_ = n
return 0
}
func (c *RecurrentCache) Clone() Cache {
clone := &RecurrentCache{
offset: c.offset,
convTail: c.convTail,
convDim: c.convDim,
numVHeads: c.numVHeads,
headVDim: c.headVDim,
headKDim: c.headKDim,
convState: snapshotPinned(c.convState),
deltaState: snapshotPinned(c.deltaState),
}
return clone
}
func (c *RecurrentCache) Free() {
mlx.Unpin(c.convState, c.deltaState)
c.convState, c.deltaState = nil, nil
c.offset = 0
}
func (c *RecurrentCache) Offset() int { return c.offset }
func (c *RecurrentCache) Len() int { return c.offset }

View File

@@ -182,20 +182,15 @@ func (c *Client) waitUntilRunning() error {
// completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization.
type completionRequest struct {
Prompt string `json:"prompt"`
Think *bool `json:"think,omitempty"`
Options *completionOpts `json:"options,omitempty"`
}
type completionOpts struct {
Temperature *float32 `json:"temperature,omitempty"`
TopP *float32 `json:"top_p,omitempty"`
MinP *float32 `json:"min_p,omitempty"`
TopK *int `json:"top_k,omitempty"`
RepeatLastN *int `json:"repeat_last_n,omitempty"`
RepeatPenalty *float32 `json:"repeat_penalty,omitempty"`
PresencePenalty *float32 `json:"presence_penalty,omitempty"`
FrequencyPenalty *float32 `json:"frequency_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
MinP float32 `json:"min_p,omitempty"`
TopK int `json:"top_k,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
}
type CompletionResponse struct {
@@ -233,27 +228,16 @@ func (c *Client) Close() error {
// Completion implements llm.LlamaServer.
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
var think *bool
if req.Think != nil {
enabled := req.Think.Bool()
think = &enabled
}
creq := completionRequest{
Prompt: req.Prompt,
Think: think,
}
if req.Options != nil {
creq.Options = &completionOpts{
Temperature: float32Ptr(req.Options.Temperature, hasExplicitOption(req.ExplicitOptions, "temperature")),
TopP: float32Ptr(req.Options.TopP, hasExplicitOption(req.ExplicitOptions, "top_p")),
MinP: float32Ptr(req.Options.MinP, hasExplicitOption(req.ExplicitOptions, "min_p")),
TopK: intPtr(req.Options.TopK, hasExplicitOption(req.ExplicitOptions, "top_k")),
RepeatLastN: intPtr(req.Options.RepeatLastN, hasExplicitOption(req.ExplicitOptions, "repeat_last_n")),
RepeatPenalty: float32Ptr(req.Options.RepeatPenalty, hasExplicitOption(req.ExplicitOptions, "repeat_penalty")),
PresencePenalty: float32Ptr(req.Options.PresencePenalty, hasExplicitOption(req.ExplicitOptions, "presence_penalty")),
FrequencyPenalty: float32Ptr(req.Options.FrequencyPenalty, hasExplicitOption(req.ExplicitOptions, "frequency_penalty")),
NumPredict: req.Options.NumPredict,
Temperature: req.Options.Temperature,
TopP: req.Options.TopP,
MinP: req.Options.MinP,
TopK: req.Options.TopK,
NumPredict: req.Options.NumPredict,
}
}
@@ -312,25 +296,6 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
return scanner.Err()
}
func hasExplicitOption(explicit map[string]struct{}, key string) bool {
_, ok := explicit[key]
return ok
}
func float32Ptr(v float32, ok bool) *float32 {
if !ok {
return nil
}
return &v
}
func intPtr(v int, ok bool) *int {
if !ok {
return nil
}
return &v
}
func (c *Client) ContextLength() int {
return int(c.contextLength.Load())
}

View File

@@ -1,167 +0,0 @@
package mlxrunner
import (
"context"
"encoding/json"
"io"
"net/http"
"strings"
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
)
func TestCompletionForwardsThink(t *testing.T) {
boolPtr := func(v bool) *bool { return &v }
testCases := []struct {
name string
think *api.ThinkValue
want *bool
}{
{name: "unset", think: nil, want: nil},
{name: "enabled", think: &api.ThinkValue{Value: true}, want: boolPtr(true)},
{name: "disabled", think: &api.ThinkValue{Value: false}, want: boolPtr(false)},
{name: "level maps to enabled", think: &api.ThinkValue{Value: "high"}, want: boolPtr(true)},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var got completionRequest
rt := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path != "/completion" {
t.Fatalf("request path = %q, want %q", r.URL.Path, "/completion")
}
if err := json.NewDecoder(r.Body).Decode(&got); err != nil {
return nil, err
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("{\"done\":true}\n")),
Request: r,
}, nil
})
c := &Client{
port: 11434,
client: &http.Client{
Transport: rt,
},
}
err := c.Completion(context.Background(), llm.CompletionRequest{
Prompt: "hello",
Think: tc.think,
}, func(llm.CompletionResponse) {})
if err != nil {
t.Fatalf("completion request failed: %v", err)
}
if got.Prompt != "hello" {
t.Fatalf("prompt = %q, want %q", got.Prompt, "hello")
}
switch {
case tc.want == nil && got.Think != nil:
t.Fatalf("think = %v, want nil", *got.Think)
case tc.want != nil && got.Think == nil:
t.Fatalf("think = nil, want %v", *tc.want)
case tc.want != nil && got.Think != nil && *tc.want != *got.Think:
t.Fatalf("think = %v, want %v", *got.Think, *tc.want)
}
})
}
}
func TestCompletionForwardsOnlySpecifiedSamplingOptions(t *testing.T) {
var got completionRequest
rt := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if err := json.NewDecoder(r.Body).Decode(&got); err != nil {
return nil, err
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("{\"done\":true}\n")),
Request: r,
}, nil
})
c := &Client{
port: 11434,
client: &http.Client{
Transport: rt,
},
}
opts := &api.Options{
Temperature: 1.0,
TopP: 0.95,
MinP: 0.1,
TopK: 20,
RepeatLastN: 128,
RepeatPenalty: 1.2,
PresencePenalty: 1.5,
FrequencyPenalty: 0.25,
NumPredict: 64,
}
err := c.Completion(context.Background(), llm.CompletionRequest{
Prompt: "hello",
Options: opts,
ExplicitOptions: map[string]struct{}{
"temperature": {},
"top_k": {},
"repeat_penalty": {},
"presence_penalty": {},
},
}, func(llm.CompletionResponse) {})
if err != nil {
t.Fatalf("completion request failed: %v", err)
}
if got.Options == nil {
t.Fatal("options = nil, want serialized options")
}
if got.Options.Temperature == nil || *got.Options.Temperature != opts.Temperature {
t.Fatalf("temperature = %v, want %v", got.Options.Temperature, opts.Temperature)
}
if got.Options.TopK == nil || *got.Options.TopK != opts.TopK {
t.Fatalf("top_k = %v, want %v", got.Options.TopK, opts.TopK)
}
if got.Options.RepeatPenalty == nil || *got.Options.RepeatPenalty != opts.RepeatPenalty {
t.Fatalf("repeat_penalty = %v, want %v", got.Options.RepeatPenalty, opts.RepeatPenalty)
}
if got.Options.PresencePenalty == nil || *got.Options.PresencePenalty != opts.PresencePenalty {
t.Fatalf("presence_penalty = %v, want %v", got.Options.PresencePenalty, opts.PresencePenalty)
}
if got.Options.TopP != nil {
t.Fatalf("top_p = %v, want nil", *got.Options.TopP)
}
if got.Options.MinP != nil {
t.Fatalf("min_p = %v, want nil", *got.Options.MinP)
}
if got.Options.RepeatLastN != nil {
t.Fatalf("repeat_last_n = %v, want nil", *got.Options.RepeatLastN)
}
if got.Options.FrequencyPenalty != nil {
t.Fatalf("frequency_penalty = %v, want nil", *got.Options.FrequencyPenalty)
}
if got.Options.NumPredict != opts.NumPredict {
t.Fatalf("num_predict = %d, want %d", got.Options.NumPredict, opts.NumPredict)
}
}
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return f(r)
}

View File

@@ -7,6 +7,4 @@ import (
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
_ "github.com/ollama/ollama/x/models/llama"
_ "github.com/ollama/ollama/x/models/qwen3"
_ "github.com/ollama/ollama/x/models/qwen3_5"
_ "github.com/ollama/ollama/x/models/qwen3_5_moe"
)

View File

@@ -1,275 +0,0 @@
//go:build mlx
package mlx
// #include <stdlib.h>
// #include "generated.h"
import "C"
import (
"sync"
"sync/atomic"
"unsafe"
)
var (
gatedDeltaMetalKernelOnce sync.Once
gatedDeltaMetalKernel C.mlx_fast_metal_kernel
gatedDeltaMetalDisabled atomic.Bool
)
const gatedDeltaMetalKernelSource = `
auto n = thread_position_in_grid.z;
auto b_idx = n / Hv;
auto hv_idx = n % Hv;
auto hk_idx = hv_idx / (Hv / Hk);
constexpr int n_per_t = Dk / 32;
// q, k: [B, T, Hk, Dk]
auto q_ = q + b_idx * T * Hk * Dk + hk_idx * Dk;
auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk;
// v, y: [B, T, Hv, Dv]
auto v_ = v + b_idx * T * Hv * Dv + hv_idx * Dv;
y += b_idx * T * Hv * Dv + hv_idx * Dv;
auto dk_idx = thread_position_in_threadgroup.x;
auto dv_idx = thread_position_in_grid.y;
// state_in, state_out: [B, Hv, Dv, Dk]
auto i_state = state_in + (n * Dv + dv_idx) * Dk;
auto o_state = state_out + (n * Dv + dv_idx) * Dk;
float state[n_per_t];
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = static_cast<float>(i_state[s_idx]);
}
// g: [B, T, Hv]
auto g_ = g + b_idx * T * Hv;
auto beta_ = beta + b_idx * T * Hv;
for (int t = 0; t < T; ++t) {
float kv_mem = 0.0f;
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = state[i] * g_[hv_idx];
kv_mem += state[i] * k_[s_idx];
}
kv_mem = simd_sum(kv_mem);
auto delta = (v_[dv_idx] - kv_mem) * beta_[hv_idx];
float out = 0.0f;
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = state[i] + k_[s_idx] * delta;
out += state[i] * q_[s_idx];
}
out = simd_sum(out);
if (thread_index_in_simdgroup == 0) {
y[dv_idx] = static_cast<InT>(out);
}
q_ += Hk * Dk;
k_ += Hk * Dk;
v_ += Hv * Dv;
y += Hv * Dv;
g_ += Hv;
beta_ += Hv;
}
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
o_state[s_idx] = static_cast<InT>(state[i]);
}
`
func cStringVector(values []string) (C.mlx_vector_string, func(), bool) {
vec := C.mlx_vector_string_new()
ok := true
for _, s := range values {
cs := C.CString(s)
if C.mlx_vector_string_append_value(vec, cs) != 0 {
ok = false
}
C.free(unsafe.Pointer(cs))
if !ok {
break
}
}
cleanup := func() {
C.mlx_vector_string_free(vec)
}
return vec, cleanup, ok
}
func initGatedDeltaMetalKernel() {
inputs, freeInputs, ok := cStringVector([]string{"q", "k", "v", "g", "beta", "state_in", "T"})
if !ok {
gatedDeltaMetalDisabled.Store(true)
freeInputs()
return
}
defer freeInputs()
outputs, freeOutputs, ok := cStringVector([]string{"y", "state_out"})
if !ok {
gatedDeltaMetalDisabled.Store(true)
freeOutputs()
return
}
defer freeOutputs()
cName := C.CString("gated_delta_step")
defer C.free(unsafe.Pointer(cName))
cSource := C.CString(gatedDeltaMetalKernelSource)
defer C.free(unsafe.Pointer(cSource))
cHeader := C.CString("")
defer C.free(unsafe.Pointer(cHeader))
gatedDeltaMetalKernel = C.mlx_fast_metal_kernel_new(
cName,
inputs,
outputs,
cSource,
cHeader,
C.bool(true),
C.bool(false),
)
}
// GatedDeltaKernel runs a fused Metal kernel for the qwen3.5 recurrent update.
// It returns ok=false on unsupported shapes/devices or kernel setup/apply failure.
func GatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok bool) {
if gatedDeltaMetalDisabled.Load() {
return nil, nil, false
}
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
return nil, nil, false
}
if !q.Valid() || !k.Valid() || !v.Valid() || !g.Valid() || !beta.Valid() || !state.Valid() {
return nil, nil, false
}
qd := q.Dims()
kd := k.Dims()
vd := v.Dims()
gd := g.Dims()
bd := beta.Dims()
sd := state.Dims()
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
return nil, nil, false
}
B, T, Hk, Dk := qd[0], qd[1], qd[2], qd[3]
if T <= 0 || Hk <= 0 || Dk <= 0 || Dk%32 != 0 {
return nil, nil, false
}
if kd[0] != B || kd[1] != T || kd[2] != Hk || kd[3] != Dk {
return nil, nil, false
}
Hv, Dv := vd[2], vd[3]
if vd[0] != B || vd[1] != T || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
return nil, nil, false
}
if gd[0] != B || gd[1] != T || gd[2] != Hv {
return nil, nil, false
}
if bd[0] != B || bd[1] != T || bd[2] != Hv {
return nil, nil, false
}
if sd[0] != B || sd[1] != Hv || sd[2] != Dv || sd[3] != Dk {
return nil, nil, false
}
dtype := q.DType()
if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype || state.DType() != dtype {
return nil, nil, false
}
gatedDeltaMetalKernelOnce.Do(initGatedDeltaMetalKernel)
if gatedDeltaMetalDisabled.Load() {
return nil, nil, false
}
cfg := C.mlx_fast_metal_kernel_config_new()
defer C.mlx_fast_metal_kernel_config_free(cfg)
cInT := C.CString("InT")
defer C.free(unsafe.Pointer(cInT))
if C.mlx_fast_metal_kernel_config_add_template_arg_dtype(cfg, cInT, C.mlx_dtype(dtype)) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
for _, tpl := range []struct {
name string
value int
}{
{name: "Dk", value: Dk},
{name: "Dv", value: Dv},
{name: "Hk", value: Hk},
{name: "Hv", value: Hv},
} {
cn := C.CString(tpl.name)
rc := C.mlx_fast_metal_kernel_config_add_template_arg_int(cfg, cn, C.int(tpl.value))
C.free(unsafe.Pointer(cn))
if rc != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
}
yShape := []C.int{C.int(B), C.int(T), C.int(Hv), C.int(Dv)}
stateShape := []C.int{C.int(B), C.int(Hv), C.int(Dv), C.int(Dk)}
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(yShape), C.size_t(len(yShape)), C.mlx_dtype(dtype)) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(dtype)) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
if C.mlx_fast_metal_kernel_config_set_grid(cfg, 32, C.int(Dv), C.int(B*Hv)) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
threadY := Dv
if threadY > 4 {
threadY = 4
}
if C.mlx_fast_metal_kernel_config_set_thread_group(cfg, 32, C.int(threadY), 1) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
tScalar := FromValue(T)
inputs := []C.mlx_array{
q.ctx,
k.ctx,
v.ctx,
g.ctx,
beta.ctx,
state.ctx,
tScalar.ctx,
}
inVec := C.mlx_vector_array_new_data(unsafe.SliceData(inputs), C.size_t(len(inputs)))
defer C.mlx_vector_array_free(inVec)
outVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(outVec)
if C.mlx_fast_metal_kernel_apply(&outVec, gatedDeltaMetalKernel, inVec, cfg, DefaultStream().ctx) != 0 {
gatedDeltaMetalDisabled.Store(true)
return nil, nil, false
}
if int(C.mlx_vector_array_size(outVec)) < 2 {
return nil, nil, false
}
y = New("GATED_DELTA_METAL_Y")
nextState = New("GATED_DELTA_METAL_STATE")
C.mlx_vector_array_get(&y.ctx, outVec, 0)
C.mlx_vector_array_get(&nextState.ctx, outVec, 1)
return y, nextState, true
}

View File

@@ -19,7 +19,7 @@ func doEval(outputs []*Array, async bool) {
defer C.mlx_vector_array_free(vector)
for _, output := range outputs {
if output != nil && output.Valid() {
if output.Valid() {
C.mlx_vector_array_append_value(vector, output.ctx)
}
}

View File

@@ -93,12 +93,6 @@ func (t *Array) Divide(other *Array) *Array {
return out
}
func (t *Array) Cumsum(axis int, reverse, inclusive bool) *Array {
out := New("CUMSUM")
C.mlx_cumsum(&out.ctx, t.ctx, C.int(axis), C.bool(reverse), C.bool(inclusive), DefaultStream().ctx)
return out
}
func (t *Array) ExpandDims(axis int) *Array {
out := New("EXPAND_DIMS")
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
@@ -129,30 +123,12 @@ func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array {
return out
}
func (t *Array) GreaterEqual(other *Array) *Array {
out := New("GREATER_EQUAL")
C.mlx_greater_equal(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Logsumexp(keepDims bool) *Array {
out := New("LOGSUMEXP")
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) Less(other *Array) *Array {
out := New("LESS")
C.mlx_less(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) LogicalOr(other *Array) *Array {
out := New("LOGICAL_OR")
C.mlx_logical_or(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Matmul(other *Array) *Array {
out := New("MATMUL")
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)

View File

@@ -113,35 +113,6 @@ func Where(condition, a, b *Array) *Array {
return out
}
func Conv1d(x, weight *Array, bias *Array, stride, padding, dilation, groups int32) *Array {
out := New("CONV1D")
C.mlx_conv1d(
&out.ctx,
x.ctx,
weight.ctx,
C.int(stride),
C.int(padding),
C.int(dilation),
C.int(groups),
DefaultStream().ctx,
)
if bias != nil && bias.Valid() {
out = Add(out, bias)
}
return out
}
func Contiguous(a *Array, allowColMajor bool) *Array {
out := New("CONTIGUOUS")
C.mlx_contiguous(&out.ctx, a.ctx, C.bool(allowColMajor), DefaultStream().ctx)
return out
}
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
groups := int32(x.Dim(x.NumDims() - 1))
return Conv1d(x, weight, bias, 1, 0, 1, groups)
}
// Convenience wrappers (function-style for the model code)
func Stack(arrays []*Array, axis int) *Array {
@@ -300,24 +271,6 @@ func Sigmoid(a *Array) *Array {
return a.Sigmoid()
}
func Exp(a *Array) *Array {
out := New("EXP")
C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Log(a *Array) *Array {
out := New("LOG")
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func SoftmaxAxis(a *Array, axis int, precise bool) *Array {
out := New("SOFTMAX_AXIS")
C.mlx_softmax_axis(&out.ctx, a.ctx, C.int(axis), C.bool(precise), DefaultStream().ctx)
return out
}
func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask bool) *Array {
mask := New("")
sinks := New("")
@@ -335,11 +288,7 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b
func RMSNormFn(x, weight *Array, eps float32) *Array {
out := New("FAST_RMSNORM")
var w C.mlx_array
if weight != nil {
w = weight.ctx
}
C.mlx_fast_rms_norm(&out.ctx, x.ctx, w, C.float(eps), DefaultStream().ctx)
C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx)
return out
}
@@ -429,27 +378,6 @@ func Collect(v any) []*Array {
return arrays
}
// Snapshot copies an array into a fresh leaf value with no Go-side graph inputs.
func Snapshot(a *Array) *Array {
if a == nil || !a.Valid() {
return a
}
out := New("SNAPSHOT")
C.mlx_copy(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// Detach returns a new Array handle that shares the same MLX value but does
// not retain Go-side graph input references.
func Detach(a *Array) *Array {
if a == nil || !a.Valid() {
return a
}
out := New("DETACH")
C.mlx_array_set(&out.ctx, a.ctx)
return out
}
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
if !v.IsValid() {
return

View File

@@ -16,20 +16,11 @@ import (
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func prefillChunkSize() int {
return 2 << 10
}
func (r *Runner) TextGenerationPipeline(request Request) error {
if r.Model == nil {
return errors.New("model not loaded")
}
ctx := request.Ctx
if ctx == nil {
ctx = context.Background()
}
var (
sample, logprobs *mlx.Array
nextSample, nextLogprobs *mlx.Array
@@ -82,46 +73,36 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
defer session.close()
caches := session.caches
tokens := session.remaining
history := append([]int32(nil), session.inputs...)
prefillChunk := prefillChunkSize()
materializeCaches := func() {
state := make([]*mlx.Array, 0, 2*len(caches))
for _, c := range caches {
if c == nil {
continue
}
state = append(state, c.Materialize()...)
}
if len(state) == 0 {
return
}
mlx.Eval(state...)
}
now := time.Now()
total, processed := len(tokens), 0
for total-processed > 1 {
if err := ctx.Err(); err != nil {
if err := request.Ctx.Err(); err != nil {
return err
}
n := min(prefillChunk, total-processed-1)
n := min(2<<10, total-processed-1)
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
mlx.Sweep()
materializeCaches()
mlx.Eval(func() []*mlx.Array {
s := make([]*mlx.Array, 2*len(caches))
for i, c := range caches {
s[2*i], s[2*i+1] = c.State()
}
return s
}()...)
processed += n
slog.Info("Prompt processing progress", "processed", processed, "total", total)
mlx.ClearCache()
}
step := func(token *mlx.Array, history []int32) (*mlx.Array, *mlx.Array) {
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
fwd := r.Model.Forward(token.ExpandDims(0), caches)
logits := r.Model.Unembed(fwd)
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
logprobs := logits.Subtract(logits.Logsumexp(true))
sample := request.Sample(logprobs, history)
sample := request.Sample(logprobs)
mlx.Pin(sample, logprobs)
mlx.Sweep()
@@ -130,16 +111,18 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
return sample, logprobs
}
sample, logprobs = step(mlx.FromValues(tokens[processed:], total-processed), history)
sample, logprobs = step(mlx.FromValues(tokens[processed:], total-processed))
var b bytes.Buffer
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1}
for i := range request.Options.MaxTokens {
if err := ctx.Err(); err != nil {
if err := request.Ctx.Err(); err != nil {
return err
}
nextSample, nextLogprobs = step(sample)
if i == 0 {
mlx.Eval(sample)
final.PromptEvalDuration = time.Since(now)
@@ -148,7 +131,6 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
output := int32(sample.Int())
session.outputs = append(session.outputs, output)
history = append(history, output)
if r.Tokenizer.IsEOS(output) {
final.DoneReason = 0
@@ -164,8 +146,6 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
}:
}
nextSample, nextLogprobs = step(sample, history)
mlx.Unpin(sample, logprobs)
sample, logprobs = nextSample, nextLogprobs
nextSample, nextLogprobs = nil, nil
@@ -178,8 +158,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
final.EvalDuration = time.Since(now)
final.PeakMemory = uint64(mlx.PeakMemory())
select {
case <-ctx.Done():
return ctx.Err()
case <-request.Ctx.Done():
return request.Ctx.Err()
case request.Responses <- final:
return nil
}

View File

@@ -32,17 +32,12 @@ type Request struct {
type TextCompletionsRequest struct {
Prompt string `json:"prompt"`
Think *bool `json:"think,omitempty"`
Options struct {
Temperature *float32 `json:"temperature"`
TopP *float32 `json:"top_p"`
MinP *float32 `json:"min_p"`
TopK *int `json:"top_k"`
RepeatLastN *int `json:"repeat_last_n"`
RepeatPenalty *float32 `json:"repeat_penalty"`
PresencePenalty *float32 `json:"presence_penalty"`
FrequencyPenalty *float32 `json:"frequency_penalty"`
MaxTokens int `json:"max_tokens"`
Temperature float32 `json:"temperature"`
TopP float32 `json:"top_p"`
MinP float32 `json:"min_p"`
TopK int `json:"top_k"`
MaxTokens int `json:"max_tokens"`
// Deprecated: use MaxTokens instead
NumPredict int `json:"num_predict"`

View File

@@ -9,204 +9,69 @@ import (
)
type Sampler interface {
Sample(*mlx.Array, []int32) *mlx.Array
Sample(*mlx.Array) *mlx.Array
}
func New(temp, top_p, min_p float32, top_k, repeatLastN int, repeatPenalty, presencePenalty, frequencyPenalty float32) Sampler {
var samplers []Sampler
if repeatLastN > 0 && (repeatPenalty != 1 || presencePenalty != 0 || frequencyPenalty != 0) {
samplers = append(samplers, Penalty{
RepeatLastN: repeatLastN,
RepeatPenalty: repeatPenalty,
PresencePenalty: presencePenalty,
FrequencyPenalty: frequencyPenalty,
})
func New(temp, top_p, min_p float32, top_k int) Sampler {
if temp == 0 {
return greedy{}
}
if temp == 0 {
samplers = append(samplers, greedy{})
} else {
samplers = append(samplers, Distribution{
Temperature: temp,
TopK: top_k,
TopP: top_p,
MinP: min_p,
})
var samplers []Sampler
if top_p > 0 && top_p < 1 {
samplers = append(samplers, TopP(top_p))
}
if min_p != 0 {
samplers = append(samplers, MinP(min_p))
}
if top_k > 0 {
samplers = append(samplers, TopK(top_k))
}
samplers = append(samplers, Temperature(temp))
return chain(samplers)
}
type greedy struct{}
func (greedy) Sample(logits *mlx.Array, _ []int32) *mlx.Array {
func (greedy) Sample(logits *mlx.Array) *mlx.Array {
return logits.Argmax(-1, false)
}
type chain []Sampler
func (c chain) Sample(logits *mlx.Array, history []int32) *mlx.Array {
func (c chain) Sample(logits *mlx.Array) *mlx.Array {
for _, sampler := range c {
logits = sampler.Sample(logits, history)
logits = sampler.Sample(logits)
}
return logits
}
type Distribution struct {
Temperature float32
TopK int
TopP float32
MinP float32
type Temperature float32
func (t Temperature) Sample(logits *mlx.Array) *mlx.Array {
return mlx.DivScalar(logits, float32(t)).Categorical(-1)
}
func (d Distribution) Sample(logits *mlx.Array, _ []int32) *mlx.Array {
filtered, indices := d.filter(logits)
sample := filtered.Categorical(-1)
if indices == nil {
return sample
}
type TopP float32
positions := sample.ExpandDims(1)
return indices.TakeAlongAxis(positions, -1).Squeeze(1)
func (p TopP) Sample(logprobs *mlx.Array) *mlx.Array {
// TODO: implement
return logprobs
}
func (d Distribution) filter(logits *mlx.Array) (*mlx.Array, *mlx.Array) {
candidates := logits
var candidateIndices *mlx.Array
type MinP float32
if d.TopK > 0 && d.TopK < logits.Dim(logits.NumDims()-1) {
partitions := logits.Negative().ArgpartitionAxis(d.TopK-1, -1)
switch logits.NumDims() {
case 1:
candidateIndices = partitions.Slice(mlx.Slice(0, d.TopK))
default:
candidateIndices = partitions.Slice(mlx.Slice(), mlx.Slice(0, d.TopK))
}
candidates = logits.TakeAlongAxis(candidateIndices, -1)
}
if d.Temperature != 1 {
candidates = mlx.DivScalar(candidates, d.Temperature)
}
if !d.needsProbabilityFilters() {
return candidates, candidateIndices
}
order := candidates.Negative().ArgsortAxis(-1)
sortedLogits := candidates.TakeAlongAxis(order, -1)
sortedProbs := mlx.SoftmaxAxis(candidates, -1, true).TakeAlongAxis(order, -1)
remove := d.topPRemovalMask(sortedProbs)
if d.MinP > 0 {
minPRemove := d.minPRemovalMask(sortedProbs)
if remove == nil {
remove = minPRemove
} else {
remove = remove.LogicalOr(minPRemove)
}
}
if remove == nil {
return candidates, candidateIndices
}
negInf := mlx.FromValue(float32(math.Inf(-1)))
filtered := mlx.Where(remove, negInf, sortedLogits)
return candidates.PutAlongAxis(order, filtered, -1), candidateIndices
func (p MinP) Sample(logprobs *mlx.Array) *mlx.Array {
// TODO: implement
return logprobs
}
func (d Distribution) needsProbabilityFilters() bool {
return (d.TopP > 0 && d.TopP < 1) || d.MinP > 0
}
func (d Distribution) topPRemovalMask(sortedProbs *mlx.Array) *mlx.Array {
if d.TopP <= 0 || d.TopP >= 1 {
return nil
}
threshold := mlx.NewScalarArray(d.TopP)
prevCum := sortedProbs.Cumsum(-1, false, true).Subtract(sortedProbs)
return prevCum.GreaterEqual(threshold)
}
func (d Distribution) minPRemovalMask(sortedProbs *mlx.Array) *mlx.Array {
if d.MinP <= 0 {
return nil
}
var maxProb *mlx.Array
switch sortedProbs.NumDims() {
case 1:
maxProb = sortedProbs.Slice(mlx.Slice(0, 1))
default:
maxProb = sortedProbs.Slice(mlx.Slice(), mlx.Slice(0, 1))
}
threshold := mlx.MulScalar(maxProb, d.MinP)
return sortedProbs.Less(threshold)
}
type Penalty struct {
RepeatLastN int
RepeatPenalty float32
PresencePenalty float32
FrequencyPenalty float32
}
func (p Penalty) Sample(logprobs *mlx.Array, history []int32) *mlx.Array {
if len(history) == 0 {
return logprobs
}
window := p.RepeatLastN
if window <= 0 || window > len(history) {
window = len(history)
}
counts := make(map[int32]int, window)
order := make([]int32, 0, window)
for _, token := range history[len(history)-window:] {
if token < 0 {
continue
}
if counts[token] == 0 {
order = append(order, token)
}
counts[token]++
}
if len(order) == 0 {
return logprobs
}
indexShape := []int32{int32(len(order))}
valueShape := []int{len(order)}
if logprobs.NumDims() > 1 {
indexShape = []int32{1, int32(len(order))}
valueShape = []int{1, len(order)}
}
indices := mlx.NewArrayInt32(order, indexShape)
selected := logprobs.TakeAlongAxis(indices, -1)
mlx.Eval(selected)
values := selected.Floats()
for i, token := range order {
v := values[i]
if p.RepeatPenalty != 1 {
if v < 0 {
v *= p.RepeatPenalty
} else {
v /= p.RepeatPenalty
}
}
if p.PresencePenalty != 0 {
v -= p.PresencePenalty
}
if p.FrequencyPenalty != 0 {
v -= p.FrequencyPenalty * float32(counts[token])
}
values[i] = v
}
return logprobs.PutAlongAxis(indices, mlx.FromValues(values, valueShape...), -1)
type TopK int
func (k TopK) Sample(logprobs *mlx.Array) *mlx.Array {
mask := logprobs.Negative().ArgpartitionAxis(int(k)-1, -1).Slice(mlx.Slice(), mlx.Slice(int(k), 0))
return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
}

View File

@@ -1,104 +0,0 @@
//go:build mlx
package sample
import (
"math"
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func TestPenaltySample(t *testing.T) {
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
logprobs := mlx.FromValues([]float32{
1.0, -2.0, 3.0, 4.0,
}, 1, 4)
got := Penalty{
RepeatLastN: 3,
RepeatPenalty: 2.0,
PresencePenalty: 1.5,
FrequencyPenalty: 0.25,
}.Sample(logprobs, []int32{2, 1, 2})
mlx.Eval(got)
want := []float32{1.0, -5.75, -0.5, 4.0}
values := got.Floats()
if len(values) != len(want) {
t.Fatalf("len(values) = %d, want %d", len(values), len(want))
}
for i := range want {
if math.Abs(float64(values[i]-want[i])) > 1e-5 {
t.Fatalf("values[%d] = %v, want %v", i, values[i], want[i])
}
}
}
func TestPenaltySampleHonorsRepeatWindow(t *testing.T) {
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
logprobs := mlx.FromValues([]float32{
1.0, 2.0, 3.0,
}, 1, 3)
got := Penalty{
RepeatLastN: 1,
PresencePenalty: 1.0,
}.Sample(logprobs, []int32{0, 1})
mlx.Eval(got)
want := []float32{1.0, 1.0, 3.0}
values := got.Floats()
for i := range want {
if math.Abs(float64(values[i]-want[i])) > 1e-5 {
t.Fatalf("values[%d] = %v, want %v", i, values[i], want[i])
}
}
}
func TestDistributionFilterTopP(t *testing.T) {
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
logits := mlx.FromValues([]float32{
10.0, 9.0, 1.0, 0.0,
}, 1, 4)
filtered, indices := Distribution{
Temperature: 1.0,
TopK: 2,
TopP: 0.55,
}.filter(logits)
got := materializeFilteredLogits(filtered, indices, 4)
mlx.Eval(got)
values := got.Floats()
if values[0] != 10.0 {
t.Fatalf("values[0] = %v, want 10", values[0])
}
for i := 1; i < len(values); i++ {
if !math.IsInf(float64(values[i]), -1) {
t.Fatalf("values[%d] = %v, want -Inf", i, values[i])
}
}
}
func materializeFilteredLogits(filtered, indices *mlx.Array, width int) *mlx.Array {
if indices == nil {
return filtered
}
base := mlx.AddScalar(mlx.Zeros(mlx.DTypeFloat32, 1, width), float32(math.Inf(-1)))
return base.PutAlongAxis(indices, filtered, -1)
}

View File

@@ -16,89 +16,12 @@ import (
"strconv"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/mlxrunner/sample"
"github.com/ollama/ollama/x/models/qwen3_5"
)
type samplingConfig struct {
temperature float32
topP float32
minP float32
topK int
repeatLastN int
repeatPenalty float32
presencePenalty float32
frequencyPenalty float32
}
func defaultSamplingConfig(m base.Model, think *bool) samplingConfig {
if _, ok := m.(*qwen3_5.Model); ok {
cfg := samplingConfig{
temperature: 1.0,
topP: 0.95,
minP: 0.0,
topK: 20,
repeatLastN: 64,
repeatPenalty: 1.0,
presencePenalty: 1.5,
frequencyPenalty: 0.0,
}
if think != nil && !*think {
cfg.temperature = 0.7
cfg.topP = 0.8
}
return cfg
}
opts := api.DefaultOptions()
return samplingConfig{
temperature: opts.Temperature,
topP: opts.TopP,
minP: opts.MinP,
topK: opts.TopK,
repeatLastN: opts.RepeatLastN,
repeatPenalty: opts.RepeatPenalty,
presencePenalty: opts.PresencePenalty,
frequencyPenalty: opts.FrequencyPenalty,
}
}
func resolveSamplingConfig(m base.Model, req Request) samplingConfig {
cfg := defaultSamplingConfig(m, req.Think)
if req.Options.Temperature != nil {
cfg.temperature = *req.Options.Temperature
}
if req.Options.TopP != nil {
cfg.topP = *req.Options.TopP
}
if req.Options.MinP != nil {
cfg.minP = *req.Options.MinP
}
if req.Options.TopK != nil {
cfg.topK = *req.Options.TopK
}
if req.Options.RepeatLastN != nil {
cfg.repeatLastN = *req.Options.RepeatLastN
}
if req.Options.RepeatPenalty != nil {
cfg.repeatPenalty = *req.Options.RepeatPenalty
}
if req.Options.PresencePenalty != nil {
cfg.presencePenalty = *req.Options.PresencePenalty
}
if req.Options.FrequencyPenalty != nil {
cfg.frequencyPenalty = *req.Options.FrequencyPenalty
}
return cfg
}
func Execute(args []string) error {
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
@@ -167,18 +90,12 @@ func Execute(args []string) error {
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
sampling := resolveSamplingConfig(runner.Model, request)
request.Pipeline = runner.TextGenerationPipeline
request.Sampler = sample.New(
sampling.temperature,
sampling.topP,
sampling.minP,
sampling.topK,
sampling.repeatLastN,
sampling.repeatPenalty,
sampling.presencePenalty,
sampling.frequencyPenalty,
request.Options.Temperature,
request.Options.TopP,
request.Options.MinP,
request.Options.TopK,
)
var cancel context.CancelFunc

View File

@@ -1,172 +0,0 @@
//go:build mlx
package mlxrunner
import (
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/qwen3_5"
"github.com/ollama/ollama/x/tokenizer"
)
type stubModel struct{}
func (stubModel) Forward(*mlx.Array, []cache.Cache) *mlx.Array { return nil }
func (stubModel) Unembed(*mlx.Array) *mlx.Array { return nil }
func (stubModel) NumLayers() int { return 0 }
func (stubModel) Tokenizer() *tokenizer.Tokenizer { return nil }
func (stubModel) LoadWeights(map[string]*mlx.Array) error { return nil }
func TestResolveSamplingConfigDefaults(t *testing.T) {
trueValue := true
falseValue := false
tests := []struct {
name string
model base.Model
req Request
want samplingConfig
}{
{
name: "generic model uses api defaults",
model: stubModel{},
req: Request{},
want: samplingConfig{
temperature: 0.8,
topP: 0.9,
minP: 0.0,
topK: 40,
repeatLastN: 64,
repeatPenalty: 1.1,
presencePenalty: 0.0,
frequencyPenalty: 0.0,
},
},
{
name: "qwen3.5 defaults to thinking profile when think unset",
model: &qwen3_5.Model{},
req: Request{},
want: samplingConfig{
temperature: 1.0,
topP: 0.95,
minP: 0.0,
topK: 20,
repeatLastN: 64,
repeatPenalty: 1.0,
presencePenalty: 1.5,
frequencyPenalty: 0.0,
},
},
{
name: "qwen3.5 thinking disabled defaults",
model: &qwen3_5.Model{},
req: Request{TextCompletionsRequest: TextCompletionsRequest{Think: &falseValue}},
want: samplingConfig{
temperature: 0.7,
topP: 0.8,
minP: 0.0,
topK: 20,
repeatLastN: 64,
repeatPenalty: 1.0,
presencePenalty: 1.5,
frequencyPenalty: 0.0,
},
},
{
name: "qwen3.5 thinking enabled defaults",
model: &qwen3_5.Model{},
req: Request{TextCompletionsRequest: TextCompletionsRequest{Think: &trueValue}},
want: samplingConfig{
temperature: 1.0,
topP: 0.95,
minP: 0.0,
topK: 20,
repeatLastN: 64,
repeatPenalty: 1.0,
presencePenalty: 1.5,
frequencyPenalty: 0.0,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := resolveSamplingConfig(tt.model, tt.req); got != tt.want {
t.Fatalf("resolveSamplingConfig() = %+v, want %+v", got, tt.want)
}
})
}
}
func TestResolveSamplingConfigOverridesSpecifiedValues(t *testing.T) {
trueValue := true
temperature := float32(0.4)
topP := float32(0.6)
minP := float32(0.05)
topK := 12
repeatLastN := 32
repeatPenalty := float32(1.1)
presencePenalty := float32(0.7)
frequencyPenalty := float32(0.2)
got := resolveSamplingConfig(stubModel{}, Request{
TextCompletionsRequest: TextCompletionsRequest{
Think: &trueValue,
Options: struct {
Temperature *float32 `json:"temperature"`
TopP *float32 `json:"top_p"`
MinP *float32 `json:"min_p"`
TopK *int `json:"top_k"`
RepeatLastN *int `json:"repeat_last_n"`
RepeatPenalty *float32 `json:"repeat_penalty"`
PresencePenalty *float32 `json:"presence_penalty"`
FrequencyPenalty *float32 `json:"frequency_penalty"`
MaxTokens int `json:"max_tokens"`
NumPredict int `json:"num_predict"`
}{
Temperature: &temperature,
TopP: &topP,
MinP: &minP,
TopK: &topK,
RepeatLastN: &repeatLastN,
RepeatPenalty: &repeatPenalty,
PresencePenalty: &presencePenalty,
FrequencyPenalty: &frequencyPenalty,
},
},
})
want := samplingConfig{
temperature: temperature,
topP: topP,
minP: minP,
topK: topK,
repeatLastN: repeatLastN,
repeatPenalty: repeatPenalty,
presencePenalty: presencePenalty,
frequencyPenalty: frequencyPenalty,
}
if got != want {
t.Fatalf("resolveSamplingConfig() = %+v, want %+v", got, want)
}
}
func TestResolveSamplingConfigMatchesGenericDefaults(t *testing.T) {
want := api.DefaultOptions()
got := defaultSamplingConfig(stubModel{}, nil)
if got.temperature != want.Temperature ||
got.topP != want.TopP ||
got.minP != want.MinP ||
got.topK != want.TopK ||
got.repeatLastN != want.RepeatLastN ||
got.repeatPenalty != want.RepeatPenalty ||
got.presencePenalty != want.PresencePenalty ||
got.frequencyPenalty != want.FrequencyPenalty {
t.Fatalf("defaultSamplingConfig() = %+v, want api defaults %+v", got, want)
}
}

View File

@@ -15,40 +15,6 @@ type LinearLayer interface {
OutputDim() int32
}
// Conv1d applies 1D convolution over NLC input.
type Conv1d struct {
Weight *mlx.Array
Bias *mlx.Array
Stride int32
Padding int32
Dilation int32
Groups int32
}
func NewConv1d(weight, bias *mlx.Array, stride, padding, dilation, groups int32) *Conv1d {
if stride <= 0 {
stride = 1
}
if dilation <= 0 {
dilation = 1
}
if groups <= 0 {
groups = 1
}
return &Conv1d{
Weight: weight,
Bias: bias,
Stride: stride,
Padding: padding,
Dilation: dilation,
Groups: groups,
}
}
func (c *Conv1d) Forward(x *mlx.Array) *mlx.Array {
return mlx.Conv1d(x, c.Weight, c.Bias, c.Stride, c.Padding, c.Dilation, c.Groups)
}
// Linear applies an affine transformation: y = x @ W.T + b
type Linear struct {
Weight *mlx.Array

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,166 +0,0 @@
//go:build mlx
package qwen3_5
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func TestParseConfigNestedDefaults(t *testing.T) {
data := []byte(`{
"model_type": "Qwen3_5MoeForConditionalGeneration",
"text_config": {
"hidden_size": 4096,
"intermediate_size": 14336,
"num_hidden_layers": 8,
"num_attention_heads": 32,
"num_key_value_heads": 8,
"head_dim": 128,
"linear_num_value_heads": 64,
"linear_num_key_heads": 16,
"linear_key_head_dim": 128,
"linear_value_head_dim": 128,
"linear_conv_kernel_dim": 4,
"num_experts": 16,
"num_experts_per_tok": 4,
"moe_intermediate_size": 2048,
"shared_expert_intermediate_size": 4096,
"rope_parameters": {
"rope_theta": 500000,
"partial_rotary_factor": 0.5
}
}
}`)
cfg, err := parseConfig(data)
if err != nil {
t.Fatalf("parseConfig failed: %v", err)
}
if cfg.RopeTheta != 500000 {
t.Fatalf("rope theta mismatch: got %v", cfg.RopeTheta)
}
if cfg.RopeDim != 64 {
t.Fatalf("rope dim mismatch: got %d want 64", cfg.RopeDim)
}
if cfg.FullAttentionInterval != 4 {
t.Fatalf("full_attention_interval default mismatch: got %d want 4", cfg.FullAttentionInterval)
}
if !cfg.NormTopKProb {
t.Fatalf("norm_topk_prob should default to true for MoE")
}
}
func TestLayerSelectionHelpers(t *testing.T) {
cfg := &Config{
NumHiddenLayers: 6,
FullAttentionInterval: 3,
NumExperts: 8,
DecoderSparseStep: 2,
MLPOnlyLayers: []int32{1},
}
if !layerIsLinear(cfg, 0) {
t.Fatalf("layer 0 should be linear")
}
if layerIsLinear(cfg, 2) {
t.Fatalf("layer 2 should be full attention")
}
if layerUsesMoE(cfg, 1) {
t.Fatalf("layer 1 should be forced dense by mlp_only_layers")
}
if !layerUsesMoE(cfg, 3) {
t.Fatalf("layer 3 should use moe with decoder_sparse_step=2")
}
}
func TestResolveTensorPathLayout(t *testing.T) {
dummy := mlx.New("dummy")
tests := []struct {
name string
key string
wantContainer string
wantModel string
}{
{
name: "standard",
key: "model.embed_tokens.weight",
wantContainer: "",
wantModel: "model.",
},
{
name: "nested language model with inner model",
key: "model.language_model.model.embed_tokens.weight",
wantContainer: "model.language_model.",
wantModel: "model.",
},
{
name: "nested language model without inner model",
key: "model.language_model.embed_tokens.weight",
wantContainer: "model.language_model.",
wantModel: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
layout := resolveTensorPathLayout(map[string]*mlx.Array{
tt.key: dummy,
})
if layout.containerPrefix != tt.wantContainer || layout.modelPrefix != tt.wantModel {
t.Fatalf(
"resolveTensorPathLayout() = {%q %q}, want {%q %q}",
layout.containerPrefix,
layout.modelPrefix,
tt.wantContainer,
tt.wantModel,
)
}
})
}
}
func TestModelRuntimeDefaults(t *testing.T) {
m := &Model{}
if m.DisablePromptCache() {
t.Fatal("DisablePromptCache() = true, want false")
}
}
func TestNewCachesLayout(t *testing.T) {
m := &Model{
Config: &Config{
LinearConvKernelDim: 4,
LinearNumKeyHeads: 2,
LinearKeyHeadDim: 8,
LinearNumValueHeads: 4,
LinearValueHeadDim: 16,
},
Layers: []*Layer{
{IsLinear: true},
{IsLinear: false},
{IsLinear: true},
},
}
caches := m.NewCaches()
if len(caches) != len(m.Layers) {
t.Fatalf("len(caches) = %d, want %d", len(caches), len(m.Layers))
}
if _, ok := caches[0].(*cache.RecurrentCache); !ok {
t.Fatalf("cache[0] = %T, want *cache.RecurrentCache", caches[0])
}
if _, ok := caches[1].(*cache.KVCache); !ok {
t.Fatalf("cache[1] = %T, want *cache.KVCache", caches[1])
}
if _, ok := caches[2].(*cache.RecurrentCache); !ok {
t.Fatalf("cache[2] = %T, want *cache.RecurrentCache", caches[2])
}
}

View File

@@ -1,16 +0,0 @@
//go:build mlx
// Package qwen3_5_moe registers Qwen 3.5 MoE architecture aliases.
package qwen3_5_moe
import (
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/qwen3_5"
)
func init() {
base.Register("Qwen3_5MoeForConditionalGeneration", qwen3_5.NewModel)
base.Register("Qwen3_5MoeForCausalLM", qwen3_5.NewModel)
base.Register("Qwen3NextMoeForConditionalGeneration", qwen3_5.NewModel)
base.Register("Qwen3NextMoeForCausalLM", qwen3_5.NewModel)
}