mirror of
https://github.com/ollama/ollama.git
synced 2026-03-01 21:46:45 -05:00
Compare commits
1 Commits
pdevine/sa
...
pdevine/sa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
857cffd22a |
14
api/types.go
14
api/types.go
@@ -15,7 +15,6 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
|
||||||
"github.com/ollama/ollama/internal/orderedmap"
|
"github.com/ollama/ollama/internal/orderedmap"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
@@ -570,7 +569,6 @@ type DebugInfo struct {
|
|||||||
|
|
||||||
type Metrics struct {
|
type Metrics struct {
|
||||||
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
||||||
PeakMemory uint64 `json:"peak_memory,omitempty"`
|
|
||||||
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
||||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
||||||
@@ -936,10 +934,6 @@ func (m *Metrics) Summary() {
|
|||||||
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
|
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.PeakMemory > 0 {
|
|
||||||
fmt.Fprintf(os.Stderr, "peak memory: %s\n", formatPeakMemory(m.PeakMemory))
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.LoadDuration > 0 {
|
if m.LoadDuration > 0 {
|
||||||
fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration)
|
fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration)
|
||||||
}
|
}
|
||||||
@@ -963,14 +957,6 @@ func (m *Metrics) Summary() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func formatPeakMemory(b uint64) string {
|
|
||||||
if b >= format.GibiByte {
|
|
||||||
return fmt.Sprintf("%.3f GiB", float64(b)/float64(format.GibiByte))
|
|
||||||
}
|
|
||||||
|
|
||||||
return format.HumanBytes2(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (opts *Options) FromMap(m map[string]any) error {
|
func (opts *Options) FromMap(m map[string]any) error {
|
||||||
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
||||||
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
||||||
|
|||||||
@@ -74,7 +74,8 @@ type LlamaServer interface {
|
|||||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||||
Close() error
|
Close() error
|
||||||
MemorySize() (total, vram uint64)
|
VRAMSize() uint64 // Total VRAM across all GPUs
|
||||||
|
TotalSize() uint64
|
||||||
VRAMByGPU(id ml.DeviceID) uint64
|
VRAMByGPU(id ml.DeviceID) uint64
|
||||||
Pid() int
|
Pid() int
|
||||||
GetPort() int
|
GetPort() int
|
||||||
@@ -684,9 +685,8 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
|
|||||||
// Windows CUDA should not use mmap for best performance
|
// Windows CUDA should not use mmap for best performance
|
||||||
// Linux with a model larger than free space, mmap leads to thrashing
|
// Linux with a model larger than free space, mmap leads to thrashing
|
||||||
// For CPU loads we want the memory to be allocated, not FS cache
|
// For CPU loads we want the memory to be allocated, not FS cache
|
||||||
totalSize, _ := s.MemorySize()
|
|
||||||
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
|
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
|
||||||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < totalSize && s.options.UseMMap == nil) ||
|
(runtime.GOOS == "linux" && systemInfo.FreeMemory < s.TotalSize() && s.options.UseMMap == nil) ||
|
||||||
(len(gpus) == 0 && s.options.UseMMap == nil) ||
|
(len(gpus) == 0 && s.options.UseMMap == nil) ||
|
||||||
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
|
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
|
||||||
(s.options.UseMMap != nil && !*s.options.UseMMap) {
|
(s.options.UseMMap != nil && !*s.options.UseMMap) {
|
||||||
@@ -1453,12 +1453,10 @@ type ImageData struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CompletionRequest struct {
|
type CompletionRequest struct {
|
||||||
Prompt string
|
Prompt string
|
||||||
Format json.RawMessage
|
Format json.RawMessage
|
||||||
Images []ImageData
|
Images []ImageData
|
||||||
Options *api.Options
|
Options *api.Options
|
||||||
Think *api.ThinkValue
|
|
||||||
ExplicitOptions map[string]struct{}
|
|
||||||
|
|
||||||
Grammar string // set before sending the request to the subprocess
|
Grammar string // set before sending the request to the subprocess
|
||||||
Shift bool
|
Shift bool
|
||||||
@@ -1520,7 +1518,6 @@ type CompletionResponse struct {
|
|||||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
||||||
EvalCount int `json:"eval_count"`
|
EvalCount int `json:"eval_count"`
|
||||||
EvalDuration time.Duration `json:"eval_duration"`
|
EvalDuration time.Duration `json:"eval_duration"`
|
||||||
PeakMemory uint64 `json:"peak_memory,omitempty"`
|
|
||||||
|
|
||||||
// Logprobs contains log probability information if requested
|
// Logprobs contains log probability information if requested
|
||||||
Logprobs []Logprob `json:"logprobs,omitempty"`
|
Logprobs []Logprob `json:"logprobs,omitempty"`
|
||||||
@@ -1851,17 +1848,17 @@ func (s *llamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) MemorySize() (total, vram uint64) {
|
func (s *llmServer) VRAMSize() uint64 {
|
||||||
if s.mem == nil {
|
if s.mem == nil {
|
||||||
return 0, 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var mem uint64
|
||||||
|
|
||||||
for _, g := range s.mem.GPUs {
|
for _, g := range s.mem.GPUs {
|
||||||
vram += g.Size()
|
mem += g.Size()
|
||||||
}
|
}
|
||||||
|
|
||||||
total = s.mem.InputWeights + s.mem.CPU.Size() + vram
|
|
||||||
|
|
||||||
// Some elements are always on CPU. However, if we have allocated all layers
|
// Some elements are always on CPU. However, if we have allocated all layers
|
||||||
// on the GPU then include the CPU components as well, to represent complete offloading.
|
// on the GPU then include the CPU components as well, to represent complete offloading.
|
||||||
noCPULayers := true
|
noCPULayers := true
|
||||||
@@ -1872,11 +1869,25 @@ func (s *llmServer) MemorySize() (total, vram uint64) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if noCPULayers {
|
if noCPULayers {
|
||||||
vram += s.mem.InputWeights
|
mem += s.mem.InputWeights
|
||||||
vram += s.mem.CPU.Graph
|
mem += s.mem.CPU.Graph
|
||||||
}
|
}
|
||||||
|
|
||||||
return total, vram
|
return mem
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *llmServer) TotalSize() uint64 {
|
||||||
|
if s.mem == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
mem := s.mem.InputWeights
|
||||||
|
mem += s.mem.CPU.Size()
|
||||||
|
for _, g := range s.mem.GPUs {
|
||||||
|
mem += g.Size()
|
||||||
|
}
|
||||||
|
|
||||||
|
return mem
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {
|
func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||||
|
|||||||
@@ -41,8 +41,8 @@ type GatedDeltaNet struct {
|
|||||||
SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35)
|
SSMBeta *nn.Linear `gguf:"ssm_beta"` // -> beta (qwen35)
|
||||||
SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35)
|
SSMAlpha *nn.Linear `gguf:"ssm_alpha"` // -> alpha (qwen35)
|
||||||
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
|
SSMConv1D *convKernel `gguf:"ssm_conv1d"`
|
||||||
SSMDT ml.Tensor `gguf:"ssm_dt,alt:ssm_dt.bias"` // alpha bias
|
SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias
|
||||||
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
|
SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
|
||||||
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
|
SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
|
||||||
SSMOut *nn.Linear `gguf:"ssm_out"`
|
SSMOut *nn.Linear `gguf:"ssm_out"`
|
||||||
|
|
||||||
@@ -135,18 +135,6 @@ func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cac
|
|||||||
default:
|
default:
|
||||||
return nil, errors.New("qwen3next: missing linear attention beta/alpha projections")
|
return nil, errors.New("qwen3next: missing linear attention beta/alpha projections")
|
||||||
}
|
}
|
||||||
if gdn.SSMDT == nil {
|
|
||||||
return nil, errors.New("qwen3next: missing linear attention ssm_dt tensor")
|
|
||||||
}
|
|
||||||
if gdn.SSMA == nil {
|
|
||||||
return nil, errors.New("qwen3next: missing linear attention ssm_a tensor")
|
|
||||||
}
|
|
||||||
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
|
|
||||||
return nil, errors.New("qwen3next: missing linear attention ssm_conv1d tensor")
|
|
||||||
}
|
|
||||||
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
|
|
||||||
return nil, errors.New("qwen3next: missing linear attention ssm_norm/ssm_out projections")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute gate: softplus(alpha + dt_bias) * -A
|
// Compute gate: softplus(alpha + dt_bias) * -A
|
||||||
alphaBiased := alpha.Add(ctx, gdn.SSMDT)
|
alphaBiased := alpha.Add(ctx, gdn.SSMDT)
|
||||||
|
|||||||
@@ -437,46 +437,6 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
return m.Output.Forward(ctx, hiddenStates), nil
|
return m.Output.Forward(ctx, hiddenStates), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Validate() error {
|
|
||||||
if m.Options == nil {
|
|
||||||
return fmt.Errorf("qwen3next: missing model options")
|
|
||||||
}
|
|
||||||
if len(m.Layers) != len(m.Options.isRecurrent) {
|
|
||||||
return fmt.Errorf("qwen3next: layer config mismatch: have %d layers, %d recurrent flags", len(m.Layers), len(m.Options.isRecurrent))
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, layer := range m.Layers {
|
|
||||||
if !m.Options.isRecurrent[i] {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
gdn, ok := layer.Operator.(*GatedDeltaNet)
|
|
||||||
if !ok || gdn == nil {
|
|
||||||
return fmt.Errorf("qwen3next: layer %d expected recurrent operator", i)
|
|
||||||
}
|
|
||||||
if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
|
|
||||||
return fmt.Errorf("qwen3next: layer %d missing attn_qkv/attn_gate projections", i)
|
|
||||||
}
|
|
||||||
if gdn.SSMBetaAlpha == nil && (gdn.SSMBeta == nil || gdn.SSMAlpha == nil) {
|
|
||||||
return fmt.Errorf("qwen3next: layer %d missing linear attention beta/alpha projections", i)
|
|
||||||
}
|
|
||||||
if gdn.SSMDT == nil {
|
|
||||||
return fmt.Errorf("qwen3next: layer %d missing ssm_dt tensor", i)
|
|
||||||
}
|
|
||||||
if gdn.SSMA == nil {
|
|
||||||
return fmt.Errorf("qwen3next: layer %d missing ssm_a tensor", i)
|
|
||||||
}
|
|
||||||
if gdn.SSMConv1D == nil || gdn.SSMConv1D.Weight == nil {
|
|
||||||
return fmt.Errorf("qwen3next: layer %d missing ssm_conv1d tensor", i)
|
|
||||||
}
|
|
||||||
if gdn.SSMNorm == nil || gdn.SSMOut == nil {
|
|
||||||
return fmt.Errorf("qwen3next: layer %d missing ssm_norm/ssm_out projections", i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
m.positionCache = nil
|
m.positionCache = nil
|
||||||
if len(m.mropeSections) > 0 {
|
if len(m.mropeSections) > 0 {
|
||||||
@@ -490,64 +450,6 @@ var (
|
|||||||
_ model.MultimodalProcessor = (*Model)(nil)
|
_ model.MultimodalProcessor = (*Model)(nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
func defaultVHeadReordered(arch string) bool {
|
|
||||||
return arch == "qwen35" || arch == "qwen35moe"
|
|
||||||
}
|
|
||||||
|
|
||||||
func inferRecurrentLayers(headCountKV []uint64, numLayers int, fullAttentionInterval uint32) ([]bool, error) {
|
|
||||||
isRecurrent := make([]bool, numLayers)
|
|
||||||
|
|
||||||
hasZero := false
|
|
||||||
hasFull := false
|
|
||||||
for i := range numLayers {
|
|
||||||
if i >= len(headCountKV) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if headCountKV[i] == 0 {
|
|
||||||
isRecurrent[i] = true
|
|
||||||
hasZero = true
|
|
||||||
} else {
|
|
||||||
hasFull = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if hasZero && hasFull {
|
|
||||||
return isRecurrent, nil
|
|
||||||
}
|
|
||||||
if !hasFull {
|
|
||||||
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compatibility path: older imports store a scalar KV head count and omit
|
|
||||||
// per-layer recurrent flags. Derive the hybrid layout from the interval.
|
|
||||||
interval := int(fullAttentionInterval)
|
|
||||||
if interval == 0 {
|
|
||||||
interval = min(4, numLayers)
|
|
||||||
}
|
|
||||||
if interval <= 0 {
|
|
||||||
return nil, fmt.Errorf("qwen3next: invalid block_count (%d)", numLayers)
|
|
||||||
}
|
|
||||||
if interval > numLayers {
|
|
||||||
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds block_count (%d)", interval, numLayers)
|
|
||||||
}
|
|
||||||
|
|
||||||
hasZero = false
|
|
||||||
hasFull = false
|
|
||||||
for i := range numLayers {
|
|
||||||
isRecurrent[i] = (i+1)%interval != 0
|
|
||||||
if isRecurrent[i] {
|
|
||||||
hasZero = true
|
|
||||||
} else {
|
|
||||||
hasFull = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !hasZero || !hasFull {
|
|
||||||
return nil, fmt.Errorf("qwen3next: full_attention_interval (%d) does not produce a mixed recurrent/full layout", interval)
|
|
||||||
}
|
|
||||||
|
|
||||||
return isRecurrent, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
numLayers := int(c.Uint("block_count"))
|
numLayers := int(c.Uint("block_count"))
|
||||||
layers := make([]Layer, numLayers)
|
layers := make([]Layer, numLayers)
|
||||||
@@ -558,14 +460,26 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
HeadCountKV() []uint64
|
HeadCountKV() []uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var isRecurrent []bool
|
||||||
var headCountKV []uint64
|
var headCountKV []uint64
|
||||||
if hc, ok := c.(headCounts); ok {
|
if hc, ok := c.(headCounts); ok {
|
||||||
headCountKV = hc.HeadCountKV()
|
headCountKV = hc.HeadCountKV()
|
||||||
}
|
}
|
||||||
|
|
||||||
isRecurrent, err := inferRecurrentLayers(headCountKV, numLayers, c.Uint("full_attention_interval"))
|
isRecurrent = make([]bool, numLayers)
|
||||||
if err != nil {
|
hasZero := false
|
||||||
return nil, err
|
hasFull := false
|
||||||
|
for i := range numLayers {
|
||||||
|
// If KV head count is 0, it's a recurrent layer
|
||||||
|
if i < len(headCountKV) && headCountKV[i] == 0 {
|
||||||
|
isRecurrent[i] = true
|
||||||
|
hasZero = true
|
||||||
|
} else if i < len(headCountKV) && headCountKV[i] > 0 {
|
||||||
|
hasFull = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasZero || !hasFull {
|
||||||
|
return nil, fmt.Errorf("qwen3next: invalid attention.head_count_kv array; expected mix of zero and non-zero values")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine if MoE
|
// Determine if MoE
|
||||||
@@ -629,7 +543,7 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
ssmNGroup: int(c.Uint("ssm.group_count")),
|
ssmNGroup: int(c.Uint("ssm.group_count")),
|
||||||
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
|
ssmDtRank: int(c.Uint("ssm.time_step_rank")),
|
||||||
convKernelSize: int(c.Uint("ssm.conv_kernel")),
|
convKernelSize: int(c.Uint("ssm.conv_kernel")),
|
||||||
vHeadReordered: c.Bool("ssm.v_head_reordered", defaultVHeadReordered(c.Architecture())),
|
vHeadReordered: c.Bool("ssm.v_head_reordered", false),
|
||||||
isRecurrent: isRecurrent,
|
isRecurrent: isRecurrent,
|
||||||
mropeSections: slices.Collect(func(yield func(int) bool) {
|
mropeSections: slices.Collect(func(yield func(int) bool) {
|
||||||
for _, section := range mropeSections {
|
for _, section := range mropeSections {
|
||||||
@@ -641,7 +555,7 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)),
|
mropeInterleaved: c.Bool("rope.mrope_interleaved", c.Bool("mrope_interleaved", false)),
|
||||||
}
|
}
|
||||||
if opts.numKVHeads == 0 {
|
if opts.numKVHeads == 0 {
|
||||||
return nil, fmt.Errorf("qwen3next: attention.head_count_kv must include at least one non-zero value")
|
return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate cache dimensions
|
// Calculate cache dimensions
|
||||||
|
|||||||
@@ -1,65 +0,0 @@
|
|||||||
package qwen3next
|
|
||||||
|
|
||||||
import (
|
|
||||||
"slices"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestInferRecurrentLayersMixedKVArray(t *testing.T) {
|
|
||||||
got, err := inferRecurrentLayers([]uint64{0, 2, 0, 2}, 4, 0)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
want := []bool{true, false, true, false}
|
|
||||||
if !slices.Equal(got, want) {
|
|
||||||
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInferRecurrentLayersScalarKVDefaultInterval(t *testing.T) {
|
|
||||||
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2, 2, 2}, 8, 0)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
want := []bool{true, true, true, false, true, true, true, false}
|
|
||||||
if !slices.Equal(got, want) {
|
|
||||||
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInferRecurrentLayersScalarKVConfiguredInterval(t *testing.T) {
|
|
||||||
got, err := inferRecurrentLayers([]uint64{2, 2, 2, 2, 2, 2}, 6, 3)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("inferRecurrentLayers() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
want := []bool{true, true, false, true, true, false}
|
|
||||||
if !slices.Equal(got, want) {
|
|
||||||
t.Fatalf("inferRecurrentLayers() = %v, want %v", got, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInferRecurrentLayersAllZeroRejects(t *testing.T) {
|
|
||||||
_, err := inferRecurrentLayers([]uint64{0, 0, 0, 0}, 4, 0)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("inferRecurrentLayers() expected error, got nil")
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), "must include at least one non-zero value") {
|
|
||||||
t.Fatalf("unexpected error = %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultVHeadReordered(t *testing.T) {
|
|
||||||
if !defaultVHeadReordered("qwen35") {
|
|
||||||
t.Fatal("defaultVHeadReordered(qwen35) = false, want true")
|
|
||||||
}
|
|
||||||
if !defaultVHeadReordered("qwen35moe") {
|
|
||||||
t.Fatal("defaultVHeadReordered(qwen35moe) = false, want true")
|
|
||||||
}
|
|
||||||
if defaultVHeadReordered("qwen3next") {
|
|
||||||
t.Fatal("defaultVHeadReordered(qwen3next) = true, want false")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
package qwen3next
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/ml/nn"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestValidateRecurrentLayerRequiresSSMDT(t *testing.T) {
|
|
||||||
m := &Model{
|
|
||||||
Layers: []Layer{{
|
|
||||||
Operator: &GatedDeltaNet{
|
|
||||||
SSMQKV: &nn.Linear{},
|
|
||||||
SSMQKVGate: &nn.Linear{},
|
|
||||||
SSMBeta: &nn.Linear{},
|
|
||||||
SSMAlpha: &nn.Linear{},
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
Options: &Options{
|
|
||||||
isRecurrent: []bool{true},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
err := m.Validate()
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Validate() expected error, got nil")
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), "missing ssm_dt") {
|
|
||||||
t.Fatalf("unexpected error = %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateNonRecurrentSkipsLinearChecks(t *testing.T) {
|
|
||||||
m := &Model{
|
|
||||||
Layers: []Layer{{Operator: &FullAttention{}}},
|
|
||||||
Options: &Options{
|
|
||||||
isRecurrent: []bool{false},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.Validate(); err != nil {
|
|
||||||
t.Fatalf("Validate() error = %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -32,10 +32,9 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type GLM46Parser struct {
|
type GLM46Parser struct {
|
||||||
state glm46ParserState
|
state glm46ParserState
|
||||||
buffer strings.Builder
|
buffer strings.Builder
|
||||||
tools []api.Tool
|
tools []api.Tool
|
||||||
callIndex int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *GLM46Parser) HasToolSupport() bool {
|
func (p *GLM46Parser) HasToolSupport() bool {
|
||||||
@@ -49,7 +48,6 @@ func (p *GLM46Parser) HasThinkingSupport() bool {
|
|||||||
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
// func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||||
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
p.tools = tools
|
p.tools = tools
|
||||||
p.callIndex = 0
|
|
||||||
return tools
|
return tools
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -91,8 +89,6 @@ func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string,
|
|||||||
slog.Warn("glm-4.6 tool call parsing failed", "error", err)
|
slog.Warn("glm-4.6 tool call parsing failed", "error", err)
|
||||||
return "", "", nil, err
|
return "", "", nil, err
|
||||||
}
|
}
|
||||||
toolCall.Function.Index = p.callIndex
|
|
||||||
p.callIndex++
|
|
||||||
toolCalls = append(toolCalls, toolCall)
|
toolCalls = append(toolCalls, toolCall)
|
||||||
case glm46EventThinkingContent:
|
case glm46EventThinkingContent:
|
||||||
thinkingSb.WriteString(event.content)
|
thinkingSb.WriteString(event.content)
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ type GLM47Parser struct {
|
|||||||
|
|
||||||
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
func (p *GLM47Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
p.tools = tools
|
p.tools = tools
|
||||||
p.callIndex = 0
|
|
||||||
// When thinking is enabled (nil or true), the prompt ends with <think>,
|
// When thinking is enabled (nil or true), the prompt ends with <think>,
|
||||||
// so model output starts directly with thinking content (no opening tag).
|
// so model output starts directly with thinking content (no opening tag).
|
||||||
if thinkValue == nil || thinkValue.Bool() {
|
if thinkValue == nil || thinkValue.Bool() {
|
||||||
|
|||||||
@@ -97,91 +97,3 @@ func TestGLM47ParserToolCallEscaping(t *testing.T) {
|
|||||||
t.Fatalf("expected %#v, got %#v", expected, toolCall)
|
t.Fatalf("expected %#v, got %#v", expected, toolCall)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGLM47ParserToolCallIndexing(t *testing.T) {
|
|
||||||
parser := GLM47Parser{}
|
|
||||||
parser.Init(nil, nil, nil)
|
|
||||||
|
|
||||||
input := `plan</think>
|
|
||||||
<tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>
|
|
||||||
<tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>
|
|
||||||
<tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>`
|
|
||||||
|
|
||||||
_, _, calls, err := parser.Add(input, true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("parse failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
want := []api.ToolCall{
|
|
||||||
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
|
||||||
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
|
||||||
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
|
||||||
}
|
|
||||||
if len(calls) != len(want) {
|
|
||||||
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
|
||||||
}
|
|
||||||
for i := range want {
|
|
||||||
if !toolCallEqual(calls[i], want[i]) {
|
|
||||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGLM47ParserToolCallIndexingStreaming(t *testing.T) {
|
|
||||||
parser := GLM47Parser{}
|
|
||||||
parser.Init(nil, nil, nil)
|
|
||||||
|
|
||||||
var all []api.ToolCall
|
|
||||||
|
|
||||||
_, _, calls, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call><tool_call>second<arg_key>b</arg_key>", false)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("step 1 parse failed: %v", err)
|
|
||||||
}
|
|
||||||
all = append(all, calls...)
|
|
||||||
|
|
||||||
_, _, calls, err = parser.Add("<arg_value>2</arg_value></tool_call><tool_call>third<arg_key>c</arg_key><arg_value>3</arg_value></tool_call>", true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("step 2 parse failed: %v", err)
|
|
||||||
}
|
|
||||||
all = append(all, calls...)
|
|
||||||
|
|
||||||
want := []api.ToolCall{
|
|
||||||
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
|
||||||
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
|
||||||
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
|
||||||
}
|
|
||||||
if len(all) != len(want) {
|
|
||||||
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
|
||||||
}
|
|
||||||
for i := range want {
|
|
||||||
if !toolCallEqual(all[i], want[i]) {
|
|
||||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGLM47ParserToolCallIndexResetOnInit(t *testing.T) {
|
|
||||||
parser := GLM47Parser{}
|
|
||||||
parser.Init(nil, nil, nil)
|
|
||||||
|
|
||||||
_, _, _, err := parser.Add("plan</think><tool_call>first<arg_key>a</arg_key><arg_value>1</arg_value></tool_call>", true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("first parse failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
parser.Init(nil, nil, nil)
|
|
||||||
_, _, calls, err := parser.Add("plan</think><tool_call>second<arg_key>b</arg_key><arg_value>2</arg_value></tool_call>", true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("second parse failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
want := api.ToolCall{
|
|
||||||
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
|
|
||||||
}
|
|
||||||
if len(calls) != 1 {
|
|
||||||
t.Fatalf("expected 1 call, got %d", len(calls))
|
|
||||||
}
|
|
||||||
if !toolCallEqual(calls[0], want) {
|
|
||||||
t.Fatalf("got %#v, want %#v", calls[0], want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ type Qwen3Parser struct {
|
|||||||
state qwen3ParserState
|
state qwen3ParserState
|
||||||
buffer strings.Builder
|
buffer strings.Builder
|
||||||
tools []api.Tool
|
tools []api.Tool
|
||||||
callIndex int
|
|
||||||
hasThinkingSupport bool
|
hasThinkingSupport bool
|
||||||
defaultThinking bool
|
defaultThinking bool
|
||||||
maybeThinkingOpenAtBOL bool
|
maybeThinkingOpenAtBOL bool
|
||||||
@@ -55,7 +54,6 @@ func (p *Qwen3Parser) HasThinkingSupport() bool {
|
|||||||
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
p.tools = tools
|
p.tools = tools
|
||||||
p.buffer.Reset()
|
p.buffer.Reset()
|
||||||
p.callIndex = 0
|
|
||||||
|
|
||||||
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
||||||
if thinkValue == nil {
|
if thinkValue == nil {
|
||||||
@@ -108,8 +106,6 @@ func (p *Qwen3Parser) Add(s string, done bool) (content string, thinking string,
|
|||||||
slog.Warn("qwen3 tool call parsing failed", "error", err)
|
slog.Warn("qwen3 tool call parsing failed", "error", err)
|
||||||
return "", "", nil, err
|
return "", "", nil, err
|
||||||
}
|
}
|
||||||
toolCall.Function.Index = p.callIndex
|
|
||||||
p.callIndex++
|
|
||||||
calls = append(calls, toolCall)
|
calls = append(calls, toolCall)
|
||||||
case qwen3EventThinkingContent:
|
case qwen3EventThinkingContent:
|
||||||
thinkingSb.WriteString(event.content)
|
thinkingSb.WriteString(event.content)
|
||||||
|
|||||||
@@ -230,89 +230,3 @@ func TestQwen35ParserRespectsNoThink(t *testing.T) {
|
|||||||
t.Fatalf("expected no tool calls, got %d", len(calls))
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQwen3ParserToolCallIndexing(t *testing.T) {
|
|
||||||
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
|
||||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
|
||||||
|
|
||||||
input := `<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>
|
|
||||||
<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>
|
|
||||||
<tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`
|
|
||||||
_, _, calls, err := parser.Add(input, true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("parse failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
want := []api.ToolCall{
|
|
||||||
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
|
||||||
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
|
||||||
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
|
||||||
}
|
|
||||||
if len(calls) != len(want) {
|
|
||||||
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
|
||||||
}
|
|
||||||
for i := range want {
|
|
||||||
if !toolCallEqual(calls[i], want[i]) {
|
|
||||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQwen3ParserToolCallIndexingStreaming(t *testing.T) {
|
|
||||||
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
|
||||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
|
||||||
|
|
||||||
var all []api.ToolCall
|
|
||||||
|
|
||||||
_, _, calls, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call><tool_call>{"name":"second","arguments":{"b":"2"}`, false)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("step 1 parse failed: %v", err)
|
|
||||||
}
|
|
||||||
all = append(all, calls...)
|
|
||||||
|
|
||||||
_, _, calls, err = parser.Add(`}</tool_call><tool_call>{"name":"third","arguments":{"c":"3"}}</tool_call>`, true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("step 2 parse failed: %v", err)
|
|
||||||
}
|
|
||||||
all = append(all, calls...)
|
|
||||||
|
|
||||||
want := []api.ToolCall{
|
|
||||||
{Function: api.ToolCallFunction{Name: "first", Arguments: args(`{"a":"1"}`), Index: 0}},
|
|
||||||
{Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 1}},
|
|
||||||
{Function: api.ToolCallFunction{Name: "third", Arguments: args(`{"c":"3"}`), Index: 2}},
|
|
||||||
}
|
|
||||||
if len(all) != len(want) {
|
|
||||||
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
|
||||||
}
|
|
||||||
for i := range want {
|
|
||||||
if !toolCallEqual(all[i], want[i]) {
|
|
||||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQwen3ParserToolCallIndexResetOnInit(t *testing.T) {
|
|
||||||
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
|
||||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
|
||||||
|
|
||||||
_, _, _, err := parser.Add(`<tool_call>{"name":"first","arguments":{"a":"1"}}</tool_call>`, true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("first parse failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
|
||||||
_, _, calls, err := parser.Add(`<tool_call>{"name":"second","arguments":{"b":"2"}}</tool_call>`, true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("second parse failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
want := api.ToolCall{
|
|
||||||
Function: api.ToolCallFunction{Name: "second", Arguments: args(`{"b":"2"}`), Index: 0},
|
|
||||||
}
|
|
||||||
if len(calls) != 1 {
|
|
||||||
t.Fatalf("expected 1 call, got %d", len(calls))
|
|
||||||
}
|
|
||||||
if !toolCallEqual(calls[0], want) {
|
|
||||||
t.Fatalf("got %#v, want %#v", calls[0], want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -29,10 +29,9 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Qwen3CoderParser struct {
|
type Qwen3CoderParser struct {
|
||||||
state qwenParserState
|
state qwenParserState
|
||||||
acc strings.Builder
|
acc strings.Builder
|
||||||
tools []api.Tool
|
tools []api.Tool
|
||||||
callIndex int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Qwen3CoderParser) HasToolSupport() bool {
|
func (p *Qwen3CoderParser) HasToolSupport() bool {
|
||||||
@@ -45,7 +44,6 @@ func (p *Qwen3CoderParser) HasThinkingSupport() bool {
|
|||||||
|
|
||||||
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
p.tools = tools
|
p.tools = tools
|
||||||
p.callIndex = 0
|
|
||||||
return tools // Qwen doesn't modify tools
|
return tools // Qwen doesn't modify tools
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,8 +62,6 @@ func (p *Qwen3CoderParser) Add(s string, done bool) (content string, thinking st
|
|||||||
slog.Warn("qwen tool call parsing failed", "error", err)
|
slog.Warn("qwen tool call parsing failed", "error", err)
|
||||||
return "", "", nil, err
|
return "", "", nil, err
|
||||||
}
|
}
|
||||||
toolCall.Function.Index = p.callIndex
|
|
||||||
p.callIndex++
|
|
||||||
toolCalls = append(toolCalls, toolCall)
|
toolCalls = append(toolCalls, toolCall)
|
||||||
case qwenEventContent:
|
case qwenEventContent:
|
||||||
// TODO(drifkin): if the same turn contains multiple interleaved content
|
// TODO(drifkin): if the same turn contains multiple interleaved content
|
||||||
|
|||||||
@@ -1035,92 +1035,6 @@ func TestQwenToolCallValueParsing(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQwen3CoderParserToolCallIndexing(t *testing.T) {
|
|
||||||
parser := Qwen3CoderParser{}
|
|
||||||
parser.Init(nil, nil, nil)
|
|
||||||
|
|
||||||
input := `<tool_call><function=first><parameter=a>1</parameter></function></tool_call>
|
|
||||||
<tool_call><function=second><parameter=b>2</parameter></function></tool_call>
|
|
||||||
<tool_call><function=third><parameter=c>3</parameter></function></tool_call>`
|
|
||||||
_, _, calls, err := parser.Add(input, true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("parse failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
want := []api.ToolCall{
|
|
||||||
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
|
|
||||||
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
|
|
||||||
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
|
|
||||||
}
|
|
||||||
if len(calls) != len(want) {
|
|
||||||
t.Fatalf("expected %d calls, got %d", len(want), len(calls))
|
|
||||||
}
|
|
||||||
for i := range want {
|
|
||||||
if !toolCallEqual(calls[i], want[i]) {
|
|
||||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, calls[i], want[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQwen3CoderParserToolCallIndexingStreaming(t *testing.T) {
|
|
||||||
parser := Qwen3CoderParser{}
|
|
||||||
parser.Init(nil, nil, nil)
|
|
||||||
|
|
||||||
var all []api.ToolCall
|
|
||||||
|
|
||||||
_, _, calls, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call><tool_call><function=second>", false)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("step 1 parse failed: %v", err)
|
|
||||||
}
|
|
||||||
all = append(all, calls...)
|
|
||||||
|
|
||||||
_, _, calls, err = parser.Add("<parameter=b>2</parameter></function></tool_call><tool_call><function=third><parameter=c>3</parameter></function></tool_call>", true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("step 2 parse failed: %v", err)
|
|
||||||
}
|
|
||||||
all = append(all, calls...)
|
|
||||||
|
|
||||||
want := []api.ToolCall{
|
|
||||||
{Function: api.ToolCallFunction{Name: "first", Arguments: testArgs(map[string]any{"a": "1"}), Index: 0}},
|
|
||||||
{Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 1}},
|
|
||||||
{Function: api.ToolCallFunction{Name: "third", Arguments: testArgs(map[string]any{"c": "3"}), Index: 2}},
|
|
||||||
}
|
|
||||||
if len(all) != len(want) {
|
|
||||||
t.Fatalf("expected %d calls, got %d", len(want), len(all))
|
|
||||||
}
|
|
||||||
for i := range want {
|
|
||||||
if !toolCallEqual(all[i], want[i]) {
|
|
||||||
t.Fatalf("call %d mismatch: got %#v, want %#v", i, all[i], want[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQwen3CoderParserToolCallIndexResetOnInit(t *testing.T) {
|
|
||||||
parser := Qwen3CoderParser{}
|
|
||||||
parser.Init(nil, nil, nil)
|
|
||||||
|
|
||||||
_, _, _, err := parser.Add("<tool_call><function=first><parameter=a>1</parameter></function></tool_call>", true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("first parse failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
parser.Init(nil, nil, nil)
|
|
||||||
_, _, calls, err := parser.Add("<tool_call><function=second><parameter=b>2</parameter></function></tool_call>", true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("second parse failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
want := api.ToolCall{
|
|
||||||
Function: api.ToolCallFunction{Name: "second", Arguments: testArgs(map[string]any{"b": "2"}), Index: 0},
|
|
||||||
}
|
|
||||||
if len(calls) != 1 {
|
|
||||||
t.Fatalf("expected 1 call, got %d", len(calls))
|
|
||||||
}
|
|
||||||
if !toolCallEqual(calls[0], want) {
|
|
||||||
t.Fatalf("got %#v, want %#v", calls[0], want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQwenXMLTransform(t *testing.T) {
|
func TestQwenXMLTransform(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
desc string
|
desc string
|
||||||
|
|||||||
@@ -71,10 +71,6 @@ type Model struct {
|
|||||||
Template *template.Template
|
Template *template.Template
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) IsMLX() bool {
|
|
||||||
return m.Config.ModelFormat == "safetensors"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Capabilities returns the capabilities that the model supports
|
// Capabilities returns the capabilities that the model supports
|
||||||
func (m *Model) Capabilities() []model.Capability {
|
func (m *Model) Capabilities() []model.Capability {
|
||||||
capabilities := []model.Capability{}
|
capabilities := []model.Capability{}
|
||||||
|
|||||||
@@ -30,44 +30,42 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
lastMsgIdx := len(msgs) - 1
|
lastMsgIdx := len(msgs) - 1
|
||||||
currMsgIdx := 0
|
currMsgIdx := 0
|
||||||
|
|
||||||
if truncate {
|
// Start with all messages and remove from the front until it fits in context
|
||||||
// Start with all messages and remove from the front until it fits in context
|
for i := 0; i <= lastMsgIdx; i++ {
|
||||||
for i := 0; i <= lastMsgIdx; i++ {
|
// Collect system messages from the portion we're about to skip
|
||||||
// Collect system messages from the portion we're about to skip
|
system = make([]api.Message, 0)
|
||||||
system = make([]api.Message, 0)
|
for j := range i {
|
||||||
for j := range i {
|
if msgs[j].Role == "system" {
|
||||||
if msgs[j].Role == "system" {
|
system = append(system, msgs[j])
|
||||||
system = append(system, msgs[j])
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
|
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
s, err := tokenize(ctx, p)
|
s, err := tokenize(ctx, p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctxLen := len(s)
|
ctxLen := len(s)
|
||||||
if m.ProjectorPaths != nil {
|
if m.ProjectorPaths != nil {
|
||||||
for _, msg := range msgs[i:] {
|
for _, msg := range msgs[i:] {
|
||||||
ctxLen += imageNumTokens * len(msg.Images)
|
ctxLen += imageNumTokens * len(msg.Images)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if ctxLen <= opts.NumCtx {
|
if !truncate || ctxLen <= opts.NumCtx {
|
||||||
currMsgIdx = i
|
currMsgIdx = i
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
// Must always include at least the last message
|
// Must always include at least the last message
|
||||||
if i == lastMsgIdx {
|
if i == lastMsgIdx {
|
||||||
currMsgIdx = lastMsgIdx
|
currMsgIdx = lastMsgIdx
|
||||||
break
|
break
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -130,35 +130,6 @@ func (s *Server) modelOptions(model *Model, requestOpts map[string]any) (api.Opt
|
|||||||
return opts, nil
|
return opts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func explicitOptions(modelOpts, requestOpts map[string]any) map[string]struct{} {
|
|
||||||
keys := []string{
|
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"min_p",
|
|
||||||
"top_k",
|
|
||||||
"repeat_last_n",
|
|
||||||
"repeat_penalty",
|
|
||||||
"presence_penalty",
|
|
||||||
"frequency_penalty",
|
|
||||||
}
|
|
||||||
|
|
||||||
explicit := make(map[string]struct{}, len(keys))
|
|
||||||
for _, key := range keys {
|
|
||||||
if optionSpecified(modelOpts, requestOpts, key) {
|
|
||||||
explicit[key] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return explicit
|
|
||||||
}
|
|
||||||
|
|
||||||
func optionSpecified(modelOpts, requestOpts map[string]any, key string) bool {
|
|
||||||
if _, ok := requestOpts[key]; ok {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
_, ok := modelOpts[key]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
|
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
|
||||||
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
|
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
|
||||||
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
|
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
|
||||||
@@ -513,8 +484,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
// the real chat handler, but doing this as a stopgap to get renderer
|
// the real chat handler, but doing this as a stopgap to get renderer
|
||||||
// support for generate
|
// support for generate
|
||||||
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
|
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
|
||||||
genTruncate := (req.Truncate == nil || *req.Truncate) && !m.IsMLX()
|
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate)
|
||||||
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, genTruncate)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
@@ -568,16 +538,14 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Images: images,
|
Images: images,
|
||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
Options: opts,
|
Options: opts,
|
||||||
Think: req.Think,
|
Shift: req.Shift == nil || *req.Shift,
|
||||||
ExplicitOptions: explicitOptions(m.Options, req.Options),
|
Truncate: req.Truncate == nil || *req.Truncate,
|
||||||
Shift: req.Shift == nil || *req.Shift,
|
Logprobs: req.Logprobs,
|
||||||
Truncate: req.Truncate == nil || *req.Truncate,
|
TopLogprobs: req.TopLogprobs,
|
||||||
Logprobs: req.Logprobs,
|
|
||||||
TopLogprobs: req.TopLogprobs,
|
|
||||||
}, func(cr llm.CompletionResponse) {
|
}, func(cr llm.CompletionResponse) {
|
||||||
res := api.GenerateResponse{
|
res := api.GenerateResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
@@ -589,7 +557,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
PromptEvalDuration: cr.PromptEvalDuration,
|
PromptEvalDuration: cr.PromptEvalDuration,
|
||||||
EvalCount: cr.EvalCount,
|
EvalCount: cr.EvalCount,
|
||||||
EvalDuration: cr.EvalDuration,
|
EvalDuration: cr.EvalDuration,
|
||||||
PeakMemory: cr.PeakMemory,
|
|
||||||
},
|
},
|
||||||
Logprobs: toAPILogprobs(cr.Logprobs),
|
Logprobs: toAPILogprobs(cr.Logprobs),
|
||||||
}
|
}
|
||||||
@@ -1984,9 +1951,6 @@ func (s *Server) PsHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
if v.llama != nil {
|
if v.llama != nil {
|
||||||
mr.ContextLength = v.llama.ContextLength()
|
mr.ContextLength = v.llama.ContextLength()
|
||||||
total, vram := v.llama.MemorySize()
|
|
||||||
mr.Size = int64(total)
|
|
||||||
mr.SizeVRAM = int64(vram)
|
|
||||||
}
|
}
|
||||||
// The scheduler waits to set expiresAt, so if a model is loading it's
|
// The scheduler waits to set expiresAt, so if a model is loading it's
|
||||||
// possible that it will be set to the unix epoch. For those cases, just
|
// possible that it will be set to the unix epoch. For those cases, just
|
||||||
@@ -2249,9 +2213,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
truncate := req.Truncate == nil || *req.Truncate
|
truncate := req.Truncate == nil || *req.Truncate
|
||||||
if m.IsMLX() {
|
|
||||||
truncate = false
|
|
||||||
}
|
|
||||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("chat prompt error", "error", err)
|
slog.Error("chat prompt error", "error", err)
|
||||||
@@ -2329,16 +2290,14 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
// sets up new context given parent context per request
|
// sets up new context given parent context per request
|
||||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||||
err := r.Completion(ctx, llm.CompletionRequest{
|
err := r.Completion(ctx, llm.CompletionRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Images: images,
|
Images: images,
|
||||||
Format: currentFormat,
|
Format: currentFormat,
|
||||||
Options: opts,
|
Options: opts,
|
||||||
Think: req.Think,
|
Shift: req.Shift == nil || *req.Shift,
|
||||||
ExplicitOptions: explicitOptions(m.Options, req.Options),
|
Truncate: truncate,
|
||||||
Shift: req.Shift == nil || *req.Shift,
|
Logprobs: req.Logprobs,
|
||||||
Truncate: truncate,
|
TopLogprobs: req.TopLogprobs,
|
||||||
Logprobs: req.Logprobs,
|
|
||||||
TopLogprobs: req.TopLogprobs,
|
|
||||||
}, func(r llm.CompletionResponse) {
|
}, func(r llm.CompletionResponse) {
|
||||||
res := api.ChatResponse{
|
res := api.ChatResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
@@ -2350,7 +2309,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
PromptEvalDuration: r.PromptEvalDuration,
|
PromptEvalDuration: r.PromptEvalDuration,
|
||||||
EvalCount: r.EvalCount,
|
EvalCount: r.EvalCount,
|
||||||
EvalDuration: r.EvalDuration,
|
EvalDuration: r.EvalDuration,
|
||||||
PeakMemory: r.PeakMemory,
|
|
||||||
},
|
},
|
||||||
Logprobs: toAPILogprobs(r.Logprobs),
|
Logprobs: toAPILogprobs(r.Logprobs),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -231,7 +231,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check for experimental safetensors LLM models
|
// Check for experimental safetensors LLM models
|
||||||
if pending.model.IsMLX() {
|
if pending.model.Config.ModelFormat == "safetensors" {
|
||||||
if slices.Contains(pending.model.Config.Capabilities, "completion") {
|
if slices.Contains(pending.model.Config.Capabilities, "completion") {
|
||||||
// LLM model with safetensors format - use MLX runner
|
// LLM model with safetensors format - use MLX runner
|
||||||
if s.loadMLX(pending) {
|
if s.loadMLX(pending) {
|
||||||
@@ -536,7 +536,6 @@ iGPUScan:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
totalSize, vramSize := llama.MemorySize()
|
|
||||||
runner := &runnerRef{
|
runner := &runnerRef{
|
||||||
model: req.model,
|
model: req.model,
|
||||||
modelPath: req.model.ModelPath,
|
modelPath: req.model.ModelPath,
|
||||||
@@ -546,8 +545,8 @@ iGPUScan:
|
|||||||
sessionDuration: sessionDuration,
|
sessionDuration: sessionDuration,
|
||||||
gpus: gpuIDs,
|
gpus: gpuIDs,
|
||||||
discreteGPUs: discreteGPUs,
|
discreteGPUs: discreteGPUs,
|
||||||
totalSize: totalSize,
|
vramSize: llama.VRAMSize(),
|
||||||
vramSize: vramSize,
|
totalSize: llama.TotalSize(),
|
||||||
loading: true,
|
loading: true,
|
||||||
pid: llama.Pid(),
|
pid: llama.Pid(),
|
||||||
}
|
}
|
||||||
@@ -620,7 +619,6 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
|||||||
sessionDuration = req.sessionDuration.Duration
|
sessionDuration = req.sessionDuration.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
totalSize, vramSize := server.MemorySize()
|
|
||||||
runner := &runnerRef{
|
runner := &runnerRef{
|
||||||
model: req.model,
|
model: req.model,
|
||||||
modelPath: req.model.ModelPath,
|
modelPath: req.model.ModelPath,
|
||||||
@@ -630,8 +628,8 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
|||||||
loading: false,
|
loading: false,
|
||||||
isImagegen: isImagegen,
|
isImagegen: isImagegen,
|
||||||
sessionDuration: sessionDuration,
|
sessionDuration: sessionDuration,
|
||||||
totalSize: totalSize,
|
totalSize: server.TotalSize(),
|
||||||
vramSize: vramSize,
|
vramSize: server.VRAMSize(),
|
||||||
}
|
}
|
||||||
|
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
@@ -764,7 +762,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
|
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
|
||||||
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
|
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
|
||||||
(!runner.model.IsMLX() && !reflect.DeepEqual(optsExisting, optsNew)) || // have the runner options changed?
|
!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
|
||||||
runner.llama.Ping(ctx) != nil {
|
runner.llama.Ping(ctx) != nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -861,7 +861,8 @@ func (s *mockLlm) Close() error {
|
|||||||
s.closeCalled = true
|
s.closeCalled = true
|
||||||
return s.closeResp
|
return s.closeResp
|
||||||
}
|
}
|
||||||
func (s *mockLlm) MemorySize() (uint64, uint64) { return s.totalSize, s.vramSize }
|
func (s *mockLlm) VRAMSize() uint64 { return s.vramSize }
|
||||||
|
func (s *mockLlm) TotalSize() uint64 { return s.totalSize }
|
||||||
func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] }
|
func (s *mockLlm) VRAMByGPU(id ml.DeviceID) uint64 { return s.vramByGPU[id] }
|
||||||
func (s *mockLlm) Pid() int { return -1 }
|
func (s *mockLlm) Pid() int { return -1 }
|
||||||
func (s *mockLlm) GetPort() int { return -1 }
|
func (s *mockLlm) GetPort() int { return -1 }
|
||||||
|
|||||||
@@ -288,18 +288,6 @@ func normalizeQuantType(quantize string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func isStackedExpertWeight(name string) bool {
|
|
||||||
// Combined/stacked expert tensors may be emitted either as "...proj.weight" (per-expert)
|
|
||||||
// or "...proj" (pre-stacked packed tensor).
|
|
||||||
if strings.HasSuffix(name, ".bias") || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".qbias") {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return strings.Contains(name, ".mlp.switch_mlp.") ||
|
|
||||||
strings.Contains(name, ".mlp.experts.") ||
|
|
||||||
strings.Contains(name, ".mlp.shared_experts.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetTensorQuantization returns the appropriate quantization type for a tensor.
|
// GetTensorQuantization returns the appropriate quantization type for a tensor.
|
||||||
// Returns "" if the tensor should not be quantized.
|
// Returns "" if the tensor should not be quantized.
|
||||||
// This implements mixed-precision quantization:
|
// This implements mixed-precision quantization:
|
||||||
@@ -308,25 +296,18 @@ func isStackedExpertWeight(name string) bool {
|
|||||||
// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
|
// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
|
||||||
// - Norms, embeddings, biases, routing gates: no quantization
|
// - Norms, embeddings, biases, routing gates: no quantization
|
||||||
func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||||
stackedExpert := isStackedExpertWeight(name)
|
|
||||||
|
|
||||||
// Use basic name-based check first
|
// Use basic name-based check first
|
||||||
if !stackedExpert && !ShouldQuantize(name, "") {
|
if !ShouldQuantize(name, "") {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// Quantize standard linear weights (2D). Also allow stacked expert weights (3D),
|
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
|
||||||
// e.g. qwen switch_mlp / experts combined tensors.
|
if len(shape) != 2 {
|
||||||
if len(shape) != 2 && !(len(shape) == 3 && stackedExpert) {
|
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// Skip small tensors (less than 1024 elements) - not worth quantizing
|
// Skip small tensors (less than 1024 elements) - not worth quantizing
|
||||||
var elems int64 = 1
|
if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 {
|
||||||
for _, d := range shape {
|
|
||||||
elems *= int64(d)
|
|
||||||
}
|
|
||||||
if elems < 1024 {
|
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -557,10 +557,6 @@ func TestShouldQuantizeTensor(t *testing.T) {
|
|||||||
// 3D+ tensors should not be quantized
|
// 3D+ tensors should not be quantized
|
||||||
{"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
|
{"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
|
||||||
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
|
{"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
|
||||||
{"stacked expert switch_mlp gate_up 3D int8", "model.layers.1.mlp.switch_mlp.gate_up_proj.weight", []int32{64, 22016, 4096}, "int8", true},
|
|
||||||
{"stacked expert experts down_proj 3D int8", "model.layers.1.mlp.experts.down_proj.weight", []int32{64, 4096, 14336}, "int8", true},
|
|
||||||
{"stacked expert combined gate_up 3D int8", "model.language_model.layers.0.mlp.experts.gate_up_proj", []int32{256, 1024, 2048}, "int8", true},
|
|
||||||
{"stacked expert combined down_proj 3D int8", "model.language_model.layers.0.mlp.experts.down_proj", []int32{256, 2048, 512}, "int8", true},
|
|
||||||
|
|
||||||
// Embeddings should not be quantized regardless of shape
|
// Embeddings should not be quantized regardless of shape
|
||||||
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
|
{"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
|
||||||
@@ -623,44 +619,6 @@ func TestExpertGroupPrefix(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetTensorQuantization_StackedExpert3D(t *testing.T) {
|
|
||||||
gateUp := GetTensorQuantization(
|
|
||||||
"model.layers.1.mlp.switch_mlp.gate_up_proj.weight",
|
|
||||||
[]int32{64, 22016, 4096},
|
|
||||||
"int4",
|
|
||||||
)
|
|
||||||
if gateUp != "int4" {
|
|
||||||
t.Fatalf("gate_up_proj quantization = %q, want %q", gateUp, "int4")
|
|
||||||
}
|
|
||||||
|
|
||||||
down := GetTensorQuantization(
|
|
||||||
"model.layers.1.mlp.experts.down_proj.weight",
|
|
||||||
[]int32{64, 4096, 14336},
|
|
||||||
"int4",
|
|
||||||
)
|
|
||||||
if down != "int8" {
|
|
||||||
t.Fatalf("down_proj quantization = %q, want %q", down, "int8")
|
|
||||||
}
|
|
||||||
|
|
||||||
combinedGateUp := GetTensorQuantization(
|
|
||||||
"model.language_model.layers.0.mlp.experts.gate_up_proj",
|
|
||||||
[]int32{256, 1024, 2048},
|
|
||||||
"int8",
|
|
||||||
)
|
|
||||||
if combinedGateUp != "int8" {
|
|
||||||
t.Fatalf("combined gate_up_proj quantization = %q, want %q", combinedGateUp, "int8")
|
|
||||||
}
|
|
||||||
|
|
||||||
combinedDown := GetTensorQuantization(
|
|
||||||
"model.language_model.layers.0.mlp.experts.down_proj",
|
|
||||||
[]int32{256, 2048, 512},
|
|
||||||
"int4",
|
|
||||||
)
|
|
||||||
if combinedDown != "int8" {
|
|
||||||
t.Fatalf("combined down_proj quantization = %q, want %q", combinedDown, "int8")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
|
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
|
||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
|
|
||||||
|
|||||||
@@ -374,9 +374,14 @@ func (s *Server) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MemorySize returns the total and VRAM memory usage.
|
// VRAMSize returns the estimated VRAM usage.
|
||||||
func (s *Server) MemorySize() (total, vram uint64) {
|
func (s *Server) VRAMSize() uint64 {
|
||||||
return s.vramSize, s.vramSize
|
return s.vramSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotalSize returns the total memory usage.
|
||||||
|
func (s *Server) TotalSize() uint64 {
|
||||||
|
return s.vramSize
|
||||||
}
|
}
|
||||||
|
|
||||||
// VRAMByGPU returns VRAM usage for a specific GPU.
|
// VRAMByGPU returns VRAM usage for a specific GPU.
|
||||||
|
|||||||
@@ -30,64 +30,21 @@ type cacheSession struct {
|
|||||||
remaining []int32
|
remaining []int32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *kvCache) free() {
|
|
||||||
for i, kv := range c.caches {
|
|
||||||
if kv == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
kv.Free()
|
|
||||||
c.caches[i] = nil
|
|
||||||
}
|
|
||||||
c.caches = nil
|
|
||||||
c.tokens = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *kvCache) cachesCanTrim() bool {
|
|
||||||
for _, kv := range c.caches {
|
|
||||||
if kv == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !kv.CanTrim() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *kvCache) trimToPrefix(prefix int) {
|
|
||||||
for _, kv := range c.caches {
|
|
||||||
if kv == nil || !kv.CanTrim() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if trim := kv.Offset() - prefix; trim > 0 {
|
|
||||||
kv.Trim(trim)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if prefix < len(c.tokens) {
|
|
||||||
c.tokens = c.tokens[:prefix]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// begin prepares caches for a new request. It finds the nearest
|
// begin prepares caches for a new request. It finds the nearest
|
||||||
// matching cache or creates new caches if none match.
|
// matching cache or creates new caches if none match.
|
||||||
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
||||||
ensureCaches := func() {
|
if len(c.caches) == 0 {
|
||||||
if len(c.caches) != 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
|
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
|
||||||
c.caches = cacheFactory.NewCaches()
|
c.caches = cacheFactory.NewCaches()
|
||||||
return
|
} else {
|
||||||
}
|
c.caches = make([]cache.Cache, m.NumLayers())
|
||||||
c.caches = make([]cache.Cache, m.NumLayers())
|
for i := range c.caches {
|
||||||
for i := range c.caches {
|
c.caches[i] = cache.NewKVCache()
|
||||||
c.caches[i] = cache.NewKVCache()
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ensureCaches()
|
|
||||||
|
|
||||||
remaining := c.findRemaining(inputs)
|
remaining := c.findRemaining(inputs)
|
||||||
ensureCaches()
|
|
||||||
|
|
||||||
return &cacheSession{
|
return &cacheSession{
|
||||||
cache: c,
|
cache: c,
|
||||||
@@ -99,34 +56,18 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
|||||||
|
|
||||||
// close saves the token state if the forward pass ran.
|
// close saves the token state if the forward pass ran.
|
||||||
func (s *cacheSession) close() {
|
func (s *cacheSession) close() {
|
||||||
if len(s.caches) == 0 {
|
if offset := s.caches[0].Offset(); offset > 0 {
|
||||||
return
|
// Ensure that if we have run the forward pass and set the metadata
|
||||||
}
|
// that we also actually have the data
|
||||||
|
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
|
||||||
offset := -1
|
for _, c := range s.caches {
|
||||||
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
|
k, v := c.State()
|
||||||
for _, kv := range s.caches {
|
arrays = append(arrays, k, v)
|
||||||
if kv == nil {
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
if off := kv.Offset(); offset < 0 || off < offset {
|
mlx.AsyncEval(arrays...)
|
||||||
offset = off
|
|
||||||
}
|
|
||||||
arrays = append(arrays, kv.Materialize()...)
|
|
||||||
}
|
|
||||||
if offset <= 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure that if we have run the forward pass and set the metadata
|
s.cache.tokens = append(s.inputs, s.outputs...)[:offset]
|
||||||
// that we also actually have the data.
|
|
||||||
mlx.AsyncEval(arrays...)
|
|
||||||
|
|
||||||
stored := append(s.inputs, s.outputs...)
|
|
||||||
if offset > len(stored) {
|
|
||||||
offset = len(stored)
|
|
||||||
}
|
}
|
||||||
s.cache.tokens = stored[:offset]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// findRemaining finds the longest common prefix between tokens and the cached
|
// findRemaining finds the longest common prefix between tokens and the cached
|
||||||
@@ -137,20 +78,17 @@ func (c *kvCache) findRemaining(tokens []int32) []int32 {
|
|||||||
prefix++
|
prefix++
|
||||||
}
|
}
|
||||||
|
|
||||||
// Always keep at least one token to re-evaluate so the
|
|
||||||
// pipeline can seed token generation from it.
|
|
||||||
if prefix == len(tokens) && prefix > 0 {
|
if prefix == len(tokens) && prefix > 0 {
|
||||||
|
// Leave one token to run through the model so we can sample a response.
|
||||||
prefix--
|
prefix--
|
||||||
}
|
}
|
||||||
|
|
||||||
if prefix < len(c.tokens) {
|
if prefix < len(c.tokens) {
|
||||||
if c.cachesCanTrim() {
|
trim := len(c.tokens) - prefix
|
||||||
c.trimToPrefix(prefix)
|
for _, kv := range c.caches {
|
||||||
} else {
|
kv.Trim(trim)
|
||||||
c.free()
|
|
||||||
slog.Info("Cache miss", "left", len(tokens), "matched", prefix, "reason", "non_trimmable_divergence")
|
|
||||||
return tokens
|
|
||||||
}
|
}
|
||||||
|
c.tokens = c.tokens[:prefix]
|
||||||
}
|
}
|
||||||
|
|
||||||
if prefix == 0 {
|
if prefix == 0 {
|
||||||
@@ -165,21 +103,10 @@ func (c *kvCache) log() {
|
|||||||
if len(c.caches) == 0 {
|
if len(c.caches) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
offset := -1
|
|
||||||
var totalBytes int
|
var totalBytes int
|
||||||
for _, kv := range c.caches {
|
for _, kv := range c.caches {
|
||||||
if kv == nil {
|
k, v := kv.State()
|
||||||
continue
|
totalBytes += k.NumBytes() + v.NumBytes()
|
||||||
}
|
|
||||||
if off := kv.Offset(); offset < 0 || off < offset {
|
|
||||||
offset = off
|
|
||||||
}
|
|
||||||
for _, a := range kv.Materialize() {
|
|
||||||
totalBytes += a.NumBytes()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if offset < 0 {
|
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
|
||||||
return
|
|
||||||
}
|
|
||||||
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", offset, mlx.PrettyBytes(totalBytes)))
|
|
||||||
}
|
}
|
||||||
|
|||||||
18
x/mlxrunner/cache/cache.go
vendored
18
x/mlxrunner/cache/cache.go
vendored
@@ -10,8 +10,6 @@ import (
|
|||||||
type Cache interface {
|
type Cache interface {
|
||||||
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
|
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
|
||||||
State() (keys, values *mlx.Array)
|
State() (keys, values *mlx.Array)
|
||||||
Materialize() []*mlx.Array
|
|
||||||
CanTrim() bool
|
|
||||||
Trim(int) int
|
Trim(int) int
|
||||||
Clone() Cache
|
Clone() Cache
|
||||||
Free()
|
Free()
|
||||||
@@ -69,20 +67,6 @@ func (c *KVCache) State() (*mlx.Array, *mlx.Array) {
|
|||||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Materialize returns the backing key/value buffers currently held by the cache.
|
|
||||||
func (c *KVCache) Materialize() []*mlx.Array {
|
|
||||||
out := make([]*mlx.Array, 0, 2)
|
|
||||||
if c.keys != nil && c.keys.Valid() {
|
|
||||||
out = append(out, c.keys)
|
|
||||||
}
|
|
||||||
if c.values != nil && c.values.Valid() {
|
|
||||||
out = append(out, c.values)
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *KVCache) CanTrim() bool { return true }
|
|
||||||
|
|
||||||
func (c *KVCache) Trim(n int) int {
|
func (c *KVCache) Trim(n int) int {
|
||||||
n = min(c.offset, n)
|
n = min(c.offset, n)
|
||||||
c.offset -= n
|
c.offset -= n
|
||||||
@@ -206,8 +190,6 @@ func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) {
|
|||||||
return c.keys, c.values
|
return c.keys, c.values
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RotatingKVCache) CanTrim() bool { return true }
|
|
||||||
|
|
||||||
func (c *RotatingKVCache) Trim(n int) int {
|
func (c *RotatingKVCache) Trim(n int) int {
|
||||||
n = min(c.offset, n)
|
n = min(c.offset, n)
|
||||||
c.offset -= n
|
c.offset -= n
|
||||||
|
|||||||
220
x/mlxrunner/cache/recurrent.go
vendored
220
x/mlxrunner/cache/recurrent.go
vendored
@@ -1,220 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package cache
|
|
||||||
|
|
||||||
import "github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
||||||
|
|
||||||
// RecurrentCache stores state for linear-recurrent layers.
|
|
||||||
//
|
|
||||||
// Conv state shape: [B, convTail, convDim]
|
|
||||||
// Delta state shape: [B, numVHeads, headVDim, headKDim]
|
|
||||||
type RecurrentCache struct {
|
|
||||||
convState *mlx.Array
|
|
||||||
deltaState *mlx.Array
|
|
||||||
offset int
|
|
||||||
|
|
||||||
convTail int
|
|
||||||
convDim int
|
|
||||||
numVHeads int
|
|
||||||
headVDim int
|
|
||||||
headKDim int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *RecurrentCache) setStateMaterialized(dst **mlx.Array, v *mlx.Array) {
|
|
||||||
if v == nil || !v.Valid() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if *dst == v {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Break dependency chains so recurrent state does not retain the full
|
|
||||||
// per-token compute graph over time.
|
|
||||||
snap := mlx.Snapshot(v)
|
|
||||||
mlx.Eval(snap)
|
|
||||||
|
|
||||||
old := *dst
|
|
||||||
*dst = snap
|
|
||||||
mlx.Pin(snap)
|
|
||||||
|
|
||||||
// Drop references to the previous cached state root and transient incoming
|
|
||||||
// graph root now that a detached snapshot is retained in cache. Actual
|
|
||||||
// cleanup happens at the runner's normal sweep points.
|
|
||||||
if old != nil && old != snap {
|
|
||||||
mlx.Unpin(old)
|
|
||||||
}
|
|
||||||
if v != snap && v != old {
|
|
||||||
mlx.Unpin(v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *RecurrentCache) setStateRaw(dst **mlx.Array, v *mlx.Array) {
|
|
||||||
if v == nil || !v.Valid() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if *dst == v {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
old := *dst
|
|
||||||
*dst = v
|
|
||||||
mlx.Pin(v)
|
|
||||||
if old != nil && old != v {
|
|
||||||
mlx.Unpin(old)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *RecurrentCache) setStateDetached(dst **mlx.Array, v *mlx.Array, ensureContiguous bool) {
|
|
||||||
if v == nil || !v.Valid() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if *dst == v {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
root := v
|
|
||||||
if ensureContiguous {
|
|
||||||
root = mlx.Contiguous(v, false)
|
|
||||||
}
|
|
||||||
detached := mlx.Detach(root)
|
|
||||||
|
|
||||||
old := *dst
|
|
||||||
*dst = detached
|
|
||||||
mlx.Pin(detached)
|
|
||||||
if old != nil && old != detached {
|
|
||||||
mlx.Unpin(old)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Intentionally do not force-release root/v here. In the fast path, the detached
|
|
||||||
// handle aliases the same MLX value and may still be lazily computed. Releasing the
|
|
||||||
// source handles can invalidate the cached state before the next eval/sweep point.
|
|
||||||
}
|
|
||||||
|
|
||||||
func snapshotPinned(a *mlx.Array) *mlx.Array {
|
|
||||||
if a == nil || !a.Valid() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
snap := mlx.Snapshot(a)
|
|
||||||
mlx.Eval(snap)
|
|
||||||
mlx.Pin(snap)
|
|
||||||
return snap
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim int32) *RecurrentCache {
|
|
||||||
return &RecurrentCache{
|
|
||||||
convTail: int(convTail),
|
|
||||||
convDim: int(convDim),
|
|
||||||
numVHeads: int(numVHeads),
|
|
||||||
headVDim: int(headVDim),
|
|
||||||
headKDim: int(headKDim),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *RecurrentCache) ensure(batch int, dtype mlx.DType) {
|
|
||||||
if batch <= 0 {
|
|
||||||
batch = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
needConv := c.convState == nil || !c.convState.Valid() || c.convState.DType() != dtype ||
|
|
||||||
c.convState.Dim(0) != batch || c.convState.Dim(1) != c.convTail || c.convState.Dim(2) != c.convDim
|
|
||||||
needDelta := c.deltaState == nil || !c.deltaState.Valid() || c.deltaState.DType() != dtype ||
|
|
||||||
c.deltaState.Dim(0) != batch || c.deltaState.Dim(1) != c.numVHeads || c.deltaState.Dim(2) != c.headVDim || c.deltaState.Dim(3) != c.headKDim
|
|
||||||
if !needConv && !needDelta {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if needConv {
|
|
||||||
c.setStateRaw(&c.convState, mlx.Zeros(dtype, batch, c.convTail, c.convDim))
|
|
||||||
}
|
|
||||||
if needDelta {
|
|
||||||
c.setStateRaw(&c.deltaState, mlx.Zeros(dtype, batch, c.numVHeads, c.headVDim, c.headKDim))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *RecurrentCache) ConvState(batch int, dtype mlx.DType) *mlx.Array {
|
|
||||||
c.ensure(batch, dtype)
|
|
||||||
return c.convState
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *RecurrentCache) SetConvState(v *mlx.Array) {
|
|
||||||
c.setStateMaterialized(&c.convState, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetConvStateFast stores conv state without forcing an immediate snapshot/eval.
|
|
||||||
// Use only for decode hot paths that accept higher transient memory until the next
|
|
||||||
// sync/sweep point. The conv-state input is usually a slice view, so request a
|
|
||||||
// compact contiguous copy to avoid pinning the whole source buffer.
|
|
||||||
func (c *RecurrentCache) SetConvStateFast(v *mlx.Array) {
|
|
||||||
c.setStateDetached(&c.convState, v, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *RecurrentCache) DeltaState(batch int, dtype mlx.DType) *mlx.Array {
|
|
||||||
c.ensure(batch, dtype)
|
|
||||||
return c.deltaState
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *RecurrentCache) SetDeltaState(v *mlx.Array) {
|
|
||||||
c.setStateMaterialized(&c.deltaState, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDeltaStateFast stores delta state without forcing an immediate snapshot/eval.
|
|
||||||
// Use only for decode hot paths that accept higher transient memory until the next
|
|
||||||
// sync/sweep point.
|
|
||||||
func (c *RecurrentCache) SetDeltaStateFast(v *mlx.Array) {
|
|
||||||
c.setStateDetached(&c.deltaState, v, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *RecurrentCache) Advance(n int) {
|
|
||||||
c.offset += n
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *RecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
|
||||||
return keys, values
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *RecurrentCache) State() (*mlx.Array, *mlx.Array) {
|
|
||||||
return c.convState, c.deltaState
|
|
||||||
}
|
|
||||||
|
|
||||||
// Materialize returns the recurrent state roots (conv and delta) held by the cache.
|
|
||||||
func (c *RecurrentCache) Materialize() []*mlx.Array {
|
|
||||||
out := make([]*mlx.Array, 0, 2)
|
|
||||||
if c.convState != nil && c.convState.Valid() {
|
|
||||||
out = append(out, c.convState)
|
|
||||||
}
|
|
||||||
if c.deltaState != nil && c.deltaState.Valid() {
|
|
||||||
out = append(out, c.deltaState)
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *RecurrentCache) CanTrim() bool { return false }
|
|
||||||
|
|
||||||
func (c *RecurrentCache) Trim(n int) int {
|
|
||||||
// Recurrent state is not directly trimmable. Divergent prefixes must drop the cache.
|
|
||||||
_ = n
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *RecurrentCache) Clone() Cache {
|
|
||||||
clone := &RecurrentCache{
|
|
||||||
offset: c.offset,
|
|
||||||
convTail: c.convTail,
|
|
||||||
convDim: c.convDim,
|
|
||||||
numVHeads: c.numVHeads,
|
|
||||||
headVDim: c.headVDim,
|
|
||||||
headKDim: c.headKDim,
|
|
||||||
convState: snapshotPinned(c.convState),
|
|
||||||
deltaState: snapshotPinned(c.deltaState),
|
|
||||||
}
|
|
||||||
return clone
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *RecurrentCache) Free() {
|
|
||||||
mlx.Unpin(c.convState, c.deltaState)
|
|
||||||
c.convState, c.deltaState = nil, nil
|
|
||||||
c.offset = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *RecurrentCache) Offset() int { return c.offset }
|
|
||||||
func (c *RecurrentCache) Len() int { return c.offset }
|
|
||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"math"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -18,27 +19,25 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/x/imagegen"
|
"github.com/ollama/ollama/x/imagegen"
|
||||||
|
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models.
|
// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
port int
|
port int
|
||||||
modelName string
|
modelName string
|
||||||
contextLength atomic.Int64
|
vramSize uint64
|
||||||
memory atomic.Uint64
|
done chan error
|
||||||
done chan error
|
client *http.Client
|
||||||
client *http.Client
|
lastErr string
|
||||||
lastErr string
|
lastErrLock sync.Mutex
|
||||||
lastErrLock sync.Mutex
|
mu sync.Mutex
|
||||||
mu sync.Mutex
|
cmd *exec.Cmd
|
||||||
cmd *exec.Cmd
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient spawns a new MLX runner subprocess for LLM models and waits until it's ready.
|
// NewClient spawns a new MLX runner subprocess for LLM models and waits until it's ready.
|
||||||
@@ -99,9 +98,18 @@ func NewClient(modelName string) (*Client, error) {
|
|||||||
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Estimate VRAM based on tensor size from manifest
|
||||||
|
var vramSize uint64
|
||||||
|
if modelManifest, err := manifest.LoadManifest(modelName); err == nil {
|
||||||
|
vramSize = uint64(modelManifest.TotalTensorSize())
|
||||||
|
} else {
|
||||||
|
vramSize = 8 * 1024 * 1024 * 1024
|
||||||
|
}
|
||||||
|
|
||||||
c := &Client{
|
c := &Client{
|
||||||
port: port,
|
port: port,
|
||||||
modelName: modelName,
|
modelName: modelName,
|
||||||
|
vramSize: vramSize,
|
||||||
done: make(chan error, 1),
|
done: make(chan error, 1),
|
||||||
client: &http.Client{Timeout: 10 * time.Minute},
|
client: &http.Client{Timeout: 10 * time.Minute},
|
||||||
cmd: cmd,
|
cmd: cmd,
|
||||||
@@ -182,34 +190,15 @@ func (c *Client) waitUntilRunning() error {
|
|||||||
// completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization.
|
// completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization.
|
||||||
type completionRequest struct {
|
type completionRequest struct {
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
Think *bool `json:"think,omitempty"`
|
|
||||||
Options *completionOpts `json:"options,omitempty"`
|
Options *completionOpts `json:"options,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type completionOpts struct {
|
type completionOpts struct {
|
||||||
Temperature *float32 `json:"temperature,omitempty"`
|
Temperature float32 `json:"temperature,omitempty"`
|
||||||
TopP *float32 `json:"top_p,omitempty"`
|
TopP float32 `json:"top_p,omitempty"`
|
||||||
MinP *float32 `json:"min_p,omitempty"`
|
MinP float32 `json:"min_p,omitempty"`
|
||||||
TopK *int `json:"top_k,omitempty"`
|
TopK int `json:"top_k,omitempty"`
|
||||||
RepeatLastN *int `json:"repeat_last_n,omitempty"`
|
NumPredict int `json:"num_predict,omitempty"`
|
||||||
RepeatPenalty *float32 `json:"repeat_penalty,omitempty"`
|
|
||||||
PresencePenalty *float32 `json:"presence_penalty,omitempty"`
|
|
||||||
FrequencyPenalty *float32 `json:"frequency_penalty,omitempty"`
|
|
||||||
NumPredict int `json:"num_predict,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type CompletionResponse struct {
|
|
||||||
Content string
|
|
||||||
Done bool
|
|
||||||
DoneReason int
|
|
||||||
|
|
||||||
PromptEvalCount int
|
|
||||||
PromptEvalDuration time.Duration
|
|
||||||
EvalCount int
|
|
||||||
EvalDuration time.Duration
|
|
||||||
PeakMemory uint64
|
|
||||||
|
|
||||||
Error *api.StatusError
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close terminates the subprocess.
|
// Close terminates the subprocess.
|
||||||
@@ -233,27 +222,16 @@ func (c *Client) Close() error {
|
|||||||
|
|
||||||
// Completion implements llm.LlamaServer.
|
// Completion implements llm.LlamaServer.
|
||||||
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||||
var think *bool
|
|
||||||
if req.Think != nil {
|
|
||||||
enabled := req.Think.Bool()
|
|
||||||
think = &enabled
|
|
||||||
}
|
|
||||||
|
|
||||||
creq := completionRequest{
|
creq := completionRequest{
|
||||||
Prompt: req.Prompt,
|
Prompt: req.Prompt,
|
||||||
Think: think,
|
|
||||||
}
|
}
|
||||||
if req.Options != nil {
|
if req.Options != nil {
|
||||||
creq.Options = &completionOpts{
|
creq.Options = &completionOpts{
|
||||||
Temperature: float32Ptr(req.Options.Temperature, hasExplicitOption(req.ExplicitOptions, "temperature")),
|
Temperature: req.Options.Temperature,
|
||||||
TopP: float32Ptr(req.Options.TopP, hasExplicitOption(req.ExplicitOptions, "top_p")),
|
TopP: req.Options.TopP,
|
||||||
MinP: float32Ptr(req.Options.MinP, hasExplicitOption(req.ExplicitOptions, "min_p")),
|
MinP: req.Options.MinP,
|
||||||
TopK: intPtr(req.Options.TopK, hasExplicitOption(req.ExplicitOptions, "top_k")),
|
TopK: req.Options.TopK,
|
||||||
RepeatLastN: intPtr(req.Options.RepeatLastN, hasExplicitOption(req.ExplicitOptions, "repeat_last_n")),
|
NumPredict: req.Options.NumPredict,
|
||||||
RepeatPenalty: float32Ptr(req.Options.RepeatPenalty, hasExplicitOption(req.ExplicitOptions, "repeat_penalty")),
|
|
||||||
PresencePenalty: float32Ptr(req.Options.PresencePenalty, hasExplicitOption(req.ExplicitOptions, "presence_penalty")),
|
|
||||||
FrequencyPenalty: float32Ptr(req.Options.FrequencyPenalty, hasExplicitOption(req.ExplicitOptions, "frequency_penalty")),
|
|
||||||
NumPredict: req.Options.NumPredict,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -282,25 +260,28 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
var raw CompletionResponse
|
var raw struct {
|
||||||
|
Content string `json:"content,omitempty"`
|
||||||
|
Done bool `json:"done"`
|
||||||
|
DoneReason int `json:"done_reason,omitempty"`
|
||||||
|
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||||
|
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
|
||||||
|
EvalCount int `json:"eval_count,omitempty"`
|
||||||
|
EvalDuration int `json:"eval_duration,omitempty"`
|
||||||
|
}
|
||||||
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
||||||
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
|
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if raw.Error != nil {
|
|
||||||
return *raw.Error
|
|
||||||
}
|
|
||||||
|
|
||||||
cresp := llm.CompletionResponse{
|
cresp := llm.CompletionResponse{
|
||||||
Content: raw.Content,
|
Content: raw.Content,
|
||||||
Done: raw.Done,
|
Done: raw.Done,
|
||||||
DoneReason: llm.DoneReason(raw.DoneReason),
|
DoneReason: llm.DoneReason(raw.DoneReason),
|
||||||
PromptEvalCount: raw.PromptEvalCount,
|
PromptEvalCount: raw.PromptEvalCount,
|
||||||
PromptEvalDuration: raw.PromptEvalDuration,
|
PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
|
||||||
EvalCount: raw.EvalCount,
|
EvalCount: raw.EvalCount,
|
||||||
EvalDuration: raw.EvalDuration,
|
EvalDuration: time.Duration(raw.EvalDuration),
|
||||||
PeakMemory: raw.PeakMemory,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn(cresp)
|
fn(cresp)
|
||||||
@@ -312,27 +293,8 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||||||
return scanner.Err()
|
return scanner.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func hasExplicitOption(explicit map[string]struct{}, key string) bool {
|
|
||||||
_, ok := explicit[key]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func float32Ptr(v float32, ok bool) *float32 {
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return &v
|
|
||||||
}
|
|
||||||
|
|
||||||
func intPtr(v int, ok bool) *int {
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return &v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) ContextLength() int {
|
func (c *Client) ContextLength() int {
|
||||||
return int(c.contextLength.Load())
|
return math.MaxInt
|
||||||
}
|
}
|
||||||
|
|
||||||
// Detokenize implements llm.LlamaServer.
|
// Detokenize implements llm.LlamaServer.
|
||||||
@@ -385,16 +347,9 @@ func (c *Client) Pid() int {
|
|||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
type statusResponse struct {
|
|
||||||
Status int
|
|
||||||
Progress int
|
|
||||||
ContextLength int
|
|
||||||
Memory uint64
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ping implements llm.LlamaServer.
|
// Ping implements llm.LlamaServer.
|
||||||
func (c *Client) Ping(ctx context.Context) error {
|
func (c *Client) Ping(ctx context.Context) error {
|
||||||
reqURL := fmt.Sprintf("http://127.0.0.1:%d/v1/status", c.port)
|
reqURL := fmt.Sprintf("http://127.0.0.1:%d/health", c.port)
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -407,15 +362,6 @@ func (c *Client) Ping(ctx context.Context) error {
|
|||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return fmt.Errorf("health check failed: %d", resp.StatusCode)
|
return fmt.Errorf("health check failed: %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
var status statusResponse
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&status); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.contextLength.Store(int64(status.ContextLength))
|
|
||||||
c.memory.Store(status.Memory)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -442,24 +388,19 @@ func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
|
|||||||
return tokens, nil
|
return tokens, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) currentMemory() uint64 {
|
// TotalSize implements llm.LlamaServer.
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
func (c *Client) TotalSize() uint64 {
|
||||||
defer cancel()
|
return c.vramSize
|
||||||
if err := c.Ping(ctx); err != nil {
|
|
||||||
slog.Warn("failed to get current memory", "error", err)
|
|
||||||
}
|
|
||||||
return c.memory.Load()
|
|
||||||
}
|
|
||||||
|
|
||||||
// MemorySize implements llm.LlamaServer.
|
|
||||||
func (c *Client) MemorySize() (total, vram uint64) {
|
|
||||||
mem := c.currentMemory()
|
|
||||||
return mem, mem
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// VRAMByGPU implements llm.LlamaServer.
|
// VRAMByGPU implements llm.LlamaServer.
|
||||||
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
|
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||||
return c.currentMemory()
|
return c.vramSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// VRAMSize implements llm.LlamaServer.
|
||||||
|
func (c *Client) VRAMSize() uint64 {
|
||||||
|
return c.vramSize
|
||||||
}
|
}
|
||||||
|
|
||||||
// WaitUntilRunning implements llm.LlamaServer.
|
// WaitUntilRunning implements llm.LlamaServer.
|
||||||
|
|||||||
@@ -1,167 +0,0 @@
|
|||||||
package mlxrunner
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
"github.com/ollama/ollama/llm"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCompletionForwardsThink(t *testing.T) {
|
|
||||||
boolPtr := func(v bool) *bool { return &v }
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
think *api.ThinkValue
|
|
||||||
want *bool
|
|
||||||
}{
|
|
||||||
{name: "unset", think: nil, want: nil},
|
|
||||||
{name: "enabled", think: &api.ThinkValue{Value: true}, want: boolPtr(true)},
|
|
||||||
{name: "disabled", think: &api.ThinkValue{Value: false}, want: boolPtr(false)},
|
|
||||||
{name: "level maps to enabled", think: &api.ThinkValue{Value: "high"}, want: boolPtr(true)},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
var got completionRequest
|
|
||||||
|
|
||||||
rt := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
|
||||||
if r.URL.Path != "/completion" {
|
|
||||||
t.Fatalf("request path = %q, want %q", r.URL.Path, "/completion")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&got); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &http.Response{
|
|
||||||
StatusCode: http.StatusOK,
|
|
||||||
Header: make(http.Header),
|
|
||||||
Body: io.NopCloser(strings.NewReader("{\"done\":true}\n")),
|
|
||||||
Request: r,
|
|
||||||
}, nil
|
|
||||||
})
|
|
||||||
|
|
||||||
c := &Client{
|
|
||||||
port: 11434,
|
|
||||||
client: &http.Client{
|
|
||||||
Transport: rt,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
err := c.Completion(context.Background(), llm.CompletionRequest{
|
|
||||||
Prompt: "hello",
|
|
||||||
Think: tc.think,
|
|
||||||
}, func(llm.CompletionResponse) {})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("completion request failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if got.Prompt != "hello" {
|
|
||||||
t.Fatalf("prompt = %q, want %q", got.Prompt, "hello")
|
|
||||||
}
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case tc.want == nil && got.Think != nil:
|
|
||||||
t.Fatalf("think = %v, want nil", *got.Think)
|
|
||||||
case tc.want != nil && got.Think == nil:
|
|
||||||
t.Fatalf("think = nil, want %v", *tc.want)
|
|
||||||
case tc.want != nil && got.Think != nil && *tc.want != *got.Think:
|
|
||||||
t.Fatalf("think = %v, want %v", *got.Think, *tc.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCompletionForwardsOnlySpecifiedSamplingOptions(t *testing.T) {
|
|
||||||
var got completionRequest
|
|
||||||
|
|
||||||
rt := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&got); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &http.Response{
|
|
||||||
StatusCode: http.StatusOK,
|
|
||||||
Header: make(http.Header),
|
|
||||||
Body: io.NopCloser(strings.NewReader("{\"done\":true}\n")),
|
|
||||||
Request: r,
|
|
||||||
}, nil
|
|
||||||
})
|
|
||||||
|
|
||||||
c := &Client{
|
|
||||||
port: 11434,
|
|
||||||
client: &http.Client{
|
|
||||||
Transport: rt,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
opts := &api.Options{
|
|
||||||
Temperature: 1.0,
|
|
||||||
TopP: 0.95,
|
|
||||||
MinP: 0.1,
|
|
||||||
TopK: 20,
|
|
||||||
RepeatLastN: 128,
|
|
||||||
RepeatPenalty: 1.2,
|
|
||||||
PresencePenalty: 1.5,
|
|
||||||
FrequencyPenalty: 0.25,
|
|
||||||
NumPredict: 64,
|
|
||||||
}
|
|
||||||
|
|
||||||
err := c.Completion(context.Background(), llm.CompletionRequest{
|
|
||||||
Prompt: "hello",
|
|
||||||
Options: opts,
|
|
||||||
ExplicitOptions: map[string]struct{}{
|
|
||||||
"temperature": {},
|
|
||||||
"top_k": {},
|
|
||||||
"repeat_penalty": {},
|
|
||||||
"presence_penalty": {},
|
|
||||||
},
|
|
||||||
}, func(llm.CompletionResponse) {})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("completion request failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if got.Options == nil {
|
|
||||||
t.Fatal("options = nil, want serialized options")
|
|
||||||
}
|
|
||||||
|
|
||||||
if got.Options.Temperature == nil || *got.Options.Temperature != opts.Temperature {
|
|
||||||
t.Fatalf("temperature = %v, want %v", got.Options.Temperature, opts.Temperature)
|
|
||||||
}
|
|
||||||
if got.Options.TopK == nil || *got.Options.TopK != opts.TopK {
|
|
||||||
t.Fatalf("top_k = %v, want %v", got.Options.TopK, opts.TopK)
|
|
||||||
}
|
|
||||||
if got.Options.RepeatPenalty == nil || *got.Options.RepeatPenalty != opts.RepeatPenalty {
|
|
||||||
t.Fatalf("repeat_penalty = %v, want %v", got.Options.RepeatPenalty, opts.RepeatPenalty)
|
|
||||||
}
|
|
||||||
if got.Options.PresencePenalty == nil || *got.Options.PresencePenalty != opts.PresencePenalty {
|
|
||||||
t.Fatalf("presence_penalty = %v, want %v", got.Options.PresencePenalty, opts.PresencePenalty)
|
|
||||||
}
|
|
||||||
if got.Options.TopP != nil {
|
|
||||||
t.Fatalf("top_p = %v, want nil", *got.Options.TopP)
|
|
||||||
}
|
|
||||||
if got.Options.MinP != nil {
|
|
||||||
t.Fatalf("min_p = %v, want nil", *got.Options.MinP)
|
|
||||||
}
|
|
||||||
if got.Options.RepeatLastN != nil {
|
|
||||||
t.Fatalf("repeat_last_n = %v, want nil", *got.Options.RepeatLastN)
|
|
||||||
}
|
|
||||||
if got.Options.FrequencyPenalty != nil {
|
|
||||||
t.Fatalf("frequency_penalty = %v, want nil", *got.Options.FrequencyPenalty)
|
|
||||||
}
|
|
||||||
if got.Options.NumPredict != opts.NumPredict {
|
|
||||||
t.Fatalf("num_predict = %d, want %d", got.Options.NumPredict, opts.NumPredict)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
|
||||||
|
|
||||||
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) {
|
|
||||||
return f(r)
|
|
||||||
}
|
|
||||||
@@ -7,6 +7,4 @@ import (
|
|||||||
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
|
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
|
||||||
_ "github.com/ollama/ollama/x/models/llama"
|
_ "github.com/ollama/ollama/x/models/llama"
|
||||||
_ "github.com/ollama/ollama/x/models/qwen3"
|
_ "github.com/ollama/ollama/x/models/qwen3"
|
||||||
_ "github.com/ollama/ollama/x/models/qwen3_5"
|
|
||||||
_ "github.com/ollama/ollama/x/models/qwen3_5_moe"
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,275 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package mlx
|
|
||||||
|
|
||||||
// #include <stdlib.h>
|
|
||||||
// #include "generated.h"
|
|
||||||
import "C"
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"unsafe"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
gatedDeltaMetalKernelOnce sync.Once
|
|
||||||
gatedDeltaMetalKernel C.mlx_fast_metal_kernel
|
|
||||||
gatedDeltaMetalDisabled atomic.Bool
|
|
||||||
)
|
|
||||||
|
|
||||||
const gatedDeltaMetalKernelSource = `
|
|
||||||
auto n = thread_position_in_grid.z;
|
|
||||||
auto b_idx = n / Hv;
|
|
||||||
auto hv_idx = n % Hv;
|
|
||||||
auto hk_idx = hv_idx / (Hv / Hk);
|
|
||||||
constexpr int n_per_t = Dk / 32;
|
|
||||||
|
|
||||||
// q, k: [B, T, Hk, Dk]
|
|
||||||
auto q_ = q + b_idx * T * Hk * Dk + hk_idx * Dk;
|
|
||||||
auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk;
|
|
||||||
|
|
||||||
// v, y: [B, T, Hv, Dv]
|
|
||||||
auto v_ = v + b_idx * T * Hv * Dv + hv_idx * Dv;
|
|
||||||
y += b_idx * T * Hv * Dv + hv_idx * Dv;
|
|
||||||
|
|
||||||
auto dk_idx = thread_position_in_threadgroup.x;
|
|
||||||
auto dv_idx = thread_position_in_grid.y;
|
|
||||||
|
|
||||||
// state_in, state_out: [B, Hv, Dv, Dk]
|
|
||||||
auto i_state = state_in + (n * Dv + dv_idx) * Dk;
|
|
||||||
auto o_state = state_out + (n * Dv + dv_idx) * Dk;
|
|
||||||
|
|
||||||
float state[n_per_t];
|
|
||||||
for (int i = 0; i < n_per_t; ++i) {
|
|
||||||
auto s_idx = n_per_t * dk_idx + i;
|
|
||||||
state[i] = static_cast<float>(i_state[s_idx]);
|
|
||||||
}
|
|
||||||
|
|
||||||
// g: [B, T, Hv]
|
|
||||||
auto g_ = g + b_idx * T * Hv;
|
|
||||||
auto beta_ = beta + b_idx * T * Hv;
|
|
||||||
|
|
||||||
for (int t = 0; t < T; ++t) {
|
|
||||||
float kv_mem = 0.0f;
|
|
||||||
for (int i = 0; i < n_per_t; ++i) {
|
|
||||||
auto s_idx = n_per_t * dk_idx + i;
|
|
||||||
state[i] = state[i] * g_[hv_idx];
|
|
||||||
kv_mem += state[i] * k_[s_idx];
|
|
||||||
}
|
|
||||||
kv_mem = simd_sum(kv_mem);
|
|
||||||
|
|
||||||
auto delta = (v_[dv_idx] - kv_mem) * beta_[hv_idx];
|
|
||||||
|
|
||||||
float out = 0.0f;
|
|
||||||
for (int i = 0; i < n_per_t; ++i) {
|
|
||||||
auto s_idx = n_per_t * dk_idx + i;
|
|
||||||
state[i] = state[i] + k_[s_idx] * delta;
|
|
||||||
out += state[i] * q_[s_idx];
|
|
||||||
}
|
|
||||||
out = simd_sum(out);
|
|
||||||
if (thread_index_in_simdgroup == 0) {
|
|
||||||
y[dv_idx] = static_cast<InT>(out);
|
|
||||||
}
|
|
||||||
|
|
||||||
q_ += Hk * Dk;
|
|
||||||
k_ += Hk * Dk;
|
|
||||||
v_ += Hv * Dv;
|
|
||||||
y += Hv * Dv;
|
|
||||||
g_ += Hv;
|
|
||||||
beta_ += Hv;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < n_per_t; ++i) {
|
|
||||||
auto s_idx = n_per_t * dk_idx + i;
|
|
||||||
o_state[s_idx] = static_cast<InT>(state[i]);
|
|
||||||
}
|
|
||||||
`
|
|
||||||
|
|
||||||
func cStringVector(values []string) (C.mlx_vector_string, func(), bool) {
|
|
||||||
vec := C.mlx_vector_string_new()
|
|
||||||
ok := true
|
|
||||||
for _, s := range values {
|
|
||||||
cs := C.CString(s)
|
|
||||||
if C.mlx_vector_string_append_value(vec, cs) != 0 {
|
|
||||||
ok = false
|
|
||||||
}
|
|
||||||
C.free(unsafe.Pointer(cs))
|
|
||||||
if !ok {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cleanup := func() {
|
|
||||||
C.mlx_vector_string_free(vec)
|
|
||||||
}
|
|
||||||
return vec, cleanup, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func initGatedDeltaMetalKernel() {
|
|
||||||
inputs, freeInputs, ok := cStringVector([]string{"q", "k", "v", "g", "beta", "state_in", "T"})
|
|
||||||
if !ok {
|
|
||||||
gatedDeltaMetalDisabled.Store(true)
|
|
||||||
freeInputs()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer freeInputs()
|
|
||||||
|
|
||||||
outputs, freeOutputs, ok := cStringVector([]string{"y", "state_out"})
|
|
||||||
if !ok {
|
|
||||||
gatedDeltaMetalDisabled.Store(true)
|
|
||||||
freeOutputs()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer freeOutputs()
|
|
||||||
|
|
||||||
cName := C.CString("gated_delta_step")
|
|
||||||
defer C.free(unsafe.Pointer(cName))
|
|
||||||
cSource := C.CString(gatedDeltaMetalKernelSource)
|
|
||||||
defer C.free(unsafe.Pointer(cSource))
|
|
||||||
cHeader := C.CString("")
|
|
||||||
defer C.free(unsafe.Pointer(cHeader))
|
|
||||||
|
|
||||||
gatedDeltaMetalKernel = C.mlx_fast_metal_kernel_new(
|
|
||||||
cName,
|
|
||||||
inputs,
|
|
||||||
outputs,
|
|
||||||
cSource,
|
|
||||||
cHeader,
|
|
||||||
C.bool(true),
|
|
||||||
C.bool(false),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GatedDeltaKernel runs a fused Metal kernel for the qwen3.5 recurrent update.
|
|
||||||
// It returns ok=false on unsupported shapes/devices or kernel setup/apply failure.
|
|
||||||
func GatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok bool) {
|
|
||||||
if gatedDeltaMetalDisabled.Load() {
|
|
||||||
return nil, nil, false
|
|
||||||
}
|
|
||||||
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
|
|
||||||
return nil, nil, false
|
|
||||||
}
|
|
||||||
if !q.Valid() || !k.Valid() || !v.Valid() || !g.Valid() || !beta.Valid() || !state.Valid() {
|
|
||||||
return nil, nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
qd := q.Dims()
|
|
||||||
kd := k.Dims()
|
|
||||||
vd := v.Dims()
|
|
||||||
gd := g.Dims()
|
|
||||||
bd := beta.Dims()
|
|
||||||
sd := state.Dims()
|
|
||||||
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
|
|
||||||
return nil, nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
B, T, Hk, Dk := qd[0], qd[1], qd[2], qd[3]
|
|
||||||
if T <= 0 || Hk <= 0 || Dk <= 0 || Dk%32 != 0 {
|
|
||||||
return nil, nil, false
|
|
||||||
}
|
|
||||||
if kd[0] != B || kd[1] != T || kd[2] != Hk || kd[3] != Dk {
|
|
||||||
return nil, nil, false
|
|
||||||
}
|
|
||||||
Hv, Dv := vd[2], vd[3]
|
|
||||||
if vd[0] != B || vd[1] != T || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
|
|
||||||
return nil, nil, false
|
|
||||||
}
|
|
||||||
if gd[0] != B || gd[1] != T || gd[2] != Hv {
|
|
||||||
return nil, nil, false
|
|
||||||
}
|
|
||||||
if bd[0] != B || bd[1] != T || bd[2] != Hv {
|
|
||||||
return nil, nil, false
|
|
||||||
}
|
|
||||||
if sd[0] != B || sd[1] != Hv || sd[2] != Dv || sd[3] != Dk {
|
|
||||||
return nil, nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
dtype := q.DType()
|
|
||||||
if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype || state.DType() != dtype {
|
|
||||||
return nil, nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
gatedDeltaMetalKernelOnce.Do(initGatedDeltaMetalKernel)
|
|
||||||
if gatedDeltaMetalDisabled.Load() {
|
|
||||||
return nil, nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg := C.mlx_fast_metal_kernel_config_new()
|
|
||||||
defer C.mlx_fast_metal_kernel_config_free(cfg)
|
|
||||||
|
|
||||||
cInT := C.CString("InT")
|
|
||||||
defer C.free(unsafe.Pointer(cInT))
|
|
||||||
if C.mlx_fast_metal_kernel_config_add_template_arg_dtype(cfg, cInT, C.mlx_dtype(dtype)) != 0 {
|
|
||||||
gatedDeltaMetalDisabled.Store(true)
|
|
||||||
return nil, nil, false
|
|
||||||
}
|
|
||||||
for _, tpl := range []struct {
|
|
||||||
name string
|
|
||||||
value int
|
|
||||||
}{
|
|
||||||
{name: "Dk", value: Dk},
|
|
||||||
{name: "Dv", value: Dv},
|
|
||||||
{name: "Hk", value: Hk},
|
|
||||||
{name: "Hv", value: Hv},
|
|
||||||
} {
|
|
||||||
cn := C.CString(tpl.name)
|
|
||||||
rc := C.mlx_fast_metal_kernel_config_add_template_arg_int(cfg, cn, C.int(tpl.value))
|
|
||||||
C.free(unsafe.Pointer(cn))
|
|
||||||
if rc != 0 {
|
|
||||||
gatedDeltaMetalDisabled.Store(true)
|
|
||||||
return nil, nil, false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
yShape := []C.int{C.int(B), C.int(T), C.int(Hv), C.int(Dv)}
|
|
||||||
stateShape := []C.int{C.int(B), C.int(Hv), C.int(Dv), C.int(Dk)}
|
|
||||||
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(yShape), C.size_t(len(yShape)), C.mlx_dtype(dtype)) != 0 {
|
|
||||||
gatedDeltaMetalDisabled.Store(true)
|
|
||||||
return nil, nil, false
|
|
||||||
}
|
|
||||||
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(dtype)) != 0 {
|
|
||||||
gatedDeltaMetalDisabled.Store(true)
|
|
||||||
return nil, nil, false
|
|
||||||
}
|
|
||||||
if C.mlx_fast_metal_kernel_config_set_grid(cfg, 32, C.int(Dv), C.int(B*Hv)) != 0 {
|
|
||||||
gatedDeltaMetalDisabled.Store(true)
|
|
||||||
return nil, nil, false
|
|
||||||
}
|
|
||||||
threadY := Dv
|
|
||||||
if threadY > 4 {
|
|
||||||
threadY = 4
|
|
||||||
}
|
|
||||||
if C.mlx_fast_metal_kernel_config_set_thread_group(cfg, 32, C.int(threadY), 1) != 0 {
|
|
||||||
gatedDeltaMetalDisabled.Store(true)
|
|
||||||
return nil, nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
tScalar := FromValue(T)
|
|
||||||
inputs := []C.mlx_array{
|
|
||||||
q.ctx,
|
|
||||||
k.ctx,
|
|
||||||
v.ctx,
|
|
||||||
g.ctx,
|
|
||||||
beta.ctx,
|
|
||||||
state.ctx,
|
|
||||||
tScalar.ctx,
|
|
||||||
}
|
|
||||||
inVec := C.mlx_vector_array_new_data(unsafe.SliceData(inputs), C.size_t(len(inputs)))
|
|
||||||
defer C.mlx_vector_array_free(inVec)
|
|
||||||
|
|
||||||
outVec := C.mlx_vector_array_new()
|
|
||||||
defer C.mlx_vector_array_free(outVec)
|
|
||||||
if C.mlx_fast_metal_kernel_apply(&outVec, gatedDeltaMetalKernel, inVec, cfg, DefaultStream().ctx) != 0 {
|
|
||||||
gatedDeltaMetalDisabled.Store(true)
|
|
||||||
return nil, nil, false
|
|
||||||
}
|
|
||||||
if int(C.mlx_vector_array_size(outVec)) < 2 {
|
|
||||||
return nil, nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
y = New("GATED_DELTA_METAL_Y")
|
|
||||||
nextState = New("GATED_DELTA_METAL_STATE")
|
|
||||||
C.mlx_vector_array_get(&y.ctx, outVec, 0)
|
|
||||||
C.mlx_vector_array_get(&nextState.ctx, outVec, 1)
|
|
||||||
return y, nextState, true
|
|
||||||
}
|
|
||||||
@@ -64,10 +64,6 @@ func PeakMemory() int {
|
|||||||
return int(peak)
|
return int(peak)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ResetPeakMemory() {
|
|
||||||
C.mlx_reset_peak_memory()
|
|
||||||
}
|
|
||||||
|
|
||||||
type Memory struct{}
|
type Memory struct{}
|
||||||
|
|
||||||
func (Memory) LogValue() slog.Value {
|
func (Memory) LogValue() slog.Value {
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ func doEval(outputs []*Array, async bool) {
|
|||||||
defer C.mlx_vector_array_free(vector)
|
defer C.mlx_vector_array_free(vector)
|
||||||
|
|
||||||
for _, output := range outputs {
|
for _, output := range outputs {
|
||||||
if output != nil && output.Valid() {
|
if output.Valid() {
|
||||||
C.mlx_vector_array_append_value(vector, output.ctx)
|
C.mlx_vector_array_append_value(vector, output.ctx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -93,12 +93,6 @@ func (t *Array) Divide(other *Array) *Array {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Array) Cumsum(axis int, reverse, inclusive bool) *Array {
|
|
||||||
out := New("CUMSUM")
|
|
||||||
C.mlx_cumsum(&out.ctx, t.ctx, C.int(axis), C.bool(reverse), C.bool(inclusive), DefaultStream().ctx)
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Array) ExpandDims(axis int) *Array {
|
func (t *Array) ExpandDims(axis int) *Array {
|
||||||
out := New("EXPAND_DIMS")
|
out := New("EXPAND_DIMS")
|
||||||
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
||||||
@@ -129,30 +123,12 @@ func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Array) GreaterEqual(other *Array) *Array {
|
|
||||||
out := New("GREATER_EQUAL")
|
|
||||||
C.mlx_greater_equal(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Array) Logsumexp(keepDims bool) *Array {
|
func (t *Array) Logsumexp(keepDims bool) *Array {
|
||||||
out := New("LOGSUMEXP")
|
out := New("LOGSUMEXP")
|
||||||
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
|
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Array) Less(other *Array) *Array {
|
|
||||||
out := New("LESS")
|
|
||||||
C.mlx_less(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Array) LogicalOr(other *Array) *Array {
|
|
||||||
out := New("LOGICAL_OR")
|
|
||||||
C.mlx_logical_or(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Array) Matmul(other *Array) *Array {
|
func (t *Array) Matmul(other *Array) *Array {
|
||||||
out := New("MATMUL")
|
out := New("MATMUL")
|
||||||
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||||
|
|||||||
@@ -113,35 +113,6 @@ func Where(condition, a, b *Array) *Array {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func Conv1d(x, weight *Array, bias *Array, stride, padding, dilation, groups int32) *Array {
|
|
||||||
out := New("CONV1D")
|
|
||||||
C.mlx_conv1d(
|
|
||||||
&out.ctx,
|
|
||||||
x.ctx,
|
|
||||||
weight.ctx,
|
|
||||||
C.int(stride),
|
|
||||||
C.int(padding),
|
|
||||||
C.int(dilation),
|
|
||||||
C.int(groups),
|
|
||||||
DefaultStream().ctx,
|
|
||||||
)
|
|
||||||
if bias != nil && bias.Valid() {
|
|
||||||
out = Add(out, bias)
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func Contiguous(a *Array, allowColMajor bool) *Array {
|
|
||||||
out := New("CONTIGUOUS")
|
|
||||||
C.mlx_contiguous(&out.ctx, a.ctx, C.bool(allowColMajor), DefaultStream().ctx)
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
|
|
||||||
groups := int32(x.Dim(x.NumDims() - 1))
|
|
||||||
return Conv1d(x, weight, bias, 1, 0, 1, groups)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convenience wrappers (function-style for the model code)
|
// Convenience wrappers (function-style for the model code)
|
||||||
|
|
||||||
func Stack(arrays []*Array, axis int) *Array {
|
func Stack(arrays []*Array, axis int) *Array {
|
||||||
@@ -300,24 +271,6 @@ func Sigmoid(a *Array) *Array {
|
|||||||
return a.Sigmoid()
|
return a.Sigmoid()
|
||||||
}
|
}
|
||||||
|
|
||||||
func Exp(a *Array) *Array {
|
|
||||||
out := New("EXP")
|
|
||||||
C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx)
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func Log(a *Array) *Array {
|
|
||||||
out := New("LOG")
|
|
||||||
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func SoftmaxAxis(a *Array, axis int, precise bool) *Array {
|
|
||||||
out := New("SOFTMAX_AXIS")
|
|
||||||
C.mlx_softmax_axis(&out.ctx, a.ctx, C.int(axis), C.bool(precise), DefaultStream().ctx)
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask bool) *Array {
|
func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask bool) *Array {
|
||||||
mask := New("")
|
mask := New("")
|
||||||
sinks := New("")
|
sinks := New("")
|
||||||
@@ -335,11 +288,7 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b
|
|||||||
|
|
||||||
func RMSNormFn(x, weight *Array, eps float32) *Array {
|
func RMSNormFn(x, weight *Array, eps float32) *Array {
|
||||||
out := New("FAST_RMSNORM")
|
out := New("FAST_RMSNORM")
|
||||||
var w C.mlx_array
|
C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx)
|
||||||
if weight != nil {
|
|
||||||
w = weight.ctx
|
|
||||||
}
|
|
||||||
C.mlx_fast_rms_norm(&out.ctx, x.ctx, w, C.float(eps), DefaultStream().ctx)
|
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -429,27 +378,6 @@ func Collect(v any) []*Array {
|
|||||||
return arrays
|
return arrays
|
||||||
}
|
}
|
||||||
|
|
||||||
// Snapshot copies an array into a fresh leaf value with no Go-side graph inputs.
|
|
||||||
func Snapshot(a *Array) *Array {
|
|
||||||
if a == nil || !a.Valid() {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
out := New("SNAPSHOT")
|
|
||||||
C.mlx_copy(&out.ctx, a.ctx, DefaultStream().ctx)
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
// Detach returns a new Array handle that shares the same MLX value but does
|
|
||||||
// not retain Go-side graph input references.
|
|
||||||
func Detach(a *Array) *Array {
|
|
||||||
if a == nil || !a.Valid() {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
out := New("DETACH")
|
|
||||||
C.mlx_array_set(&out.ctx, a.ctx)
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
|
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
|
||||||
if !v.IsValid() {
|
if !v.IsValid() {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ type Model interface {
|
|||||||
Unembed(x *mlx.Array) *mlx.Array
|
Unembed(x *mlx.Array) *mlx.Array
|
||||||
NumLayers() int
|
NumLayers() int
|
||||||
Tokenizer() *tokenizer.Tokenizer
|
Tokenizer() *tokenizer.Tokenizer
|
||||||
MaxContextLength() int
|
|
||||||
|
|
||||||
// LoadWeights receives all tensors loaded from the manifest and assigns
|
// LoadWeights receives all tensors loaded from the manifest and assigns
|
||||||
// them to model fields. Model-specific logic (MLA absorption, expert
|
// them to model fields. Model-specific logic (MLA absorption, expert
|
||||||
|
|||||||
@@ -6,30 +6,18 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
)
|
)
|
||||||
|
|
||||||
func prefillChunkSize() int {
|
|
||||||
return 2 << 10
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Runner) TextGenerationPipeline(request Request) error {
|
func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||||
if r.Model == nil {
|
if r.Model == nil {
|
||||||
return errors.New("model not loaded")
|
return errors.New("model not loaded")
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := request.Ctx
|
|
||||||
if ctx == nil {
|
|
||||||
ctx = context.Background()
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
sample, logprobs *mlx.Array
|
sample, logprobs *mlx.Array
|
||||||
nextSample, nextLogprobs *mlx.Array
|
nextSample, nextLogprobs *mlx.Array
|
||||||
@@ -56,72 +44,43 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
} else {
|
} else {
|
||||||
mlx.DisableCompile()
|
mlx.DisableCompile()
|
||||||
}
|
}
|
||||||
mlx.ResetPeakMemory()
|
|
||||||
|
|
||||||
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
||||||
if len(inputs) == 0 {
|
|
||||||
return errors.New("empty prompt")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(inputs) >= r.contextLength {
|
|
||||||
return api.StatusError{
|
|
||||||
StatusCode: http.StatusBadRequest,
|
|
||||||
ErrorMessage: fmt.Sprintf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(inputs), r.contextLength),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cap generation to stay within the model's context length
|
|
||||||
maxGenerate := r.contextLength - len(inputs)
|
|
||||||
if request.Options.MaxTokens <= 0 {
|
|
||||||
request.Options.MaxTokens = maxGenerate
|
|
||||||
} else {
|
|
||||||
request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate)
|
|
||||||
}
|
|
||||||
|
|
||||||
session := r.cache.begin(r.Model, inputs)
|
session := r.cache.begin(r.Model, inputs)
|
||||||
defer session.close()
|
defer session.close()
|
||||||
|
|
||||||
caches := session.caches
|
caches := session.caches
|
||||||
tokens := session.remaining
|
tokens := session.remaining
|
||||||
history := append([]int32(nil), session.inputs...)
|
|
||||||
prefillChunk := prefillChunkSize()
|
|
||||||
|
|
||||||
materializeCaches := func() {
|
|
||||||
state := make([]*mlx.Array, 0, 2*len(caches))
|
|
||||||
for _, c := range caches {
|
|
||||||
if c == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
state = append(state, c.Materialize()...)
|
|
||||||
}
|
|
||||||
if len(state) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
mlx.Eval(state...)
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
total, processed := len(tokens), 0
|
total, processed := len(tokens), 0
|
||||||
|
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||||
for total-processed > 1 {
|
for total-processed > 1 {
|
||||||
if err := ctx.Err(); err != nil {
|
if err := request.Ctx.Err(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
n := min(prefillChunk, total-processed-1)
|
n := min(2<<10, total-processed-1)
|
||||||
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
||||||
mlx.Sweep()
|
mlx.Sweep()
|
||||||
materializeCaches()
|
mlx.Eval(func() []*mlx.Array {
|
||||||
|
s := make([]*mlx.Array, 2*len(caches))
|
||||||
|
for i, c := range caches {
|
||||||
|
s[2*i], s[2*i+1] = c.State()
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}()...)
|
||||||
processed += n
|
processed += n
|
||||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||||
mlx.ClearCache()
|
mlx.ClearCache()
|
||||||
}
|
}
|
||||||
|
|
||||||
step := func(token *mlx.Array, history []int32) (*mlx.Array, *mlx.Array) {
|
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||||
fwd := r.Model.Forward(token.ExpandDims(0), caches)
|
fwd := r.Model.Forward(token.ExpandDims(0), caches)
|
||||||
logits := r.Model.Unembed(fwd)
|
logits := r.Model.Unembed(fwd)
|
||||||
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
||||||
|
|
||||||
logprobs := logits.Subtract(logits.Logsumexp(true))
|
logprobs := logits.Subtract(logits.Logsumexp(true))
|
||||||
sample := request.Sample(logprobs, history)
|
sample := request.Sample(logprobs)
|
||||||
|
|
||||||
mlx.Pin(sample, logprobs)
|
mlx.Pin(sample, logprobs)
|
||||||
mlx.Sweep()
|
mlx.Sweep()
|
||||||
@@ -130,42 +89,45 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
return sample, logprobs
|
return sample, logprobs
|
||||||
}
|
}
|
||||||
|
|
||||||
sample, logprobs = step(mlx.FromValues(tokens[processed:], total-processed), history)
|
sample, logprobs = step(mlx.FromValues(tokens[processed:], total-processed))
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
|
|
||||||
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1}
|
now := time.Now()
|
||||||
|
final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
|
||||||
for i := range request.Options.MaxTokens {
|
for i := range request.Options.MaxTokens {
|
||||||
if err := ctx.Err(); err != nil {
|
if err := request.Ctx.Err(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nextSample, nextLogprobs = step(sample)
|
||||||
|
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
|
slog.Info("Prompt processing progress", "processed", total, "total", total)
|
||||||
mlx.Eval(sample)
|
mlx.Eval(sample)
|
||||||
final.PromptEvalDuration = time.Since(now)
|
final.PromptTokensDuration = time.Since(now)
|
||||||
now = time.Now()
|
now = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
output := int32(sample.Int())
|
output := int32(sample.Int())
|
||||||
session.outputs = append(session.outputs, output)
|
session.outputs = append(session.outputs, output)
|
||||||
history = append(history, output)
|
|
||||||
|
|
||||||
if r.Tokenizer.IsEOS(output) {
|
if r.Tokenizer.IsEOS(output) {
|
||||||
|
final.Token = int(output)
|
||||||
final.DoneReason = 0
|
final.DoneReason = 0
|
||||||
final.EvalCount = i
|
final.CompletionTokens = i
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-request.Ctx.Done():
|
case <-request.Ctx.Done():
|
||||||
return request.Ctx.Err()
|
return request.Ctx.Err()
|
||||||
case request.Responses <- CompletionResponse{
|
case request.Responses <- Response{
|
||||||
Content: r.Decode(output, &b),
|
Text: r.Decode(output, &b),
|
||||||
|
Token: int(output),
|
||||||
}:
|
}:
|
||||||
}
|
}
|
||||||
|
|
||||||
nextSample, nextLogprobs = step(sample, history)
|
|
||||||
|
|
||||||
mlx.Unpin(sample, logprobs)
|
mlx.Unpin(sample, logprobs)
|
||||||
sample, logprobs = nextSample, nextLogprobs
|
sample, logprobs = nextSample, nextLogprobs
|
||||||
nextSample, nextLogprobs = nil, nil
|
nextSample, nextLogprobs = nil, nil
|
||||||
@@ -175,11 +137,10 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
final.EvalDuration = time.Since(now)
|
final.CompletionTokensDuration = time.Since(now)
|
||||||
final.PeakMemory = uint64(mlx.PeakMemory())
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-request.Ctx.Done():
|
||||||
return ctx.Err()
|
return request.Ctx.Err()
|
||||||
case request.Responses <- final:
|
case request.Responses <- final:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,15 +4,14 @@ package mlxrunner
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||||
@@ -22,7 +21,7 @@ import (
|
|||||||
|
|
||||||
type Request struct {
|
type Request struct {
|
||||||
TextCompletionsRequest
|
TextCompletionsRequest
|
||||||
Responses chan CompletionResponse
|
Responses chan Response
|
||||||
Pipeline func(Request) error
|
Pipeline func(Request) error
|
||||||
|
|
||||||
Ctx context.Context
|
Ctx context.Context
|
||||||
@@ -32,29 +31,37 @@ type Request struct {
|
|||||||
|
|
||||||
type TextCompletionsRequest struct {
|
type TextCompletionsRequest struct {
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
Think *bool `json:"think,omitempty"`
|
|
||||||
Options struct {
|
Options struct {
|
||||||
Temperature *float32 `json:"temperature"`
|
Temperature float32 `json:"temperature"`
|
||||||
TopP *float32 `json:"top_p"`
|
TopP float32 `json:"top_p"`
|
||||||
MinP *float32 `json:"min_p"`
|
MinP float32 `json:"min_p"`
|
||||||
TopK *int `json:"top_k"`
|
TopK int `json:"top_k"`
|
||||||
RepeatLastN *int `json:"repeat_last_n"`
|
MaxTokens int `json:"max_tokens"`
|
||||||
RepeatPenalty *float32 `json:"repeat_penalty"`
|
|
||||||
PresencePenalty *float32 `json:"presence_penalty"`
|
|
||||||
FrequencyPenalty *float32 `json:"frequency_penalty"`
|
|
||||||
MaxTokens int `json:"max_tokens"`
|
|
||||||
|
|
||||||
// Deprecated: use MaxTokens instead
|
// Deprecated: use MaxTokens instead
|
||||||
NumPredict int `json:"num_predict"`
|
NumPredict int `json:"num_predict"`
|
||||||
} `json:"options"`
|
} `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Response struct {
|
||||||
|
Text string `json:"content,omitempty"`
|
||||||
|
Token int `json:"token,omitempty"`
|
||||||
|
Logprobs []float32 `json:"logprobs,omitempty"`
|
||||||
|
Done bool `json:"done,omitempty"`
|
||||||
|
DoneReason int `json:"done_reason,omitempty"`
|
||||||
|
|
||||||
|
PromptTokens int `json:"prompt_eval_count,omitempty"`
|
||||||
|
PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
||||||
|
CompletionTokens int `json:"eval_count,omitempty"`
|
||||||
|
CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"`
|
||||||
|
TotalTokens int `json:"total_tokens,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
type Runner struct {
|
type Runner struct {
|
||||||
Model base.Model
|
Model base.Model
|
||||||
Tokenizer *tokenizer.Tokenizer
|
Tokenizer *tokenizer.Tokenizer
|
||||||
Requests chan Request
|
Requests chan Request
|
||||||
cache kvCache
|
cache kvCache
|
||||||
contextLength int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Runner) Load(modelName string) error {
|
func (r *Runner) Load(modelName string) error {
|
||||||
@@ -83,7 +90,6 @@ func (r *Runner) Load(modelName string) error {
|
|||||||
|
|
||||||
r.Model = m
|
r.Model = m
|
||||||
r.Tokenizer = m.Tokenizer()
|
r.Tokenizer = m.Tokenizer()
|
||||||
r.contextLength = m.MaxContextLength()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -152,17 +158,6 @@ func (r *Runner) Run(host, port string, mux http.Handler) error {
|
|||||||
case request := <-r.Requests:
|
case request := <-r.Requests:
|
||||||
if err := request.Pipeline(request); err != nil {
|
if err := request.Pipeline(request); err != nil {
|
||||||
slog.Info("Request terminated", "error", err)
|
slog.Info("Request terminated", "error", err)
|
||||||
var statusErr api.StatusError
|
|
||||||
if !errors.As(err, &statusErr) {
|
|
||||||
statusErr = api.StatusError{
|
|
||||||
StatusCode: http.StatusInternalServerError,
|
|
||||||
ErrorMessage: err.Error(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case request.Responses <- CompletionResponse{Error: &statusErr}:
|
|
||||||
case <-request.Ctx.Done():
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
close(request.Responses)
|
close(request.Responses)
|
||||||
|
|||||||
@@ -9,204 +9,69 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Sampler interface {
|
type Sampler interface {
|
||||||
Sample(*mlx.Array, []int32) *mlx.Array
|
Sample(*mlx.Array) *mlx.Array
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(temp, top_p, min_p float32, top_k, repeatLastN int, repeatPenalty, presencePenalty, frequencyPenalty float32) Sampler {
|
func New(temp, top_p, min_p float32, top_k int) Sampler {
|
||||||
var samplers []Sampler
|
if temp == 0 {
|
||||||
if repeatLastN > 0 && (repeatPenalty != 1 || presencePenalty != 0 || frequencyPenalty != 0) {
|
return greedy{}
|
||||||
samplers = append(samplers, Penalty{
|
|
||||||
RepeatLastN: repeatLastN,
|
|
||||||
RepeatPenalty: repeatPenalty,
|
|
||||||
PresencePenalty: presencePenalty,
|
|
||||||
FrequencyPenalty: frequencyPenalty,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if temp == 0 {
|
var samplers []Sampler
|
||||||
samplers = append(samplers, greedy{})
|
if top_p > 0 && top_p < 1 {
|
||||||
} else {
|
samplers = append(samplers, TopP(top_p))
|
||||||
samplers = append(samplers, Distribution{
|
|
||||||
Temperature: temp,
|
|
||||||
TopK: top_k,
|
|
||||||
TopP: top_p,
|
|
||||||
MinP: min_p,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if min_p != 0 {
|
||||||
|
samplers = append(samplers, MinP(min_p))
|
||||||
|
}
|
||||||
|
|
||||||
|
if top_k > 0 {
|
||||||
|
samplers = append(samplers, TopK(top_k))
|
||||||
|
}
|
||||||
|
|
||||||
|
samplers = append(samplers, Temperature(temp))
|
||||||
return chain(samplers)
|
return chain(samplers)
|
||||||
}
|
}
|
||||||
|
|
||||||
type greedy struct{}
|
type greedy struct{}
|
||||||
|
|
||||||
func (greedy) Sample(logits *mlx.Array, _ []int32) *mlx.Array {
|
func (greedy) Sample(logits *mlx.Array) *mlx.Array {
|
||||||
return logits.Argmax(-1, false)
|
return logits.Argmax(-1, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
type chain []Sampler
|
type chain []Sampler
|
||||||
|
|
||||||
func (c chain) Sample(logits *mlx.Array, history []int32) *mlx.Array {
|
func (c chain) Sample(logits *mlx.Array) *mlx.Array {
|
||||||
for _, sampler := range c {
|
for _, sampler := range c {
|
||||||
logits = sampler.Sample(logits, history)
|
logits = sampler.Sample(logits)
|
||||||
}
|
}
|
||||||
return logits
|
return logits
|
||||||
}
|
}
|
||||||
|
|
||||||
type Distribution struct {
|
type Temperature float32
|
||||||
Temperature float32
|
|
||||||
TopK int
|
func (t Temperature) Sample(logits *mlx.Array) *mlx.Array {
|
||||||
TopP float32
|
return mlx.DivScalar(logits, float32(t)).Categorical(-1)
|
||||||
MinP float32
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d Distribution) Sample(logits *mlx.Array, _ []int32) *mlx.Array {
|
type TopP float32
|
||||||
filtered, indices := d.filter(logits)
|
|
||||||
sample := filtered.Categorical(-1)
|
|
||||||
if indices == nil {
|
|
||||||
return sample
|
|
||||||
}
|
|
||||||
|
|
||||||
positions := sample.ExpandDims(1)
|
func (p TopP) Sample(logprobs *mlx.Array) *mlx.Array {
|
||||||
return indices.TakeAlongAxis(positions, -1).Squeeze(1)
|
// TODO: implement
|
||||||
|
return logprobs
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d Distribution) filter(logits *mlx.Array) (*mlx.Array, *mlx.Array) {
|
type MinP float32
|
||||||
candidates := logits
|
|
||||||
var candidateIndices *mlx.Array
|
|
||||||
|
|
||||||
if d.TopK > 0 && d.TopK < logits.Dim(logits.NumDims()-1) {
|
func (p MinP) Sample(logprobs *mlx.Array) *mlx.Array {
|
||||||
partitions := logits.Negative().ArgpartitionAxis(d.TopK-1, -1)
|
// TODO: implement
|
||||||
switch logits.NumDims() {
|
return logprobs
|
||||||
case 1:
|
|
||||||
candidateIndices = partitions.Slice(mlx.Slice(0, d.TopK))
|
|
||||||
default:
|
|
||||||
candidateIndices = partitions.Slice(mlx.Slice(), mlx.Slice(0, d.TopK))
|
|
||||||
}
|
|
||||||
candidates = logits.TakeAlongAxis(candidateIndices, -1)
|
|
||||||
}
|
|
||||||
|
|
||||||
if d.Temperature != 1 {
|
|
||||||
candidates = mlx.DivScalar(candidates, d.Temperature)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !d.needsProbabilityFilters() {
|
|
||||||
return candidates, candidateIndices
|
|
||||||
}
|
|
||||||
|
|
||||||
order := candidates.Negative().ArgsortAxis(-1)
|
|
||||||
sortedLogits := candidates.TakeAlongAxis(order, -1)
|
|
||||||
sortedProbs := mlx.SoftmaxAxis(candidates, -1, true).TakeAlongAxis(order, -1)
|
|
||||||
|
|
||||||
remove := d.topPRemovalMask(sortedProbs)
|
|
||||||
if d.MinP > 0 {
|
|
||||||
minPRemove := d.minPRemovalMask(sortedProbs)
|
|
||||||
if remove == nil {
|
|
||||||
remove = minPRemove
|
|
||||||
} else {
|
|
||||||
remove = remove.LogicalOr(minPRemove)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if remove == nil {
|
|
||||||
return candidates, candidateIndices
|
|
||||||
}
|
|
||||||
|
|
||||||
negInf := mlx.FromValue(float32(math.Inf(-1)))
|
|
||||||
filtered := mlx.Where(remove, negInf, sortedLogits)
|
|
||||||
return candidates.PutAlongAxis(order, filtered, -1), candidateIndices
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d Distribution) needsProbabilityFilters() bool {
|
type TopK int
|
||||||
return (d.TopP > 0 && d.TopP < 1) || d.MinP > 0
|
|
||||||
}
|
func (k TopK) Sample(logprobs *mlx.Array) *mlx.Array {
|
||||||
|
mask := logprobs.Negative().ArgpartitionAxis(int(k)-1, -1).Slice(mlx.Slice(), mlx.Slice(int(k), 0))
|
||||||
func (d Distribution) topPRemovalMask(sortedProbs *mlx.Array) *mlx.Array {
|
return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
|
||||||
if d.TopP <= 0 || d.TopP >= 1 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
threshold := mlx.NewScalarArray(d.TopP)
|
|
||||||
prevCum := sortedProbs.Cumsum(-1, false, true).Subtract(sortedProbs)
|
|
||||||
return prevCum.GreaterEqual(threshold)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d Distribution) minPRemovalMask(sortedProbs *mlx.Array) *mlx.Array {
|
|
||||||
if d.MinP <= 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var maxProb *mlx.Array
|
|
||||||
switch sortedProbs.NumDims() {
|
|
||||||
case 1:
|
|
||||||
maxProb = sortedProbs.Slice(mlx.Slice(0, 1))
|
|
||||||
default:
|
|
||||||
maxProb = sortedProbs.Slice(mlx.Slice(), mlx.Slice(0, 1))
|
|
||||||
}
|
|
||||||
|
|
||||||
threshold := mlx.MulScalar(maxProb, d.MinP)
|
|
||||||
return sortedProbs.Less(threshold)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Penalty struct {
|
|
||||||
RepeatLastN int
|
|
||||||
RepeatPenalty float32
|
|
||||||
PresencePenalty float32
|
|
||||||
FrequencyPenalty float32
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p Penalty) Sample(logprobs *mlx.Array, history []int32) *mlx.Array {
|
|
||||||
if len(history) == 0 {
|
|
||||||
return logprobs
|
|
||||||
}
|
|
||||||
|
|
||||||
window := p.RepeatLastN
|
|
||||||
if window <= 0 || window > len(history) {
|
|
||||||
window = len(history)
|
|
||||||
}
|
|
||||||
|
|
||||||
counts := make(map[int32]int, window)
|
|
||||||
order := make([]int32, 0, window)
|
|
||||||
for _, token := range history[len(history)-window:] {
|
|
||||||
if token < 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if counts[token] == 0 {
|
|
||||||
order = append(order, token)
|
|
||||||
}
|
|
||||||
counts[token]++
|
|
||||||
}
|
|
||||||
if len(order) == 0 {
|
|
||||||
return logprobs
|
|
||||||
}
|
|
||||||
|
|
||||||
indexShape := []int32{int32(len(order))}
|
|
||||||
valueShape := []int{len(order)}
|
|
||||||
if logprobs.NumDims() > 1 {
|
|
||||||
indexShape = []int32{1, int32(len(order))}
|
|
||||||
valueShape = []int{1, len(order)}
|
|
||||||
}
|
|
||||||
|
|
||||||
indices := mlx.NewArrayInt32(order, indexShape)
|
|
||||||
selected := logprobs.TakeAlongAxis(indices, -1)
|
|
||||||
mlx.Eval(selected)
|
|
||||||
|
|
||||||
values := selected.Floats()
|
|
||||||
for i, token := range order {
|
|
||||||
v := values[i]
|
|
||||||
if p.RepeatPenalty != 1 {
|
|
||||||
if v < 0 {
|
|
||||||
v *= p.RepeatPenalty
|
|
||||||
} else {
|
|
||||||
v /= p.RepeatPenalty
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if p.PresencePenalty != 0 {
|
|
||||||
v -= p.PresencePenalty
|
|
||||||
}
|
|
||||||
if p.FrequencyPenalty != 0 {
|
|
||||||
v -= p.FrequencyPenalty * float32(counts[token])
|
|
||||||
}
|
|
||||||
values[i] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
return logprobs.PutAlongAxis(indices, mlx.FromValues(values, valueShape...), -1)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,104 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package sample
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestPenaltySample(t *testing.T) {
|
|
||||||
if err := mlx.CheckInit(); err != nil {
|
|
||||||
t.Skipf("MLX not available: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
logprobs := mlx.FromValues([]float32{
|
|
||||||
1.0, -2.0, 3.0, 4.0,
|
|
||||||
}, 1, 4)
|
|
||||||
|
|
||||||
got := Penalty{
|
|
||||||
RepeatLastN: 3,
|
|
||||||
RepeatPenalty: 2.0,
|
|
||||||
PresencePenalty: 1.5,
|
|
||||||
FrequencyPenalty: 0.25,
|
|
||||||
}.Sample(logprobs, []int32{2, 1, 2})
|
|
||||||
|
|
||||||
mlx.Eval(got)
|
|
||||||
|
|
||||||
want := []float32{1.0, -5.75, -0.5, 4.0}
|
|
||||||
values := got.Floats()
|
|
||||||
if len(values) != len(want) {
|
|
||||||
t.Fatalf("len(values) = %d, want %d", len(values), len(want))
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range want {
|
|
||||||
if math.Abs(float64(values[i]-want[i])) > 1e-5 {
|
|
||||||
t.Fatalf("values[%d] = %v, want %v", i, values[i], want[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPenaltySampleHonorsRepeatWindow(t *testing.T) {
|
|
||||||
if err := mlx.CheckInit(); err != nil {
|
|
||||||
t.Skipf("MLX not available: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
logprobs := mlx.FromValues([]float32{
|
|
||||||
1.0, 2.0, 3.0,
|
|
||||||
}, 1, 3)
|
|
||||||
|
|
||||||
got := Penalty{
|
|
||||||
RepeatLastN: 1,
|
|
||||||
PresencePenalty: 1.0,
|
|
||||||
}.Sample(logprobs, []int32{0, 1})
|
|
||||||
|
|
||||||
mlx.Eval(got)
|
|
||||||
|
|
||||||
want := []float32{1.0, 1.0, 3.0}
|
|
||||||
values := got.Floats()
|
|
||||||
for i := range want {
|
|
||||||
if math.Abs(float64(values[i]-want[i])) > 1e-5 {
|
|
||||||
t.Fatalf("values[%d] = %v, want %v", i, values[i], want[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDistributionFilterTopP(t *testing.T) {
|
|
||||||
if err := mlx.CheckInit(); err != nil {
|
|
||||||
t.Skipf("MLX not available: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
logits := mlx.FromValues([]float32{
|
|
||||||
10.0, 9.0, 1.0, 0.0,
|
|
||||||
}, 1, 4)
|
|
||||||
|
|
||||||
filtered, indices := Distribution{
|
|
||||||
Temperature: 1.0,
|
|
||||||
TopK: 2,
|
|
||||||
TopP: 0.55,
|
|
||||||
}.filter(logits)
|
|
||||||
|
|
||||||
got := materializeFilteredLogits(filtered, indices, 4)
|
|
||||||
mlx.Eval(got)
|
|
||||||
|
|
||||||
values := got.Floats()
|
|
||||||
if values[0] != 10.0 {
|
|
||||||
t.Fatalf("values[0] = %v, want 10", values[0])
|
|
||||||
}
|
|
||||||
for i := 1; i < len(values); i++ {
|
|
||||||
if !math.IsInf(float64(values[i]), -1) {
|
|
||||||
t.Fatalf("values[%d] = %v, want -Inf", i, values[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func materializeFilteredLogits(filtered, indices *mlx.Array, width int) *mlx.Array {
|
|
||||||
if indices == nil {
|
|
||||||
return filtered
|
|
||||||
}
|
|
||||||
|
|
||||||
base := mlx.AddScalar(mlx.Zeros(mlx.DTypeFloat32, 1, width), float32(math.Inf(-1)))
|
|
||||||
return base.PutAlongAxis(indices, filtered, -1)
|
|
||||||
}
|
|
||||||
@@ -16,89 +16,12 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
|
||||||
"github.com/ollama/ollama/x/mlxrunner/sample"
|
"github.com/ollama/ollama/x/mlxrunner/sample"
|
||||||
"github.com/ollama/ollama/x/models/qwen3_5"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type samplingConfig struct {
|
|
||||||
temperature float32
|
|
||||||
topP float32
|
|
||||||
minP float32
|
|
||||||
topK int
|
|
||||||
repeatLastN int
|
|
||||||
repeatPenalty float32
|
|
||||||
presencePenalty float32
|
|
||||||
frequencyPenalty float32
|
|
||||||
}
|
|
||||||
|
|
||||||
func defaultSamplingConfig(m base.Model, think *bool) samplingConfig {
|
|
||||||
if _, ok := m.(*qwen3_5.Model); ok {
|
|
||||||
cfg := samplingConfig{
|
|
||||||
temperature: 1.0,
|
|
||||||
topP: 0.95,
|
|
||||||
minP: 0.0,
|
|
||||||
topK: 20,
|
|
||||||
repeatLastN: 64,
|
|
||||||
repeatPenalty: 1.0,
|
|
||||||
presencePenalty: 1.5,
|
|
||||||
frequencyPenalty: 0.0,
|
|
||||||
}
|
|
||||||
if think != nil && !*think {
|
|
||||||
cfg.temperature = 0.7
|
|
||||||
cfg.topP = 0.8
|
|
||||||
}
|
|
||||||
return cfg
|
|
||||||
}
|
|
||||||
|
|
||||||
opts := api.DefaultOptions()
|
|
||||||
return samplingConfig{
|
|
||||||
temperature: opts.Temperature,
|
|
||||||
topP: opts.TopP,
|
|
||||||
minP: opts.MinP,
|
|
||||||
topK: opts.TopK,
|
|
||||||
repeatLastN: opts.RepeatLastN,
|
|
||||||
repeatPenalty: opts.RepeatPenalty,
|
|
||||||
presencePenalty: opts.PresencePenalty,
|
|
||||||
frequencyPenalty: opts.FrequencyPenalty,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func resolveSamplingConfig(m base.Model, req Request) samplingConfig {
|
|
||||||
cfg := defaultSamplingConfig(m, req.Think)
|
|
||||||
|
|
||||||
if req.Options.Temperature != nil {
|
|
||||||
cfg.temperature = *req.Options.Temperature
|
|
||||||
}
|
|
||||||
if req.Options.TopP != nil {
|
|
||||||
cfg.topP = *req.Options.TopP
|
|
||||||
}
|
|
||||||
if req.Options.MinP != nil {
|
|
||||||
cfg.minP = *req.Options.MinP
|
|
||||||
}
|
|
||||||
if req.Options.TopK != nil {
|
|
||||||
cfg.topK = *req.Options.TopK
|
|
||||||
}
|
|
||||||
if req.Options.RepeatLastN != nil {
|
|
||||||
cfg.repeatLastN = *req.Options.RepeatLastN
|
|
||||||
}
|
|
||||||
if req.Options.RepeatPenalty != nil {
|
|
||||||
cfg.repeatPenalty = *req.Options.RepeatPenalty
|
|
||||||
}
|
|
||||||
if req.Options.PresencePenalty != nil {
|
|
||||||
cfg.presencePenalty = *req.Options.PresencePenalty
|
|
||||||
}
|
|
||||||
if req.Options.FrequencyPenalty != nil {
|
|
||||||
cfg.frequencyPenalty = *req.Options.FrequencyPenalty
|
|
||||||
}
|
|
||||||
|
|
||||||
return cfg
|
|
||||||
}
|
|
||||||
|
|
||||||
func Execute(args []string) error {
|
func Execute(args []string) error {
|
||||||
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
|
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
|
||||||
|
|
||||||
@@ -127,11 +50,9 @@ func Execute(args []string) error {
|
|||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
|
||||||
if err := json.NewEncoder(w).Encode(statusResponse{
|
if err := json.NewEncoder(w).Encode(map[string]any{
|
||||||
Status: 0,
|
"status": 0,
|
||||||
Progress: 100,
|
"progress": 100,
|
||||||
ContextLength: runner.contextLength,
|
|
||||||
Memory: uint64(mlx.ActiveMemory() + mlx.CacheMemory()),
|
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
slog.Error("Failed to encode response", "error", err)
|
slog.Error("Failed to encode response", "error", err)
|
||||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
@@ -157,7 +78,7 @@ func Execute(args []string) error {
|
|||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
request := Request{Responses: make(chan CompletionResponse)}
|
request := Request{Responses: make(chan Response)}
|
||||||
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
|
||||||
slog.Error("Failed to decode request", "error", err)
|
slog.Error("Failed to decode request", "error", err)
|
||||||
@@ -166,19 +87,16 @@ func Execute(args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
|
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
|
||||||
|
if request.Options.MaxTokens < 1 {
|
||||||
sampling := resolveSamplingConfig(runner.Model, request)
|
request.Options.MaxTokens = 16 << 10
|
||||||
|
}
|
||||||
|
|
||||||
request.Pipeline = runner.TextGenerationPipeline
|
request.Pipeline = runner.TextGenerationPipeline
|
||||||
request.Sampler = sample.New(
|
request.Sampler = sample.New(
|
||||||
sampling.temperature,
|
request.Options.Temperature,
|
||||||
sampling.topP,
|
request.Options.TopP,
|
||||||
sampling.minP,
|
request.Options.MinP,
|
||||||
sampling.topK,
|
request.Options.TopK,
|
||||||
sampling.repeatLastN,
|
|
||||||
sampling.repeatPenalty,
|
|
||||||
sampling.presencePenalty,
|
|
||||||
sampling.frequencyPenalty,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var cancel context.CancelFunc
|
var cancel context.CancelFunc
|
||||||
|
|||||||
@@ -1,172 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package mlxrunner
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
|
||||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
||||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
|
||||||
"github.com/ollama/ollama/x/models/qwen3_5"
|
|
||||||
"github.com/ollama/ollama/x/tokenizer"
|
|
||||||
)
|
|
||||||
|
|
||||||
type stubModel struct{}
|
|
||||||
|
|
||||||
func (stubModel) Forward(*mlx.Array, []cache.Cache) *mlx.Array { return nil }
|
|
||||||
func (stubModel) Unembed(*mlx.Array) *mlx.Array { return nil }
|
|
||||||
func (stubModel) NumLayers() int { return 0 }
|
|
||||||
func (stubModel) Tokenizer() *tokenizer.Tokenizer { return nil }
|
|
||||||
func (stubModel) LoadWeights(map[string]*mlx.Array) error { return nil }
|
|
||||||
|
|
||||||
func TestResolveSamplingConfigDefaults(t *testing.T) {
|
|
||||||
trueValue := true
|
|
||||||
falseValue := false
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
model base.Model
|
|
||||||
req Request
|
|
||||||
want samplingConfig
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "generic model uses api defaults",
|
|
||||||
model: stubModel{},
|
|
||||||
req: Request{},
|
|
||||||
want: samplingConfig{
|
|
||||||
temperature: 0.8,
|
|
||||||
topP: 0.9,
|
|
||||||
minP: 0.0,
|
|
||||||
topK: 40,
|
|
||||||
repeatLastN: 64,
|
|
||||||
repeatPenalty: 1.1,
|
|
||||||
presencePenalty: 0.0,
|
|
||||||
frequencyPenalty: 0.0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "qwen3.5 defaults to thinking profile when think unset",
|
|
||||||
model: &qwen3_5.Model{},
|
|
||||||
req: Request{},
|
|
||||||
want: samplingConfig{
|
|
||||||
temperature: 1.0,
|
|
||||||
topP: 0.95,
|
|
||||||
minP: 0.0,
|
|
||||||
topK: 20,
|
|
||||||
repeatLastN: 64,
|
|
||||||
repeatPenalty: 1.0,
|
|
||||||
presencePenalty: 1.5,
|
|
||||||
frequencyPenalty: 0.0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "qwen3.5 thinking disabled defaults",
|
|
||||||
model: &qwen3_5.Model{},
|
|
||||||
req: Request{TextCompletionsRequest: TextCompletionsRequest{Think: &falseValue}},
|
|
||||||
want: samplingConfig{
|
|
||||||
temperature: 0.7,
|
|
||||||
topP: 0.8,
|
|
||||||
minP: 0.0,
|
|
||||||
topK: 20,
|
|
||||||
repeatLastN: 64,
|
|
||||||
repeatPenalty: 1.0,
|
|
||||||
presencePenalty: 1.5,
|
|
||||||
frequencyPenalty: 0.0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "qwen3.5 thinking enabled defaults",
|
|
||||||
model: &qwen3_5.Model{},
|
|
||||||
req: Request{TextCompletionsRequest: TextCompletionsRequest{Think: &trueValue}},
|
|
||||||
want: samplingConfig{
|
|
||||||
temperature: 1.0,
|
|
||||||
topP: 0.95,
|
|
||||||
minP: 0.0,
|
|
||||||
topK: 20,
|
|
||||||
repeatLastN: 64,
|
|
||||||
repeatPenalty: 1.0,
|
|
||||||
presencePenalty: 1.5,
|
|
||||||
frequencyPenalty: 0.0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if got := resolveSamplingConfig(tt.model, tt.req); got != tt.want {
|
|
||||||
t.Fatalf("resolveSamplingConfig() = %+v, want %+v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolveSamplingConfigOverridesSpecifiedValues(t *testing.T) {
|
|
||||||
trueValue := true
|
|
||||||
temperature := float32(0.4)
|
|
||||||
topP := float32(0.6)
|
|
||||||
minP := float32(0.05)
|
|
||||||
topK := 12
|
|
||||||
repeatLastN := 32
|
|
||||||
repeatPenalty := float32(1.1)
|
|
||||||
presencePenalty := float32(0.7)
|
|
||||||
frequencyPenalty := float32(0.2)
|
|
||||||
|
|
||||||
got := resolveSamplingConfig(stubModel{}, Request{
|
|
||||||
TextCompletionsRequest: TextCompletionsRequest{
|
|
||||||
Think: &trueValue,
|
|
||||||
Options: struct {
|
|
||||||
Temperature *float32 `json:"temperature"`
|
|
||||||
TopP *float32 `json:"top_p"`
|
|
||||||
MinP *float32 `json:"min_p"`
|
|
||||||
TopK *int `json:"top_k"`
|
|
||||||
RepeatLastN *int `json:"repeat_last_n"`
|
|
||||||
RepeatPenalty *float32 `json:"repeat_penalty"`
|
|
||||||
PresencePenalty *float32 `json:"presence_penalty"`
|
|
||||||
FrequencyPenalty *float32 `json:"frequency_penalty"`
|
|
||||||
MaxTokens int `json:"max_tokens"`
|
|
||||||
NumPredict int `json:"num_predict"`
|
|
||||||
}{
|
|
||||||
Temperature: &temperature,
|
|
||||||
TopP: &topP,
|
|
||||||
MinP: &minP,
|
|
||||||
TopK: &topK,
|
|
||||||
RepeatLastN: &repeatLastN,
|
|
||||||
RepeatPenalty: &repeatPenalty,
|
|
||||||
PresencePenalty: &presencePenalty,
|
|
||||||
FrequencyPenalty: &frequencyPenalty,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
want := samplingConfig{
|
|
||||||
temperature: temperature,
|
|
||||||
topP: topP,
|
|
||||||
minP: minP,
|
|
||||||
topK: topK,
|
|
||||||
repeatLastN: repeatLastN,
|
|
||||||
repeatPenalty: repeatPenalty,
|
|
||||||
presencePenalty: presencePenalty,
|
|
||||||
frequencyPenalty: frequencyPenalty,
|
|
||||||
}
|
|
||||||
if got != want {
|
|
||||||
t.Fatalf("resolveSamplingConfig() = %+v, want %+v", got, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolveSamplingConfigMatchesGenericDefaults(t *testing.T) {
|
|
||||||
want := api.DefaultOptions()
|
|
||||||
got := defaultSamplingConfig(stubModel{}, nil)
|
|
||||||
|
|
||||||
if got.temperature != want.Temperature ||
|
|
||||||
got.topP != want.TopP ||
|
|
||||||
got.minP != want.MinP ||
|
|
||||||
got.topK != want.TopK ||
|
|
||||||
got.repeatLastN != want.RepeatLastN ||
|
|
||||||
got.repeatPenalty != want.RepeatPenalty ||
|
|
||||||
got.presencePenalty != want.PresencePenalty ||
|
|
||||||
got.frequencyPenalty != want.FrequencyPenalty {
|
|
||||||
t.Fatalf("defaultSamplingConfig() = %+v, want api defaults %+v", got, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -430,10 +430,6 @@ func (m *Model) NumLayers() int {
|
|||||||
return len(m.Layers)
|
return len(m.Layers)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) MaxContextLength() int {
|
|
||||||
return int(m.MaxPositionEmbeddings)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||||
return m.tok
|
return m.tok
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -733,7 +733,7 @@ func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
|||||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
func (m *Model) NumLayers() int { return len(m.Layers) }
|
||||||
|
|
||||||
// MaxContextLength returns the maximum context length
|
// MaxContextLength returns the maximum context length
|
||||||
func (m *Model) MaxContextLength() int { return int(m.MaxPositionEmbeddings) }
|
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
|
||||||
|
|
||||||
// VocabSize returns the vocabulary size
|
// VocabSize returns the vocabulary size
|
||||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
||||||
|
|||||||
@@ -262,10 +262,6 @@ func (m *Model) NumLayers() int {
|
|||||||
return len(m.Layers)
|
return len(m.Layers)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) MaxContextLength() int {
|
|
||||||
return int(m.MaxPositionEmbeddings)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||||
return m.tok
|
return m.tok
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,40 +15,6 @@ type LinearLayer interface {
|
|||||||
OutputDim() int32
|
OutputDim() int32
|
||||||
}
|
}
|
||||||
|
|
||||||
// Conv1d applies 1D convolution over NLC input.
|
|
||||||
type Conv1d struct {
|
|
||||||
Weight *mlx.Array
|
|
||||||
Bias *mlx.Array
|
|
||||||
Stride int32
|
|
||||||
Padding int32
|
|
||||||
Dilation int32
|
|
||||||
Groups int32
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewConv1d(weight, bias *mlx.Array, stride, padding, dilation, groups int32) *Conv1d {
|
|
||||||
if stride <= 0 {
|
|
||||||
stride = 1
|
|
||||||
}
|
|
||||||
if dilation <= 0 {
|
|
||||||
dilation = 1
|
|
||||||
}
|
|
||||||
if groups <= 0 {
|
|
||||||
groups = 1
|
|
||||||
}
|
|
||||||
return &Conv1d{
|
|
||||||
Weight: weight,
|
|
||||||
Bias: bias,
|
|
||||||
Stride: stride,
|
|
||||||
Padding: padding,
|
|
||||||
Dilation: dilation,
|
|
||||||
Groups: groups,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conv1d) Forward(x *mlx.Array) *mlx.Array {
|
|
||||||
return mlx.Conv1d(x, c.Weight, c.Bias, c.Stride, c.Padding, c.Dilation, c.Groups)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Linear applies an affine transformation: y = x @ W.T + b
|
// Linear applies an affine transformation: y = x @ W.T + b
|
||||||
type Linear struct {
|
type Linear struct {
|
||||||
Weight *mlx.Array
|
Weight *mlx.Array
|
||||||
|
|||||||
@@ -279,10 +279,6 @@ func (m *Model) NumLayers() int {
|
|||||||
return len(m.Layers)
|
return len(m.Layers)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) MaxContextLength() int {
|
|
||||||
return int(m.MaxPositionEmbeddings)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||||
return m.tok
|
return m.tok
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,166 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
package qwen3_5
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
|
||||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestParseConfigNestedDefaults(t *testing.T) {
|
|
||||||
data := []byte(`{
|
|
||||||
"model_type": "Qwen3_5MoeForConditionalGeneration",
|
|
||||||
"text_config": {
|
|
||||||
"hidden_size": 4096,
|
|
||||||
"intermediate_size": 14336,
|
|
||||||
"num_hidden_layers": 8,
|
|
||||||
"num_attention_heads": 32,
|
|
||||||
"num_key_value_heads": 8,
|
|
||||||
"head_dim": 128,
|
|
||||||
"linear_num_value_heads": 64,
|
|
||||||
"linear_num_key_heads": 16,
|
|
||||||
"linear_key_head_dim": 128,
|
|
||||||
"linear_value_head_dim": 128,
|
|
||||||
"linear_conv_kernel_dim": 4,
|
|
||||||
"num_experts": 16,
|
|
||||||
"num_experts_per_tok": 4,
|
|
||||||
"moe_intermediate_size": 2048,
|
|
||||||
"shared_expert_intermediate_size": 4096,
|
|
||||||
"rope_parameters": {
|
|
||||||
"rope_theta": 500000,
|
|
||||||
"partial_rotary_factor": 0.5
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}`)
|
|
||||||
|
|
||||||
cfg, err := parseConfig(data)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("parseConfig failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg.RopeTheta != 500000 {
|
|
||||||
t.Fatalf("rope theta mismatch: got %v", cfg.RopeTheta)
|
|
||||||
}
|
|
||||||
if cfg.RopeDim != 64 {
|
|
||||||
t.Fatalf("rope dim mismatch: got %d want 64", cfg.RopeDim)
|
|
||||||
}
|
|
||||||
if cfg.FullAttentionInterval != 4 {
|
|
||||||
t.Fatalf("full_attention_interval default mismatch: got %d want 4", cfg.FullAttentionInterval)
|
|
||||||
}
|
|
||||||
if !cfg.NormTopKProb {
|
|
||||||
t.Fatalf("norm_topk_prob should default to true for MoE")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLayerSelectionHelpers(t *testing.T) {
|
|
||||||
cfg := &Config{
|
|
||||||
NumHiddenLayers: 6,
|
|
||||||
FullAttentionInterval: 3,
|
|
||||||
NumExperts: 8,
|
|
||||||
DecoderSparseStep: 2,
|
|
||||||
MLPOnlyLayers: []int32{1},
|
|
||||||
}
|
|
||||||
|
|
||||||
if !layerIsLinear(cfg, 0) {
|
|
||||||
t.Fatalf("layer 0 should be linear")
|
|
||||||
}
|
|
||||||
if layerIsLinear(cfg, 2) {
|
|
||||||
t.Fatalf("layer 2 should be full attention")
|
|
||||||
}
|
|
||||||
|
|
||||||
if layerUsesMoE(cfg, 1) {
|
|
||||||
t.Fatalf("layer 1 should be forced dense by mlp_only_layers")
|
|
||||||
}
|
|
||||||
if !layerUsesMoE(cfg, 3) {
|
|
||||||
t.Fatalf("layer 3 should use moe with decoder_sparse_step=2")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolveTensorPathLayout(t *testing.T) {
|
|
||||||
dummy := mlx.New("dummy")
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
key string
|
|
||||||
wantContainer string
|
|
||||||
wantModel string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "standard",
|
|
||||||
key: "model.embed_tokens.weight",
|
|
||||||
wantContainer: "",
|
|
||||||
wantModel: "model.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "nested language model with inner model",
|
|
||||||
key: "model.language_model.model.embed_tokens.weight",
|
|
||||||
wantContainer: "model.language_model.",
|
|
||||||
wantModel: "model.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "nested language model without inner model",
|
|
||||||
key: "model.language_model.embed_tokens.weight",
|
|
||||||
wantContainer: "model.language_model.",
|
|
||||||
wantModel: "",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
layout := resolveTensorPathLayout(map[string]*mlx.Array{
|
|
||||||
tt.key: dummy,
|
|
||||||
})
|
|
||||||
|
|
||||||
if layout.containerPrefix != tt.wantContainer || layout.modelPrefix != tt.wantModel {
|
|
||||||
t.Fatalf(
|
|
||||||
"resolveTensorPathLayout() = {%q %q}, want {%q %q}",
|
|
||||||
layout.containerPrefix,
|
|
||||||
layout.modelPrefix,
|
|
||||||
tt.wantContainer,
|
|
||||||
tt.wantModel,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestModelRuntimeDefaults(t *testing.T) {
|
|
||||||
m := &Model{}
|
|
||||||
if m.DisablePromptCache() {
|
|
||||||
t.Fatal("DisablePromptCache() = true, want false")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewCachesLayout(t *testing.T) {
|
|
||||||
m := &Model{
|
|
||||||
Config: &Config{
|
|
||||||
LinearConvKernelDim: 4,
|
|
||||||
LinearNumKeyHeads: 2,
|
|
||||||
LinearKeyHeadDim: 8,
|
|
||||||
LinearNumValueHeads: 4,
|
|
||||||
LinearValueHeadDim: 16,
|
|
||||||
},
|
|
||||||
Layers: []*Layer{
|
|
||||||
{IsLinear: true},
|
|
||||||
{IsLinear: false},
|
|
||||||
{IsLinear: true},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
caches := m.NewCaches()
|
|
||||||
if len(caches) != len(m.Layers) {
|
|
||||||
t.Fatalf("len(caches) = %d, want %d", len(caches), len(m.Layers))
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := caches[0].(*cache.RecurrentCache); !ok {
|
|
||||||
t.Fatalf("cache[0] = %T, want *cache.RecurrentCache", caches[0])
|
|
||||||
}
|
|
||||||
if _, ok := caches[1].(*cache.KVCache); !ok {
|
|
||||||
t.Fatalf("cache[1] = %T, want *cache.KVCache", caches[1])
|
|
||||||
}
|
|
||||||
if _, ok := caches[2].(*cache.RecurrentCache); !ok {
|
|
||||||
t.Fatalf("cache[2] = %T, want *cache.RecurrentCache", caches[2])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
//go:build mlx
|
|
||||||
|
|
||||||
// Package qwen3_5_moe registers Qwen 3.5 MoE architecture aliases.
|
|
||||||
package qwen3_5_moe
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
|
||||||
"github.com/ollama/ollama/x/models/qwen3_5"
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
base.Register("Qwen3_5MoeForConditionalGeneration", qwen3_5.NewModel)
|
|
||||||
base.Register("Qwen3_5MoeForCausalLM", qwen3_5.NewModel)
|
|
||||||
base.Register("Qwen3NextMoeForConditionalGeneration", qwen3_5.NewModel)
|
|
||||||
base.Register("Qwen3NextMoeForCausalLM", qwen3_5.NewModel)
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user