mirror of
https://github.com/ollama/ollama.git
synced 2026-02-10 07:33:27 -05:00
Compare commits
3 Commits
main
...
pdevine/me
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4e9e49523c | ||
|
|
7b55952e4d | ||
|
|
75c452f3a1 |
5
go.mod
5
go.mod
@@ -13,7 +13,7 @@ require (
|
||||
github.com/mattn/go-sqlite3 v1.14.24
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/x448/float16 v0.8.4
|
||||
golang.org/x/sync v0.17.0
|
||||
golang.org/x/sys v0.37.0
|
||||
@@ -31,6 +31,8 @@ require (
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0
|
||||
github.com/tree-sitter/go-tree-sitter v0.25.0
|
||||
github.com/tree-sitter/tree-sitter-cpp v0.23.4
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8
|
||||
golang.org/x/image v0.22.0
|
||||
golang.org/x/mod v0.30.0
|
||||
@@ -60,6 +62,7 @@ require (
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/mailru/easyjson v0.7.7 // indirect
|
||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||
github.com/mattn/go-pointer v0.0.1 // indirect
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
|
||||
31
go.sum
31
go.sum
@@ -172,6 +172,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||
github.com/mattn/go-pointer v0.0.1 h1:n+XhsuGeVO6MEAp7xyEukFINEa+Quek5psIR/ylA6o0=
|
||||
github.com/mattn/go-pointer v0.0.1/go.mod h1:2zXcozF6qYGgmsG+SeTZz3oAbFLdD3OWqnUbNvJZAlc=
|
||||
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
|
||||
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
@@ -233,12 +235,39 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tkrajina/go-reflector v0.5.5 h1:gwoQFNye30Kk7NrExj8zm3zFtrGPqOkzFMLuQZg1DtQ=
|
||||
github.com/tkrajina/go-reflector v0.5.5/go.mod h1:ECbqLgccecY5kPmPmXg1MrHW585yMcDkVl6IvJe64T4=
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0 h1:ZedWk82egydDspGTryAatbX0/1NZDQbdiZLoCbOk4f8=
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0/go.mod h1:sjU00nti/PMEOZb07KljFlR+lJ+RotsC0GBQMv9EKls=
|
||||
github.com/tree-sitter/go-tree-sitter v0.25.0 h1:sx6kcg8raRFCvc9BnXglke6axya12krCJF5xJ2sftRU=
|
||||
github.com/tree-sitter/go-tree-sitter v0.25.0/go.mod h1:r77ig7BikoZhHrrsjAnv8RqGti5rtSyvDHPzgTPsUuU=
|
||||
github.com/tree-sitter/tree-sitter-c v0.23.4 h1:nBPH3FV07DzAD7p0GfNvXM+Y7pNIoPenQWBpvM++t4c=
|
||||
github.com/tree-sitter/tree-sitter-c v0.23.4/go.mod h1:MkI5dOiIpeN94LNjeCp8ljXN/953JCwAby4bClMr6bw=
|
||||
github.com/tree-sitter/tree-sitter-cpp v0.23.4 h1:LaWZsiqQKvR65yHgKmnaqA+uz6tlDJTJFCyFIeZU/8w=
|
||||
github.com/tree-sitter/tree-sitter-cpp v0.23.4/go.mod h1:doqNW64BriC7WBCQ1klf0KmJpdEvfxyXtoEybnBo6v8=
|
||||
github.com/tree-sitter/tree-sitter-embedded-template v0.23.2 h1:nFkkH6Sbe56EXLmZBqHHcamTpmz3TId97I16EnGy4rg=
|
||||
github.com/tree-sitter/tree-sitter-embedded-template v0.23.2/go.mod h1:HNPOhN0qF3hWluYLdxWs5WbzP/iE4aaRVPMsdxuzIaQ=
|
||||
github.com/tree-sitter/tree-sitter-go v0.23.4 h1:yt5KMGnTHS+86pJmLIAZMWxukr8W7Ae1STPvQUuNROA=
|
||||
github.com/tree-sitter/tree-sitter-go v0.23.4/go.mod h1:Jrx8QqYN0v7npv1fJRH1AznddllYiCMUChtVjxPK040=
|
||||
github.com/tree-sitter/tree-sitter-html v0.23.2 h1:1UYDV+Yd05GGRhVnTcbP58GkKLSHHZwVaN+lBZV11Lc=
|
||||
github.com/tree-sitter/tree-sitter-html v0.23.2/go.mod h1:gpUv/dG3Xl/eebqgeYeFMt+JLOY9cgFinb/Nw08a9og=
|
||||
github.com/tree-sitter/tree-sitter-java v0.23.5 h1:J9YeMGMwXYlKSP3K4Us8CitC6hjtMjqpeOf2GGo6tig=
|
||||
github.com/tree-sitter/tree-sitter-java v0.23.5/go.mod h1:NRKlI8+EznxA7t1Yt3xtraPk1Wzqh3GAIC46wxvc320=
|
||||
github.com/tree-sitter/tree-sitter-javascript v0.23.1 h1:1fWupaRC0ArlHJ/QJzsfQ3Ibyopw7ZfQK4xXc40Zveo=
|
||||
github.com/tree-sitter/tree-sitter-javascript v0.23.1/go.mod h1:lmGD1EJdCA+v0S1u2fFgepMg/opzSg/4pgFym2FPGAs=
|
||||
github.com/tree-sitter/tree-sitter-json v0.24.8 h1:tV5rMkihgtiOe14a9LHfDY5kzTl5GNUYe6carZBn0fQ=
|
||||
github.com/tree-sitter/tree-sitter-json v0.24.8/go.mod h1:F351KK0KGvCaYbZ5zxwx/gWWvZhIDl0eMtn+1r+gQbo=
|
||||
github.com/tree-sitter/tree-sitter-php v0.23.11 h1:iHewsLNDmznh8kgGyfWfujsZxIz1YGbSd2ZTEM0ZiP8=
|
||||
github.com/tree-sitter/tree-sitter-php v0.23.11/go.mod h1:T/kbfi+UcCywQfUNAJnGTN/fMSUjnwPXA8k4yoIks74=
|
||||
github.com/tree-sitter/tree-sitter-python v0.23.6 h1:qHnWFR5WhtMQpxBZRwiaU5Hk/29vGju6CVtmvu5Haas=
|
||||
github.com/tree-sitter/tree-sitter-python v0.23.6/go.mod h1:cpdthSy/Yoa28aJFBscFHlGiU+cnSiSh1kuDVtI8YeM=
|
||||
github.com/tree-sitter/tree-sitter-ruby v0.23.1 h1:T/NKHUA+iVbHM440hFx+lzVOzS4dV6z8Qw8ai+72bYo=
|
||||
github.com/tree-sitter/tree-sitter-ruby v0.23.1/go.mod h1:kUS4kCCQloFcdX6sdpr8p6r2rogbM6ZjTox5ZOQy8cA=
|
||||
github.com/tree-sitter/tree-sitter-rust v0.23.2 h1:6AtoooCW5GqNrRpfnvl0iUhxTAZEovEmLKDbyHlfw90=
|
||||
github.com/tree-sitter/tree-sitter-rust v0.23.2/go.mod h1:hfeGWic9BAfgTrc7Xf6FaOAguCFJRo3RBbs7QJ6D7MI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"github.com/ollama/ollama/runner/llamarunner"
|
||||
"github.com/ollama/ollama/runner/ollamarunner"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/mlxrunner"
|
||||
)
|
||||
|
||||
func Execute(args []string) error {
|
||||
@@ -17,6 +18,8 @@ func Execute(args []string) error {
|
||||
return ollamarunner.Execute(args[1:])
|
||||
case "--imagegen-engine":
|
||||
return imagegen.Execute(args[1:])
|
||||
case "--mlx-engine":
|
||||
return mlxrunner.Execute(args[1:])
|
||||
}
|
||||
}
|
||||
return llamarunner.Execute(args)
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/create"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// MinOllamaVersion is the minimum Ollama version required for safetensors models.
|
||||
@@ -35,7 +36,7 @@ type ModelfileConfig struct {
|
||||
type CreateOptions struct {
|
||||
ModelName string
|
||||
ModelDir string
|
||||
Quantize string // "q4", "q8", "nvfp4", or "mxfp8" for quantization
|
||||
Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization
|
||||
Modelfile *ModelfileConfig // template/system/license from Modelfile
|
||||
}
|
||||
|
||||
@@ -94,6 +95,7 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
||||
newLayerCreator(), newTensorLayerCreator(),
|
||||
newManifestWriter(opts, capabilities, parserName, rendererName),
|
||||
progressFn,
|
||||
newPackedTensorLayerCreator(),
|
||||
)
|
||||
} else {
|
||||
err = create.CreateImageGenModel(
|
||||
@@ -141,60 +143,33 @@ func newTensorLayerCreator() create.QuantizingTensorLayerCreator {
|
||||
}
|
||||
}
|
||||
|
||||
// createQuantizedLayers quantizes a tensor and returns the resulting layers.
|
||||
// createQuantizedLayers quantizes a tensor and returns a single combined layer.
|
||||
// The combined blob contains data, scale, and optional bias tensors with metadata.
|
||||
func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quantize string) ([]create.LayerInfo, error) {
|
||||
if !QuantizeSupported() {
|
||||
return nil, fmt.Errorf("quantization requires MLX support")
|
||||
}
|
||||
|
||||
// Quantize the tensor
|
||||
qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape, quantize)
|
||||
// Quantize the tensor into a single combined blob
|
||||
blobData, err := quantizeTensor(r, name, dtype, shape, quantize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
|
||||
}
|
||||
|
||||
// Create layer for quantized weight
|
||||
weightLayer, err := manifest.NewLayer(bytes.NewReader(qweightData), manifest.MediaTypeImageTensor)
|
||||
// Create single layer for the combined blob
|
||||
layer, err := manifest.NewLayer(bytes.NewReader(blobData), manifest.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create layer for scales
|
||||
scalesLayer, err := manifest.NewLayer(bytes.NewReader(scalesData), manifest.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers := []create.LayerInfo{
|
||||
return []create.LayerInfo{
|
||||
{
|
||||
Digest: weightLayer.Digest,
|
||||
Size: weightLayer.Size,
|
||||
MediaType: weightLayer.MediaType,
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: name,
|
||||
},
|
||||
{
|
||||
Digest: scalesLayer.Digest,
|
||||
Size: scalesLayer.Size,
|
||||
MediaType: scalesLayer.MediaType,
|
||||
Name: name + "_scale",
|
||||
},
|
||||
}
|
||||
|
||||
// Add qbiases layer if present (affine mode)
|
||||
if qbiasData != nil {
|
||||
qbiasLayer, err := manifest.NewLayer(bytes.NewReader(qbiasData), manifest.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layers = append(layers, create.LayerInfo{
|
||||
Digest: qbiasLayer.Digest,
|
||||
Size: qbiasLayer.Size,
|
||||
MediaType: qbiasLayer.MediaType,
|
||||
Name: name + "_qbias",
|
||||
})
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
// createUnquantizedLayer creates a single tensor layer without quantization.
|
||||
@@ -214,6 +189,58 @@ func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error
|
||||
}, nil
|
||||
}
|
||||
|
||||
// newPackedTensorLayerCreator returns a PackedTensorLayerCreator callback for
|
||||
// creating packed multi-tensor blob layers (used for expert groups).
|
||||
func newPackedTensorLayerCreator() create.PackedTensorLayerCreator {
|
||||
return func(groupName string, tensors []create.PackedTensorInput) (create.LayerInfo, error) {
|
||||
// Check if any tensor in the group needs quantization
|
||||
hasQuantize := false
|
||||
for _, t := range tensors {
|
||||
if t.Quantize != "" {
|
||||
hasQuantize = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
var blobReader io.Reader
|
||||
if hasQuantize {
|
||||
if !QuantizeSupported() {
|
||||
return create.LayerInfo{}, fmt.Errorf("quantization requires MLX support")
|
||||
}
|
||||
blobData, err := quantizePackedGroup(tensors)
|
||||
if err != nil {
|
||||
return create.LayerInfo{}, fmt.Errorf("failed to quantize packed group %s: %w", groupName, err)
|
||||
}
|
||||
blobReader = bytes.NewReader(blobData)
|
||||
} else {
|
||||
// Build unquantized packed blob using streaming reader
|
||||
// Extract raw tensor data from safetensors-wrapped readers
|
||||
var tds []*safetensors.TensorData
|
||||
for _, t := range tensors {
|
||||
rawData, err := safetensors.ExtractRawFromSafetensors(t.Reader)
|
||||
if err != nil {
|
||||
return create.LayerInfo{}, fmt.Errorf("failed to extract tensor %s: %w", t.Name, err)
|
||||
}
|
||||
td := safetensors.NewTensorDataFromBytes(t.Name, t.Dtype, t.Shape, rawData)
|
||||
tds = append(tds, td)
|
||||
}
|
||||
blobReader = safetensors.BuildPackedSafetensorsReader(tds)
|
||||
}
|
||||
|
||||
layer, err := manifest.NewLayer(blobReader, manifest.MediaTypeImageTensor)
|
||||
if err != nil {
|
||||
return create.LayerInfo{}, err
|
||||
}
|
||||
|
||||
return create.LayerInfo{
|
||||
Digest: layer.Digest,
|
||||
Size: layer.Size,
|
||||
MediaType: layer.MediaType,
|
||||
Name: groupName,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// newManifestWriter returns a ManifestWriter callback for writing the model manifest.
|
||||
func newManifestWriter(opts CreateOptions, capabilities []string, parserName, rendererName string) create.ManifestWriter {
|
||||
return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error {
|
||||
|
||||
@@ -3,128 +3,195 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
|
||||
"github.com/ollama/ollama/x/create"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// quantizeTensor loads a tensor from safetensors format, quantizes it,
|
||||
// and returns safetensors data for the quantized weights, scales, and biases.
|
||||
// Supported quantization types:
|
||||
// - "q4": affine 4-bit, group_size=32 (with qbiases)
|
||||
// - "nvfp4": NVIDIA FP4, group_size=16 (no qbiases, E4M3 scales)
|
||||
// - "q8": affine 8-bit, group_size=64 (with qbiases)
|
||||
// - "mxfp8": Microsoft MX FP8, group_size=32 (no qbiases, E4M3 scales)
|
||||
// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
|
||||
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||
// quantizeParams maps quantization type names to MLX quantize parameters.
|
||||
var quantizeParams = map[string]struct {
|
||||
groupSize int
|
||||
bits int
|
||||
mode string
|
||||
}{
|
||||
"int4": {32, 4, "affine"},
|
||||
"nvfp4": {16, 4, "nvfp4"},
|
||||
"int8": {64, 8, "affine"},
|
||||
"mxfp8": {32, 8, "mxfp8"},
|
||||
}
|
||||
|
||||
// loadAndQuantizeArray writes a safetensors reader to a temp file, loads it with MLX,
|
||||
// quantizes the tensor, and appends the resulting arrays (weight, scale, optional bias)
|
||||
// to the provided maps. If quantize is empty, the tensor is kept as-is.
|
||||
// Returns any temp file paths created (caller must clean up) and arrays needing eval.
|
||||
func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]*mlx.Array) (tmpPath string, toEval []*mlx.Array, nativeHandle *mlx.SafetensorsFile, err error) {
|
||||
tmpDir := ensureTempDir()
|
||||
|
||||
// Read safetensors data to a temp file (LoadSafetensorsNative needs a path)
|
||||
tmpFile, err := os.CreateTemp(tmpDir, "quant-input-*.safetensors")
|
||||
tmpFile, err := os.CreateTemp(tmpDir, "quant-*.safetensors")
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to create temp file: %w", err)
|
||||
return "", nil, nil, fmt.Errorf("failed to create temp file: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
defer os.Remove(tmpPath)
|
||||
tmpPath = tmpFile.Name()
|
||||
|
||||
if _, err := io.Copy(tmpFile, r); err != nil {
|
||||
tmpFile.Close()
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to write temp file: %w", err)
|
||||
return tmpPath, nil, nil, fmt.Errorf("failed to write temp file for %s: %w", name, err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
|
||||
// Load the tensor using MLX's native loader
|
||||
st, err := mlx.LoadSafetensorsNative(tmpPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to load safetensors: %w", err)
|
||||
return tmpPath, nil, nil, fmt.Errorf("failed to load safetensors for %s: %w", name, err)
|
||||
}
|
||||
defer st.Free()
|
||||
|
||||
// Get the tensor (it's stored as "data" in our minimal safetensors format)
|
||||
arr := st.Get("data")
|
||||
// Find the tensor key (may differ from name for single-tensor blobs)
|
||||
inputKey, err := findSafetensorsKey(tmpPath)
|
||||
if err != nil {
|
||||
st.Free()
|
||||
return tmpPath, nil, nil, fmt.Errorf("failed to read blob header for %s: %w", name, err)
|
||||
}
|
||||
|
||||
arr := st.Get(inputKey)
|
||||
if arr == nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("tensor 'data' not found in safetensors")
|
||||
st.Free()
|
||||
return tmpPath, nil, nil, fmt.Errorf("tensor %q not found in safetensors", inputKey)
|
||||
}
|
||||
|
||||
// Convert to BFloat16 if needed (quantize expects float type)
|
||||
if quantize == "" {
|
||||
arr = mlx.Contiguous(arr)
|
||||
arrays[name] = arr
|
||||
return tmpPath, []*mlx.Array{arr}, st, nil
|
||||
}
|
||||
|
||||
// Convert to float type if needed (quantize expects float)
|
||||
if arr.Dtype() != mlx.DtypeBFloat16 && arr.Dtype() != mlx.DtypeFloat32 && arr.Dtype() != mlx.DtypeFloat16 {
|
||||
arr = mlx.AsType(arr, mlx.DtypeBFloat16)
|
||||
mlx.Eval(arr)
|
||||
}
|
||||
|
||||
// Quantize based on quantization type
|
||||
var qweight, scales, qbiases *mlx.Array
|
||||
switch quantize {
|
||||
case "q4":
|
||||
// affine mode: group_size=32, bits=4 (with qbiases for zero-point offset)
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 32, 4, "affine")
|
||||
case "nvfp4":
|
||||
// NVIDIA FP4: group_size=16, bits=4 (no qbiases, E4M3 scales)
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 16, 4, "nvfp4")
|
||||
case "q8":
|
||||
// affine mode: group_size=64, bits=8 (with qbiases for zero-point offset)
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 64, 8, "affine")
|
||||
case "mxfp8":
|
||||
// Microsoft MX FP8: group_size=32, bits=8, E4M3 scales (no qbiases)
|
||||
qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "mxfp8")
|
||||
default:
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
|
||||
params, ok := quantizeParams[quantize]
|
||||
if !ok {
|
||||
st.Free()
|
||||
return tmpPath, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
|
||||
}
|
||||
|
||||
// Eval and make contiguous for data access
|
||||
qweight, scales, qbiases := mlx.Quantize(arr, params.groupSize, params.bits, params.mode)
|
||||
|
||||
qweight = mlx.Contiguous(qweight)
|
||||
scales = mlx.Contiguous(scales)
|
||||
arrays[name] = qweight
|
||||
arrays[name+".scale"] = scales
|
||||
toEval = append(toEval, qweight, scales)
|
||||
|
||||
if qbiases != nil {
|
||||
qbiases = mlx.Contiguous(qbiases)
|
||||
mlx.Eval(qweight, scales, qbiases)
|
||||
} else {
|
||||
mlx.Eval(qweight, scales)
|
||||
arrays[name+".bias"] = qbiases
|
||||
toEval = append(toEval, qbiases)
|
||||
}
|
||||
|
||||
// Get shapes
|
||||
qweightShape = qweight.Shape()
|
||||
scalesShape = scales.Shape()
|
||||
return tmpPath, toEval, st, nil
|
||||
}
|
||||
|
||||
// Save quantized weight using MLX's native safetensors (correctly handles uint32 dtype)
|
||||
qweightPath := filepath.Join(tmpDir, "qweight.safetensors")
|
||||
defer os.Remove(qweightPath)
|
||||
if err := mlx.SaveSafetensors(qweightPath, map[string]*mlx.Array{"data": qweight}); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save quantized weight: %w", err)
|
||||
// quantizeTensor loads a tensor from safetensors format, quantizes it,
|
||||
// and returns a single combined safetensors blob with the quantized weight, scale, and optional bias.
|
||||
// Tensor keys use the original tensor name: name, name.scale, name.bias.
|
||||
// The blob includes __metadata__ with quant_type and group_size.
|
||||
// Supported quantization types: "int4", "nvfp4", "int8", "mxfp8".
|
||||
func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quantize string) (blobData []byte, err error) {
|
||||
arrays := make(map[string]*mlx.Array)
|
||||
tmpPath, toEval, st, err := loadAndQuantizeArray(r, tensorName, quantize, arrays)
|
||||
if tmpPath != "" {
|
||||
defer os.Remove(tmpPath)
|
||||
}
|
||||
if st != nil {
|
||||
defer st.Free()
|
||||
}
|
||||
qweightData, err = os.ReadFile(qweightPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read quantized weight: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Save scales using MLX's native safetensors
|
||||
scalesPath := filepath.Join(tmpDir, "scales.safetensors")
|
||||
defer os.Remove(scalesPath)
|
||||
if err := mlx.SaveSafetensors(scalesPath, map[string]*mlx.Array{"data": scales}); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save scales: %w", err)
|
||||
}
|
||||
scalesData, err = os.ReadFile(scalesPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read scales: %w", err)
|
||||
mlx.Eval(toEval...)
|
||||
|
||||
// Build metadata for single-tensor blobs
|
||||
params := quantizeParams[quantize]
|
||||
metadata := map[string]string{
|
||||
"quant_type": quantize,
|
||||
"group_size": strconv.Itoa(params.groupSize),
|
||||
}
|
||||
|
||||
// Affine mode returns qbiases for zero-point offset
|
||||
if qbiases != nil {
|
||||
qbiasShape = qbiases.Shape()
|
||||
qbiasPath := filepath.Join(tmpDir, "qbias.safetensors")
|
||||
defer os.Remove(qbiasPath)
|
||||
if err := mlx.SaveSafetensors(qbiasPath, map[string]*mlx.Array{"data": qbiases}); err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save qbiases: %w", err)
|
||||
tmpDir := ensureTempDir()
|
||||
outPath := filepath.Join(tmpDir, "combined.safetensors")
|
||||
defer os.Remove(outPath)
|
||||
if err := mlx.SaveSafetensorsWithMetadata(outPath, arrays, metadata); err != nil {
|
||||
return nil, fmt.Errorf("failed to save combined blob: %w", err)
|
||||
}
|
||||
return os.ReadFile(outPath)
|
||||
}
|
||||
|
||||
// quantizePackedGroup quantizes multiple tensors and saves them all into a single
|
||||
// combined safetensors blob. Used for packing expert groups.
|
||||
// Each tensor may have a different quantization type (mixed-precision).
|
||||
// Returns the blob bytes. No __metadata__ is added because different tensors
|
||||
// may use different quantization types.
|
||||
func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
|
||||
allArrays := make(map[string]*mlx.Array)
|
||||
var allToEval []*mlx.Array
|
||||
var tmpPaths []string
|
||||
var handles []*mlx.SafetensorsFile
|
||||
|
||||
for _, input := range inputs {
|
||||
tmpPath, toEval, st, err := loadAndQuantizeArray(input.Reader, input.Name, input.Quantize, allArrays)
|
||||
if tmpPath != "" {
|
||||
tmpPaths = append(tmpPaths, tmpPath)
|
||||
}
|
||||
if st != nil {
|
||||
handles = append(handles, st)
|
||||
}
|
||||
qbiasData, err = os.ReadFile(qbiasPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read qbiases: %w", err)
|
||||
// Cleanup on error
|
||||
for _, h := range handles {
|
||||
h.Free()
|
||||
}
|
||||
for _, p := range tmpPaths {
|
||||
os.Remove(p)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
allToEval = append(allToEval, toEval...)
|
||||
}
|
||||
|
||||
return qweightData, scalesData, qbiasData, qweightShape, scalesShape, qbiasShape, nil
|
||||
mlx.Eval(allToEval...)
|
||||
|
||||
// Free native handles after eval
|
||||
for _, h := range handles {
|
||||
h.Free()
|
||||
}
|
||||
|
||||
// Save combined blob (no global metadata for mixed-precision packed blobs)
|
||||
tmpDir := ensureTempDir()
|
||||
outPath := filepath.Join(tmpDir, "packed-combined.safetensors")
|
||||
defer os.Remove(outPath)
|
||||
if err := mlx.SaveSafetensorsWithMetadata(outPath, allArrays, nil); err != nil {
|
||||
return nil, fmt.Errorf("failed to save packed blob: %w", err)
|
||||
}
|
||||
|
||||
blobData, err := os.ReadFile(outPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read packed blob: %w", err)
|
||||
}
|
||||
|
||||
for _, p := range tmpPaths {
|
||||
os.Remove(p)
|
||||
}
|
||||
|
||||
return blobData, nil
|
||||
}
|
||||
|
||||
// QuantizeSupported returns true if quantization is supported (MLX build)
|
||||
@@ -138,3 +205,33 @@ func ensureTempDir() string {
|
||||
os.MkdirAll(tmpDir, 0755)
|
||||
return tmpDir
|
||||
}
|
||||
|
||||
// findSafetensorsKey reads the first non-metadata tensor key from a safetensors file.
|
||||
func findSafetensorsKey(path string) (string, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var headerSize uint64
|
||||
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
|
||||
return "", err
|
||||
}
|
||||
headerBytes := make([]byte, headerSize)
|
||||
if _, err := io.ReadFull(f, headerBytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var header map[string]json.RawMessage
|
||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for k := range header {
|
||||
if k != "__metadata__" {
|
||||
return k, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("no tensor found in safetensors header")
|
||||
}
|
||||
|
||||
@@ -5,11 +5,18 @@ package client
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/ollama/ollama/x/create"
|
||||
)
|
||||
|
||||
// quantizeTensor is not available without MLX
|
||||
func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
|
||||
return nil, nil, nil, nil, nil, nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
|
||||
func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quantize string) (blobData []byte, err error) {
|
||||
return nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
|
||||
}
|
||||
|
||||
// quantizePackedGroup is not available without MLX
|
||||
func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
|
||||
return nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
|
||||
}
|
||||
|
||||
// QuantizeSupported returns false when MLX is not available
|
||||
|
||||
@@ -6,7 +6,9 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
@@ -228,7 +230,7 @@ type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
|
||||
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
|
||||
|
||||
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
|
||||
// When quantize is non-empty (e.g., "q8"), returns multiple layers (weight + scales + biases).
|
||||
// When quantize is non-empty (e.g., "int8"), returns multiple layers (weight + scales + biases).
|
||||
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error)
|
||||
|
||||
// ManifestWriter writes the manifest file.
|
||||
@@ -264,19 +266,19 @@ func ShouldQuantize(name, component string) bool {
|
||||
|
||||
// ShouldQuantizeTensor returns true if a tensor should be quantized based on name, shape, and quantize type.
|
||||
// This is a more detailed check that also considers tensor dimensions.
|
||||
// The quantize parameter specifies the quantization type (e.g., "q4", "nvfp4", "q8", "mxfp8").
|
||||
// The quantize parameter specifies the quantization type (e.g., "int4", "nvfp4", "int8", "mxfp8").
|
||||
func ShouldQuantizeTensor(name string, shape []int32, quantize string) bool {
|
||||
return GetTensorQuantization(name, shape, quantize) != ""
|
||||
}
|
||||
|
||||
// normalizeQuantType converts various quantization type aliases to canonical forms.
|
||||
// Supports: q4/Q4/int4/INT4/fp4/FP4 -> q4, q8/Q8/int8/INT8/fp8/FP8 -> q8, nvfp4/NVFP4, mxfp8/MXFP8
|
||||
// Supports: q4/Q4/int4/INT4/fp4/FP4 -> int4, q8/Q8/int8/INT8/fp8/FP8 -> int8, nvfp4/NVFP4, mxfp8/MXFP8
|
||||
func normalizeQuantType(quantize string) string {
|
||||
switch strings.ToUpper(quantize) {
|
||||
case "Q4", "INT4", "FP4":
|
||||
return "q4"
|
||||
return "int4"
|
||||
case "Q8", "INT8", "FP8":
|
||||
return "q8"
|
||||
return "int8"
|
||||
case "NVFP4":
|
||||
return "nvfp4"
|
||||
case "MXFP8":
|
||||
@@ -286,29 +288,12 @@ func normalizeQuantType(quantize string) string {
|
||||
}
|
||||
}
|
||||
|
||||
// getQuantGroupSize returns the group size for a given quantization type.
|
||||
// These must match the values used in quantize.go when creating quantized models.
|
||||
func getQuantGroupSize(quantize string) int {
|
||||
switch normalizeQuantType(quantize) {
|
||||
case "nvfp4":
|
||||
return 16
|
||||
case "q4":
|
||||
return 32
|
||||
case "mxfp8":
|
||||
return 32
|
||||
case "q8":
|
||||
return 64
|
||||
default:
|
||||
return 32
|
||||
}
|
||||
}
|
||||
|
||||
// GetTensorQuantization returns the appropriate quantization type for a tensor.
|
||||
// Returns "" if the tensor should not be quantized.
|
||||
// This implements mixed-precision quantization:
|
||||
// - Attention MLA weights (q_a, q_b, kv_a, kv_b): unquantized (most sensitive)
|
||||
// - Output projection, gate/up weights: q4 (less sensitive)
|
||||
// - Down projection weights: q8 (more sensitive, would be Q6 in GGML but no MLX kernel)
|
||||
// - Output projection, gate/up weights: int4 (less sensitive)
|
||||
// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
|
||||
// - Norms, embeddings, biases, routing gates: no quantization
|
||||
func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||
// Use basic name-based check first
|
||||
@@ -330,12 +315,12 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||
quantNorm := normalizeQuantType(quantize)
|
||||
|
||||
// MLX quantization requires last dimension to be divisible by group size
|
||||
// nvfp4: 16, q4/mxfp8: 32, q8: 64
|
||||
// nvfp4: 16, int4/mxfp8: 32, int8: 64
|
||||
groupSize := int32(32)
|
||||
switch quantNorm {
|
||||
case "nvfp4":
|
||||
groupSize = 16
|
||||
case "q8":
|
||||
case "int8":
|
||||
groupSize = 64
|
||||
}
|
||||
if shape[len(shape)-1]%groupSize != 0 {
|
||||
@@ -363,13 +348,13 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||
return "" // No quantization - keep bf16
|
||||
}
|
||||
|
||||
// Down projection weights - use Q8 (would be Q6_K in GGML, but MLX has no Q6 kernel)
|
||||
// Down projection weights - use INT8 (would be Q6_K in GGML, but MLX has no Q6 kernel)
|
||||
// mlp.down_proj, mlp.experts.X.down_proj, mlp.shared_experts.down_proj
|
||||
if strings.Contains(name, "down_proj") {
|
||||
return "q8"
|
||||
return "int8"
|
||||
}
|
||||
|
||||
// Output projection, gate/up weights - use requested quantization (Q4)
|
||||
// Output projection, gate/up weights - use requested quantization (INT4)
|
||||
// o_proj, gate_proj, up_proj
|
||||
if strings.Contains(name, "o_proj") ||
|
||||
strings.Contains(name, "gate_proj") ||
|
||||
@@ -386,14 +371,69 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||
return quantNorm
|
||||
}
|
||||
|
||||
// expertGroupRegexp matches expert tensor names and captures the group prefix.
|
||||
// Matches: model.layers.{L}.mlp.experts.{E}.{proj}.weight (and .scale, .bias suffixes)
|
||||
// Captures: model.layers.{L}.mlp.experts
|
||||
var expertGroupRegexp = regexp.MustCompile(`^(model\.layers\.\d+\.mlp\.(?:shared_)?experts)\..*\.weight`)
|
||||
|
||||
// ExpertGroupPrefix returns the group prefix for expert tensors that should be packed together.
|
||||
// For example:
|
||||
// - "model.layers.1.mlp.experts.0.down_proj.weight" -> "model.layers.1.mlp.experts"
|
||||
// - "model.layers.1.mlp.shared_experts.down_proj.weight" -> "model.layers.1.mlp.shared_experts"
|
||||
// - "model.layers.0.mlp.down_proj.weight" -> "" (dense layer, no experts)
|
||||
// - "model.layers.1.mlp.gate.weight" -> "" (routing gate, not an expert)
|
||||
func ExpertGroupPrefix(tensorName string) string {
|
||||
m := expertGroupRegexp.FindStringSubmatch(tensorName)
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
return m[1]
|
||||
}
|
||||
|
||||
// PackedTensorInput holds metadata for a tensor that will be packed into a multi-tensor blob.
|
||||
type PackedTensorInput struct {
|
||||
Name string
|
||||
Dtype string
|
||||
Shape []int32
|
||||
Quantize string // per-tensor quantization type (may differ within group)
|
||||
Reader io.Reader // safetensors-wrapped tensor data
|
||||
}
|
||||
|
||||
// PackedTensorLayerCreator creates a single blob layer containing multiple packed tensors.
|
||||
// groupName is the group prefix (e.g., "model.layers.1.mlp.experts").
|
||||
type PackedTensorLayerCreator func(groupName string, tensors []PackedTensorInput) (LayerInfo, error)
|
||||
|
||||
// CreateSafetensorsModel imports a standard safetensors model from a directory.
|
||||
// This handles Hugging Face style models with config.json and *.safetensors files.
|
||||
// Stores each tensor as a separate blob for fine-grained deduplication.
|
||||
// If quantize is non-empty (e.g., "q8"), eligible tensors will be quantized.
|
||||
func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
// Expert tensors are packed into per-layer blobs when createPackedLayer is non-nil.
|
||||
// If quantize is non-empty (e.g., "int8"), eligible tensors will be quantized.
|
||||
func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string), createPackedLayer ...PackedTensorLayerCreator) error {
|
||||
var layers []LayerInfo
|
||||
var configLayer LayerInfo
|
||||
|
||||
// Resolve the optional packed layer creator
|
||||
var packedCreator PackedTensorLayerCreator
|
||||
if len(createPackedLayer) > 0 {
|
||||
packedCreator = createPackedLayer[0]
|
||||
}
|
||||
|
||||
// Accumulate expert tensors by group prefix for packing.
|
||||
// Readers reference file-backed SectionReaders, so we keep extractors
|
||||
// open until each group is flushed to avoid buffering tensor data in memory.
|
||||
expertGroups := make(map[string][]PackedTensorInput)
|
||||
var expertGroupOrder []string
|
||||
|
||||
// Track open extractors so we can close them after flushing groups
|
||||
var openExtractors []*safetensors.TensorExtractor
|
||||
|
||||
closeExtractors := func() {
|
||||
for _, ext := range openExtractors {
|
||||
ext.Close()
|
||||
}
|
||||
openExtractors = nil
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(modelDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read directory: %w", err)
|
||||
@@ -410,6 +450,7 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
||||
// Extract individual tensors from safetensors file
|
||||
extractor, err := safetensors.OpenForExtraction(stPath)
|
||||
if err != nil {
|
||||
closeExtractors()
|
||||
return fmt.Errorf("failed to open %s: %w", stPath, err)
|
||||
}
|
||||
|
||||
@@ -420,10 +461,14 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
||||
}
|
||||
fn(fmt.Sprintf("importing %s (%d tensors%s)", entry.Name(), len(tensorNames), quantizeMsg))
|
||||
|
||||
// Track whether this extractor has expert tensors that need to stay open
|
||||
hasExpertTensors := false
|
||||
|
||||
for _, tensorName := range tensorNames {
|
||||
td, err := extractor.GetTensor(tensorName)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
closeExtractors()
|
||||
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
|
||||
}
|
||||
|
||||
@@ -434,20 +479,65 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
||||
quantizeType = GetTensorQuantization(tensorName, td.Shape, quantize)
|
||||
}
|
||||
|
||||
// Store as minimal safetensors format (88 bytes header overhead)
|
||||
// This enables native mmap loading via mlx_load_safetensors
|
||||
// createTensorLayer returns multiple layers if quantizing (weight + scales)
|
||||
newLayers, err := createTensorLayer(td.SafetensorsReader(), tensorName, td.Dtype, td.Shape, quantizeType)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return fmt.Errorf("failed to create layer for %s: %w", tensorName, err)
|
||||
// Check if this tensor belongs to an expert group for packing
|
||||
groupPrefix := ""
|
||||
if packedCreator != nil {
|
||||
groupPrefix = ExpertGroupPrefix(tensorName)
|
||||
}
|
||||
|
||||
if groupPrefix != "" {
|
||||
// Accumulate expert tensor for packed blob.
|
||||
// The Reader uses a file-backed SectionReader, so we must
|
||||
// keep the extractor open until this group is flushed.
|
||||
hasExpertTensors = true
|
||||
if _, exists := expertGroups[groupPrefix]; !exists {
|
||||
expertGroupOrder = append(expertGroupOrder, groupPrefix)
|
||||
}
|
||||
expertGroups[groupPrefix] = append(expertGroups[groupPrefix], PackedTensorInput{
|
||||
Name: tensorName,
|
||||
Dtype: td.Dtype,
|
||||
Shape: td.Shape,
|
||||
Quantize: quantizeType,
|
||||
Reader: td.SafetensorsReader(),
|
||||
})
|
||||
} else {
|
||||
// Store as minimal safetensors format (88 bytes header overhead)
|
||||
// This enables native mmap loading via mlx_load_safetensors
|
||||
// createTensorLayer returns multiple layers if quantizing (weight + scales)
|
||||
newLayers, err := createTensorLayer(td.SafetensorsReader(), tensorName, td.Dtype, td.Shape, quantizeType)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
closeExtractors()
|
||||
return fmt.Errorf("failed to create layer for %s: %w", tensorName, err)
|
||||
}
|
||||
layers = append(layers, newLayers...)
|
||||
}
|
||||
layers = append(layers, newLayers...)
|
||||
}
|
||||
|
||||
extractor.Close()
|
||||
if hasExpertTensors {
|
||||
// Keep extractor open - readers still reference its file handle
|
||||
openExtractors = append(openExtractors, extractor)
|
||||
} else {
|
||||
extractor.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Process accumulated expert groups into packed blobs, then close extractors
|
||||
if packedCreator != nil {
|
||||
sort.Strings(expertGroupOrder)
|
||||
for _, groupName := range expertGroupOrder {
|
||||
tensors := expertGroups[groupName]
|
||||
fn(fmt.Sprintf("packing %s (%d tensors)", groupName, len(tensors)))
|
||||
layer, err := packedCreator(groupName, tensors)
|
||||
if err != nil {
|
||||
closeExtractors()
|
||||
return fmt.Errorf("failed to create packed layer for %s: %w", groupName, err)
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
}
|
||||
closeExtractors()
|
||||
|
||||
// Process all JSON config files
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") {
|
||||
@@ -487,23 +577,6 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
|
||||
return fmt.Errorf("config.json not found in %s", modelDir)
|
||||
}
|
||||
|
||||
// Create model_index.json with quantization info if quantizing
|
||||
if quantize != "" {
|
||||
modelIndex := map[string]any{
|
||||
"quantization": strings.ToUpper(quantize),
|
||||
"group_size": getQuantGroupSize(quantize),
|
||||
}
|
||||
indexData, err := json.MarshalIndent(modelIndex, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal model_index.json: %w", err)
|
||||
}
|
||||
indexLayer, err := createLayer(strings.NewReader(string(indexData)), "application/vnd.ollama.image.json", "model_index.json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create model_index.json layer: %w", err)
|
||||
}
|
||||
layers = append(layers, indexLayer)
|
||||
}
|
||||
|
||||
fn(fmt.Sprintf("writing manifest for %s", modelName))
|
||||
|
||||
if err := writeManifest(modelName, configLayer, layers); err != nil {
|
||||
|
||||
@@ -586,6 +586,39 @@ func TestShouldQuantizeTensor(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpertGroupPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want string
|
||||
}{
|
||||
// Expert tensors should return the group prefix
|
||||
{"model.layers.1.mlp.experts.0.down_proj.weight", "model.layers.1.mlp.experts"},
|
||||
{"model.layers.1.mlp.experts.63.gate_proj.weight", "model.layers.1.mlp.experts"},
|
||||
{"model.layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.mlp.experts"},
|
||||
|
||||
// Shared expert tensors should return their own group prefix
|
||||
{"model.layers.1.mlp.shared_experts.down_proj.weight", "model.layers.1.mlp.shared_experts"},
|
||||
{"model.layers.2.mlp.shared_experts.gate_proj.weight", "model.layers.2.mlp.shared_experts"},
|
||||
|
||||
// Non-expert tensors should return empty string
|
||||
{"model.layers.0.mlp.down_proj.weight", ""}, // dense layer, no experts
|
||||
{"model.layers.1.mlp.gate.weight", ""}, // routing gate, not an expert
|
||||
{"model.embed_tokens.weight", ""}, // embedding
|
||||
{"model.layers.0.self_attn.q_proj.weight", ""}, // attention
|
||||
{"model.norm.weight", ""}, // norm
|
||||
{"lm_head.weight", ""}, // output head
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ExpertGroupPrefix(tt.name)
|
||||
if got != tt.want {
|
||||
t.Errorf("ExpertGroupPrefix(%q) = %q, want %q", tt.name, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_WithQuantize(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
@@ -751,7 +784,7 @@ func TestCreateImageGenModel_WithQuantize(t *testing.T) {
|
||||
|
||||
progressFn := func(status string) {}
|
||||
|
||||
err := CreateImageGenModel("test-imagegen", dir, "q8", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
err := CreateImageGenModel("test-imagegen", dir, "int8", createLayer, createTensorLayer, writeManifest, progressFn)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateImageGenModel failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -15,15 +15,15 @@ import (
|
||||
// CreateImageGenModel imports an image generation model from a directory.
|
||||
// Stores each tensor as a separate blob for fine-grained deduplication.
|
||||
// If quantize is specified, linear weights in transformer/text_encoder are quantized.
|
||||
// Supported quantization types: q4, q8, nvfp4, mxfp8 (or empty for no quantization).
|
||||
// Supported quantization types: int4, int8, nvfp4, mxfp8 (or empty for no quantization).
|
||||
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
|
||||
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
// Validate quantization type
|
||||
switch quantize {
|
||||
case "", "q4", "q8", "nvfp4", "mxfp8":
|
||||
case "", "int4", "int8", "nvfp4", "mxfp8":
|
||||
// valid
|
||||
default:
|
||||
return fmt.Errorf("unsupported quantization type %q: supported types are q4, q8, nvfp4, mxfp8", quantize)
|
||||
return fmt.Errorf("unsupported quantization type %q: supported types are int4, int8, nvfp4, mxfp8", quantize)
|
||||
}
|
||||
|
||||
var layers []LayerInfo
|
||||
@@ -214,7 +214,7 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer
|
||||
|
||||
// canQuantizeShape returns true if a tensor shape is compatible with MLX quantization.
|
||||
// MLX requires the last dimension to be divisible by the group size.
|
||||
// nvfp4: 16, q4/mxfp8: 32, q8: 64
|
||||
// nvfp4: 16, int4/mxfp8: 32, int8: 64
|
||||
func canQuantizeShape(shape []int32, quantize string) bool {
|
||||
if len(shape) < 2 {
|
||||
return false
|
||||
@@ -223,7 +223,7 @@ func canQuantizeShape(shape []int32, quantize string) bool {
|
||||
switch strings.ToUpper(quantize) {
|
||||
case "NVFP4":
|
||||
groupSize = 16
|
||||
case "Q8":
|
||||
case "INT8":
|
||||
groupSize = 64
|
||||
}
|
||||
return shape[len(shape)-1]%groupSize == 0
|
||||
|
||||
158
x/imagegen/docs/blob-format.md
Normal file
158
x/imagegen/docs/blob-format.md
Normal file
@@ -0,0 +1,158 @@
|
||||
# Tensor Blob Format
|
||||
|
||||
Ollama stores model tensors as individual blobs in the safetensors format. Each blob contains a logical tensor (or a combined quantized tensor with its scale/bias components), or a group of logical tensors (e.g. shared experts for a given layer along with the scale/bias components for that tensor).
|
||||
|
||||
## Safetensors File Format
|
||||
|
||||
Every blob follows the [safetensors](https://github.com/huggingface/safetensors) layout:
|
||||
|
||||
```
|
||||
[8 bytes: header_size (uint64 LE)] [header_size bytes: JSON header] [tensor data region]
|
||||
```
|
||||
|
||||
The JSON header maps tensor names to their dtype, shape, and byte offsets within the data region. A special `__metadata__` key holds string-to-string metadata.
|
||||
|
||||
## Unquantized Blobs
|
||||
|
||||
An unquantized blob stores a single tensor keyed by its name:
|
||||
|
||||
```json
|
||||
{
|
||||
"model.layers.0.self_attn.q_proj.weight": {
|
||||
"dtype": "BF16",
|
||||
"shape": [2560, 2560],
|
||||
"data_offsets": [0, 13107200]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
The tensor key is the full tensor name. Dtype is typically `BF16` or `F32`.
|
||||
|
||||
## Quantized Blobs (Combined Format)
|
||||
|
||||
A quantized blob stores the packed weight, scaling factors, and optional zero-point biases in a single file. Tensor keys use the tensor name, with `.scale` and `.bias` suffixes for the auxiliary tensors:
|
||||
|
||||
```json
|
||||
{
|
||||
"__metadata__": {
|
||||
"quant_type": "int4",
|
||||
"group_size": "32"
|
||||
},
|
||||
"model.layers.0.mlp.up_proj.weight": {
|
||||
"dtype": "U32",
|
||||
"shape": [2560, 320],
|
||||
"data_offsets": [0, 3276800]
|
||||
},
|
||||
"model.layers.0.mlp.up_proj.weight.scale": {
|
||||
"dtype": "BF16",
|
||||
"shape": [2560, 80],
|
||||
"data_offsets": [3276800, 3686400]
|
||||
},
|
||||
"model.layers.0.mlp.up_proj.weight.bias": {
|
||||
"dtype": "BF16",
|
||||
"shape": [2560, 80],
|
||||
"data_offsets": [3686400, 4096000]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Metadata Fields
|
||||
|
||||
| Field | Description |
|
||||
|---|---|
|
||||
| `quant_type` | Quantization type: `int4`, `int8`, `nvfp4`, or `mxfp8` |
|
||||
| `group_size` | Number of elements per quantization group (e.g., `32`, `64`) |
|
||||
|
||||
### Tensor Keys
|
||||
|
||||
| Key | Description |
|
||||
|---|---|
|
||||
| `{name}` | Packed quantized weights (dtype `U32`) |
|
||||
| `{name}.scale` | Per-group scaling factors |
|
||||
| `{name}.bias` | Per-group zero-point offsets (affine modes only) |
|
||||
|
||||
## Quantization Types
|
||||
|
||||
| Type | Bits | Group Size | Mode | Has Bias |
|
||||
|---|---|---|---|---|
|
||||
| `int4` | 4 | 32 | affine | yes |
|
||||
| `int8` | 8 | 64 | affine | yes |
|
||||
| `nvfp4` | 4 | 16 | nvfp4 | no |
|
||||
| `mxfp8` | 8 | 32 | mxfp8 | no |
|
||||
|
||||
**Affine modes** (`int4`, `int8`) use `scale + bias` for dequantization. The bias tensor provides the zero-point offset.
|
||||
|
||||
**Non-affine modes** (`nvfp4`, `mxfp8`) use only `scale` with specialized E4M3 scale formats.
|
||||
|
||||
### Packed Weight Shape
|
||||
|
||||
Quantized weights are packed into `uint32` values:
|
||||
- **4-bit** (int4, nvfp4): 8 values per uint32, so `packed_cols = original_cols / 8`
|
||||
- **8-bit** (int8, mxfp8): 4 values per uint32, so `packed_cols = original_cols / 4`
|
||||
|
||||
Scale shape: `[rows, original_cols / group_size]`
|
||||
|
||||
## Manifest References
|
||||
|
||||
Blobs are referenced from the model manifest as layers:
|
||||
|
||||
```json
|
||||
{
|
||||
"mediaType": "application/vnd.ollama.image.tensor",
|
||||
"digest": "sha256:abc123...",
|
||||
"size": 4096150,
|
||||
"name": "model.layers.0.mlp.up_proj.weight"
|
||||
}
|
||||
```
|
||||
|
||||
Each tensor (quantized or not) is one layer in the manifest. The layer name matches the tensor key in the blob header.
|
||||
|
||||
## Packed Blobs (Expert Groups)
|
||||
|
||||
For MoE (Mixture of Experts) models, expert tensors from the same layer are packed into a single blob to reduce blob count and improve loading efficiency. A packed blob is a standard safetensors file containing multiple tensor entries:
|
||||
|
||||
```json
|
||||
{
|
||||
"model.layers.1.mlp.experts.0.down_proj.weight": {
|
||||
"dtype": "U32",
|
||||
"shape": [2560, 640],
|
||||
"data_offsets": [0, 6553600]
|
||||
},
|
||||
"model.layers.1.mlp.experts.0.down_proj.weight.scale": {
|
||||
"dtype": "BF16",
|
||||
"shape": [2560, 40],
|
||||
"data_offsets": [6553600, 6963200]
|
||||
},
|
||||
"model.layers.1.mlp.experts.0.gate_proj.weight": {
|
||||
"dtype": "U32",
|
||||
"shape": [10240, 320],
|
||||
"data_offsets": [6963200, 20070400]
|
||||
},
|
||||
"model.layers.1.mlp.experts.0.gate_proj.weight.scale": { "..." : "..." }
|
||||
}
|
||||
```
|
||||
|
||||
### Grouping Rules
|
||||
|
||||
- `model.layers.{L}.mlp.experts.*` tensors are packed into one blob per layer
|
||||
- `model.layers.{L}.mlp.shared_experts.*` tensors are packed into one blob per layer
|
||||
- All other tensors remain as individual blobs
|
||||
|
||||
### Manifest Representation
|
||||
|
||||
One manifest layer per packed group, using the group prefix as the layer name:
|
||||
|
||||
```json
|
||||
{
|
||||
"mediaType": "application/vnd.ollama.image.tensor",
|
||||
"digest": "sha256:...",
|
||||
"size": 123456789,
|
||||
"name": "model.layers.1.mlp.experts"
|
||||
}
|
||||
```
|
||||
|
||||
## Loading
|
||||
|
||||
At load time, `mlx_load_safetensors` opens each blob via mmap for zero-copy access. For combined quantized blobs, the loader extracts `{name}`, `{name}.scale`, and `{name}.bias` tensors and caches them as `name`, `name + "_scale"`, and `name + "_qbias"` respectively, maintaining compatibility with the weight loading interface.
|
||||
|
||||
For packed blobs, if the manifest layer name (group prefix) is not found as a tensor key, the loader parses the blob header to discover all tensor names and loads each individually.
|
||||
@@ -1,11 +1,13 @@
|
||||
package manifest
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
@@ -205,17 +207,12 @@ func GetModelInfo(modelName string) (*ModelInfo, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: detect quantization from tensor names if not in config
|
||||
// Fallback: detect quantization from first tensor blob's __metadata__
|
||||
if info.Quantization == "" {
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if strings.HasSuffix(layer.Name, ".weight_scale") {
|
||||
info.Quantization = "Q8"
|
||||
break
|
||||
}
|
||||
}
|
||||
if info.Quantization == "" {
|
||||
info.Quantization = "BF16"
|
||||
}
|
||||
info.Quantization = detectQuantizationFromBlobs(manifest)
|
||||
}
|
||||
if info.Quantization == "" {
|
||||
info.Quantization = "BF16"
|
||||
}
|
||||
|
||||
// Fallback: estimate parameter count if not in config
|
||||
@@ -223,9 +220,7 @@ func GetModelInfo(modelName string) (*ModelInfo, error) {
|
||||
var totalSize int64
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.tensor" {
|
||||
if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
|
||||
totalSize += layer.Size
|
||||
}
|
||||
totalSize += layer.Size
|
||||
}
|
||||
}
|
||||
// Assume BF16 (2 bytes/param) as rough estimate
|
||||
@@ -234,3 +229,79 @@ func GetModelInfo(modelName string) (*ModelInfo, error) {
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// detectQuantizationFromBlobs reads __metadata__ from the first tensor blob
|
||||
// to detect quantization type.
|
||||
func detectQuantizationFromBlobs(manifest *ModelManifest) string {
|
||||
for _, layer := range manifest.Manifest.Layers {
|
||||
if layer.MediaType != "application/vnd.ollama.image.tensor" {
|
||||
continue
|
||||
}
|
||||
data, err := readBlobHeader(manifest.BlobPath(layer.Digest))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var header map[string]json.RawMessage
|
||||
if json.Unmarshal(data, &header) != nil {
|
||||
continue
|
||||
}
|
||||
if metaRaw, ok := header["__metadata__"]; ok {
|
||||
var meta map[string]string
|
||||
if json.Unmarshal(metaRaw, &meta) == nil {
|
||||
if qt, ok := meta["quant_type"]; ok && qt != "" {
|
||||
return strings.ToUpper(qt)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Only check the first tensor blob
|
||||
break
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ParseBlobTensorNames reads a safetensors blob and returns all "main" tensor names.
|
||||
// Filters out __metadata__, .scale, and .bias entries to return only primary weight tensors.
|
||||
func ParseBlobTensorNames(path string) ([]string, error) {
|
||||
data, err := readBlobHeader(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var header map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data, &header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var names []string
|
||||
for k := range header {
|
||||
if k == "__metadata__" || strings.HasSuffix(k, ".scale") || strings.HasSuffix(k, ".bias") {
|
||||
continue
|
||||
}
|
||||
names = append(names, k)
|
||||
}
|
||||
|
||||
sort.Strings(names)
|
||||
return names, nil
|
||||
}
|
||||
|
||||
// readBlobHeader reads the JSON header bytes from a safetensors blob file.
|
||||
func readBlobHeader(path string) ([]byte, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var headerSize uint64
|
||||
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if headerSize > 1024*1024 {
|
||||
return nil, fmt.Errorf("header too large: %d", headerSize)
|
||||
}
|
||||
data := make([]byte, headerSize)
|
||||
if _, err := io.ReadFull(f, data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ package manifest
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
@@ -18,6 +19,8 @@ type ManifestWeights struct {
|
||||
tensors map[string]ManifestLayer // name -> layer
|
||||
cache map[string]*mlx.Array // name -> loaded array
|
||||
nativeCache []*mlx.SafetensorsFile // keep native handles alive
|
||||
quantType string // quantization type from blob metadata (e.g., "int4", "int8")
|
||||
groupSize int // quantization group size from blob metadata
|
||||
}
|
||||
|
||||
// LoadWeightsFromManifest creates a weight loader from manifest storage.
|
||||
@@ -54,43 +57,129 @@ func LoadWeightsFromManifest(manifest *ModelManifest, component string) (*Manife
|
||||
|
||||
// Load loads all tensor blobs using native mmap (zero-copy).
|
||||
// Blobs are stored in safetensors format for native mlx_load_safetensors mmap.
|
||||
// If dtype is non-zero, tensors are converted to the specified dtype.
|
||||
// Combined quantized blobs contain tensors keyed by name, name+".scale", and optional name+".bias"
|
||||
// with quantization metadata. Scale and bias are stored in cache as name+"_scale"
|
||||
// and name+"_qbias" for compatibility with downstream loading code.
|
||||
// Packed blobs (e.g., for expert groups) contain multiple tensors; the manifest name
|
||||
// is a group prefix and individual tensors are loaded by their actual names from the blob.
|
||||
// If dtype is non-zero, non-quantized tensors are converted to the specified dtype.
|
||||
func (mw *ManifestWeights) Load(dtype mlx.Dtype) error {
|
||||
// Track native handles to free after batch eval
|
||||
nativeHandles := make([]*mlx.SafetensorsFile, 0, len(mw.tensors))
|
||||
arrays := make([]*mlx.Array, 0, len(mw.tensors))
|
||||
|
||||
// Group tensors by digest to avoid loading the same blob multiple times
|
||||
type blobEntry struct {
|
||||
name string
|
||||
layer ManifestLayer
|
||||
}
|
||||
blobGroups := make(map[string][]blobEntry)
|
||||
for name, layer := range mw.tensors {
|
||||
path := mw.manifest.BlobPath(layer.Digest)
|
||||
blobGroups[layer.Digest] = append(blobGroups[layer.Digest], blobEntry{name, layer})
|
||||
}
|
||||
|
||||
for digest, entries := range blobGroups {
|
||||
path := mw.manifest.BlobPath(digest)
|
||||
|
||||
// Load blob as safetensors (native mmap, zero-copy)
|
||||
sf, err := mlx.LoadSafetensorsNative(path)
|
||||
if err != nil {
|
||||
// Free any handles we've accumulated
|
||||
for _, h := range nativeHandles {
|
||||
h.Free()
|
||||
}
|
||||
return fmt.Errorf("load %s: %w", name, err)
|
||||
return fmt.Errorf("load %s: %w", entries[0].name, err)
|
||||
}
|
||||
nativeHandles = append(nativeHandles, sf)
|
||||
|
||||
// Blob contains single tensor named "data"
|
||||
arr := sf.Get("data")
|
||||
if arr == nil {
|
||||
for _, h := range nativeHandles {
|
||||
h.Free()
|
||||
// Read quantization metadata from blob
|
||||
if qt := sf.GetMetadata("quant_type"); qt != "" && mw.quantType == "" {
|
||||
mw.quantType = qt
|
||||
if gs := sf.GetMetadata("group_size"); gs != "" {
|
||||
mw.groupSize, _ = strconv.Atoi(gs)
|
||||
}
|
||||
return fmt.Errorf("tensor 'data' not found in blob for %s", name)
|
||||
}
|
||||
|
||||
// Convert dtype if needed
|
||||
if dtype != 0 && arr.Dtype() != dtype {
|
||||
arr = mlx.AsType(arr, dtype)
|
||||
for _, entry := range entries {
|
||||
name := entry.name
|
||||
|
||||
// Try to get tensor by stripped name first, then with component prefix.
|
||||
// Blobs may store tensors with the full prefixed name (e.g., "text_encoder/model.layers.0.weight")
|
||||
// while the tensors map uses stripped names (e.g., "model.layers.0.weight").
|
||||
lookupName := name
|
||||
arr := sf.Get(lookupName)
|
||||
if arr == nil && mw.component != "" {
|
||||
lookupName = mw.component + "/" + name
|
||||
arr = sf.Get(lookupName)
|
||||
}
|
||||
if arr != nil {
|
||||
// Single-tensor blob or tensor found by name
|
||||
if dtype != 0 && arr.Dtype() != dtype {
|
||||
arr = mlx.AsType(arr, dtype)
|
||||
}
|
||||
arr = mlx.Contiguous(arr)
|
||||
mw.cache[name] = arr
|
||||
arrays = append(arrays, arr)
|
||||
|
||||
// Check for scale tensor
|
||||
if scale := sf.Get(lookupName + ".scale"); scale != nil {
|
||||
scale = mlx.Contiguous(scale)
|
||||
mw.cache[name+"_scale"] = scale
|
||||
arrays = append(arrays, scale)
|
||||
}
|
||||
|
||||
// Check for bias tensor
|
||||
if bias := sf.Get(lookupName + ".bias"); bias != nil {
|
||||
bias = mlx.Contiguous(bias)
|
||||
mw.cache[name+"_qbias"] = bias
|
||||
arrays = append(arrays, bias)
|
||||
}
|
||||
} else {
|
||||
// Packed blob: manifest name is a group prefix, not a tensor name.
|
||||
// Load all individual tensors from the blob.
|
||||
tensorNames, err := ParseBlobTensorNames(path)
|
||||
if err != nil {
|
||||
for _, h := range nativeHandles {
|
||||
h.Free()
|
||||
}
|
||||
return fmt.Errorf("parse packed blob for %s: %w", name, err)
|
||||
}
|
||||
|
||||
for _, tensorName := range tensorNames {
|
||||
tArr := sf.Get(tensorName)
|
||||
if tArr == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if dtype != 0 && tArr.Dtype() != dtype {
|
||||
tArr = mlx.AsType(tArr, dtype)
|
||||
}
|
||||
tArr = mlx.Contiguous(tArr)
|
||||
|
||||
// Strip component prefix from blob-internal names so cache keys
|
||||
// match the stripped names used by LoadModule.
|
||||
cacheName := tensorName
|
||||
if mw.component != "" {
|
||||
cacheName = strings.TrimPrefix(tensorName, mw.component+"/")
|
||||
}
|
||||
mw.cache[cacheName] = tArr
|
||||
arrays = append(arrays, tArr)
|
||||
|
||||
// Check for scale tensor
|
||||
if scale := sf.Get(tensorName + ".scale"); scale != nil {
|
||||
scale = mlx.Contiguous(scale)
|
||||
mw.cache[cacheName+"_scale"] = scale
|
||||
arrays = append(arrays, scale)
|
||||
}
|
||||
|
||||
// Check for bias tensor
|
||||
if bias := sf.Get(tensorName + ".bias"); bias != nil {
|
||||
bias = mlx.Contiguous(bias)
|
||||
mw.cache[cacheName+"_qbias"] = bias
|
||||
arrays = append(arrays, bias)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Make contiguous copy to ensure independence from mmap
|
||||
arr = mlx.Contiguous(arr)
|
||||
mw.cache[name] = arr
|
||||
arrays = append(arrays, arr)
|
||||
}
|
||||
|
||||
// Batch evaluate all tensors at once (much faster than one at a time)
|
||||
@@ -117,30 +206,50 @@ func (mw *ManifestWeights) GetTensor(name string) (*mlx.Array, error) {
|
||||
}
|
||||
|
||||
// ListTensors returns all tensor names in sorted order.
|
||||
// Includes both manifest tensor names and scale/bias entries from combined blobs.
|
||||
func (mw *ManifestWeights) ListTensors() []string {
|
||||
names := make([]string, 0, len(mw.tensors))
|
||||
seen := make(map[string]bool, len(mw.tensors)+len(mw.cache))
|
||||
for name := range mw.tensors {
|
||||
seen[name] = true
|
||||
}
|
||||
// Also include cache entries (scale/bias from combined blobs)
|
||||
for name := range mw.cache {
|
||||
seen[name] = true
|
||||
}
|
||||
names := make([]string, 0, len(seen))
|
||||
for name := range seen {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
// HasTensor checks if a tensor exists.
|
||||
// HasTensor checks if a tensor exists in the manifest or cache.
|
||||
func (mw *ManifestWeights) HasTensor(name string) bool {
|
||||
_, ok := mw.tensors[name]
|
||||
return ok
|
||||
if _, ok := mw.tensors[name]; ok {
|
||||
return true
|
||||
}
|
||||
// Also check cache for scale/bias entries from combined blobs
|
||||
if _, ok := mw.cache[name]; ok {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Quantization returns the model's quantization type from model_index.json.
|
||||
// Quantization returns the model's quantization type.
|
||||
// Returns the quant_type from blob metadata (e.g., "int4", "int8", "nvfp4", "mxfp8").
|
||||
// Returns empty string if not quantized.
|
||||
// Falls back to detecting from tensor names and shapes if not in config.
|
||||
// Falls back to model_index.json for image gen models.
|
||||
func (mw *ManifestWeights) Quantization() string {
|
||||
if mw.quantType != "" {
|
||||
return strings.ToUpper(mw.quantType)
|
||||
}
|
||||
|
||||
if mw.manifest == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Try to read from model_index.json first
|
||||
// Fallback: read from model_index.json (for image gen models)
|
||||
var index struct {
|
||||
Quantization string `json:"quantization"`
|
||||
}
|
||||
@@ -148,89 +257,22 @@ func (mw *ManifestWeights) Quantization() string {
|
||||
return index.Quantization
|
||||
}
|
||||
|
||||
// Fallback: detect from tensor names
|
||||
// Check if any tensors have _scale suffix (indicates quantization)
|
||||
hasScales := false
|
||||
hasQBias := false
|
||||
for name := range mw.tensors {
|
||||
if strings.HasSuffix(name, ".weight_scale") {
|
||||
hasScales = true
|
||||
}
|
||||
if strings.HasSuffix(name, ".weight_qbias") {
|
||||
hasQBias = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasScales {
|
||||
// No scales = not quantized
|
||||
return ""
|
||||
}
|
||||
|
||||
// Has scales but no qbias = NVFP4 (or other non-affine mode)
|
||||
if !hasQBias {
|
||||
return "NVFP4"
|
||||
}
|
||||
|
||||
// Has both scales and qbias = affine mode
|
||||
// Need to determine FP4 vs FP8 from tensor shapes
|
||||
// FP4: weight last dim is 1/8 of scales last dim * group_size
|
||||
// FP8: weight last dim is 1/4 of scales last dim * group_size
|
||||
//
|
||||
// For affine mode with group_size=32:
|
||||
// - FP4 (4 bits): 8 elements packed per uint32, so weight_dim = orig_dim / 8
|
||||
// - FP8 (8 bits): 4 elements packed per uint32, so weight_dim = orig_dim / 4
|
||||
// scales_dim = orig_dim / group_size
|
||||
// So: weight_dim / scales_dim = group_size / pack_factor
|
||||
// FP4: ratio = 32/8 = 4
|
||||
// FP8: ratio = 32/4 = 8
|
||||
|
||||
// Find a weight/scale pair to check the ratio
|
||||
for name := range mw.tensors {
|
||||
if !strings.HasSuffix(name, ".weight") || strings.Contains(name, "_scale") || strings.Contains(name, "_qbias") {
|
||||
continue
|
||||
}
|
||||
scaleName := name + "_scale"
|
||||
if _, ok := mw.tensors[scaleName]; !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Load both tensors to check shapes
|
||||
weightLayer := mw.tensors[name]
|
||||
scaleLayer := mw.tensors[scaleName]
|
||||
|
||||
// Get shapes from manifest layer metadata if available
|
||||
// For now, default to FP4 since it's more common
|
||||
// The actual shape check would require loading the tensor
|
||||
|
||||
// Simple heuristic: check if scale tensor is ~4x smaller than weight
|
||||
// FP4: weight is packed 8 per uint32, scales are 1 per group (32)
|
||||
// So scale size should be ~weight_size * 8 / 32 = weight_size / 4
|
||||
// FP8: weight is packed 4 per uint32, scales are 1 per group (32)
|
||||
// So scale size should be ~weight_size * 4 / 32 = weight_size / 8
|
||||
|
||||
// Rough size heuristic (assuming float16 scales)
|
||||
// Q4: scale_bytes ≈ weight_bytes / 4 * 2 / 4 = weight_bytes / 8
|
||||
// Q8: scale_bytes ≈ weight_bytes / 8 * 2 / 4 = weight_bytes / 16
|
||||
ratio := float64(weightLayer.Size) / float64(scaleLayer.Size)
|
||||
if ratio < 12 {
|
||||
// Closer to 8 = Q4
|
||||
return "Q4"
|
||||
}
|
||||
// Closer to 16 = Q8
|
||||
return "Q8"
|
||||
}
|
||||
|
||||
// Default to Q4 for affine mode (most common)
|
||||
return "Q4"
|
||||
return ""
|
||||
}
|
||||
|
||||
// GroupSize returns the quantization group size from model_index.json.
|
||||
// GroupSize returns the quantization group size.
|
||||
// Returns the group_size from blob metadata.
|
||||
// Returns 0 if not specified (caller should use default based on quantization type).
|
||||
func (mw *ManifestWeights) GroupSize() int {
|
||||
if mw.groupSize > 0 {
|
||||
return mw.groupSize
|
||||
}
|
||||
|
||||
if mw.manifest == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Fallback: read from model_index.json (for image gen models)
|
||||
var index struct {
|
||||
GroupSize int `json:"group_size"`
|
||||
}
|
||||
|
||||
@@ -1544,6 +1544,18 @@ func (s *SafetensorsFile) Count() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetMetadata retrieves a metadata value by key from the safetensors file
|
||||
func (s *SafetensorsFile) GetMetadata(key string) string {
|
||||
cKey := C.CString(key)
|
||||
defer C.free(unsafe.Pointer(cKey))
|
||||
|
||||
var cValue *C.char
|
||||
if C.mlx_map_string_to_string_get(&cValue, s.metadata, cKey) != 0 {
|
||||
return ""
|
||||
}
|
||||
return C.GoString(cValue)
|
||||
}
|
||||
|
||||
// Free releases the safetensors file
|
||||
func (s *SafetensorsFile) Free() {
|
||||
C.mlx_map_string_to_array_free(s.arrays)
|
||||
@@ -1578,6 +1590,41 @@ func SaveSafetensors(path string, arrays map[string]*Array) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveSafetensorsWithMetadata saves arrays to a safetensors file with metadata key/value pairs.
|
||||
// This is like SaveSafetensors but inserts metadata into the __metadata__ section.
|
||||
func SaveSafetensorsWithMetadata(path string, arrays map[string]*Array, metadata map[string]string) error {
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
// Create the array map
|
||||
cArrays := C.mlx_map_string_to_array_new()
|
||||
defer C.mlx_map_string_to_array_free(cArrays)
|
||||
|
||||
for name, arr := range arrays {
|
||||
cName := C.CString(name)
|
||||
C.mlx_map_string_to_array_insert(cArrays, cName, arr.c)
|
||||
C.free(unsafe.Pointer(cName))
|
||||
}
|
||||
|
||||
// Create metadata map
|
||||
cMeta := C.mlx_map_string_to_string_new()
|
||||
defer C.mlx_map_string_to_string_free(cMeta)
|
||||
|
||||
for key, value := range metadata {
|
||||
cKey := C.CString(key)
|
||||
cValue := C.CString(value)
|
||||
C.mlx_map_string_to_string_insert(cMeta, cKey, cValue)
|
||||
C.free(unsafe.Pointer(cKey))
|
||||
C.free(unsafe.Pointer(cValue))
|
||||
}
|
||||
|
||||
// Save
|
||||
if C.mlx_save_safetensors(cPath, cArrays, cMeta) != 0 {
|
||||
return fmt.Errorf("failed to save safetensors: %s", path)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ============ NPY Loading ============
|
||||
|
||||
// LoadNpy loads a numpy array from an npy file
|
||||
|
||||
@@ -41,13 +41,11 @@ func (td *TensorData) Reader() io.Reader {
|
||||
return td.reader
|
||||
}
|
||||
|
||||
// SafetensorsReader returns a reader that outputs the tensor wrapped in
|
||||
// minimal safetensors format. This allows using mlx_load_safetensors on
|
||||
// individual tensor blobs for native zero-copy loading.
|
||||
func (td *TensorData) SafetensorsReader() io.Reader {
|
||||
// Build minimal safetensors header with tensor named "data"
|
||||
header := map[string]tensorInfo{
|
||||
"data": {
|
||||
// safetensorsHeader builds the JSON header for a minimal safetensors blob
|
||||
// containing a single tensor keyed by its name.
|
||||
func (td *TensorData) safetensorsHeader() []byte {
|
||||
header := map[string]any{
|
||||
td.Name: tensorInfo{
|
||||
Dtype: td.Dtype,
|
||||
Shape: td.Shape,
|
||||
DataOffsets: [2]int{0, int(td.Size)},
|
||||
@@ -58,6 +56,15 @@ func (td *TensorData) SafetensorsReader() io.Reader {
|
||||
// Pad header to 8-byte alignment
|
||||
padding := (8 - len(headerJSON)%8) % 8
|
||||
headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
|
||||
return headerJSON
|
||||
}
|
||||
|
||||
// SafetensorsReader returns a reader that outputs the tensor wrapped in
|
||||
// minimal safetensors format. This allows using mlx_load_safetensors on
|
||||
// individual tensor blobs for native zero-copy loading.
|
||||
// The tensor is keyed by its name in the safetensors header.
|
||||
func (td *TensorData) SafetensorsReader() io.Reader {
|
||||
headerJSON := td.safetensorsHeader()
|
||||
|
||||
// Build header with size prefix
|
||||
headerBuf := new(bytes.Buffer)
|
||||
@@ -71,16 +78,77 @@ func (td *TensorData) SafetensorsReader() io.Reader {
|
||||
|
||||
// SafetensorsSize returns the total size of the safetensors-wrapped tensor.
|
||||
func (td *TensorData) SafetensorsSize() int64 {
|
||||
header := map[string]tensorInfo{
|
||||
"data": {
|
||||
headerJSON := td.safetensorsHeader()
|
||||
return 8 + int64(len(headerJSON)) + td.Size
|
||||
}
|
||||
|
||||
// NewTensorDataFromBytes creates a TensorData from raw tensor bytes.
|
||||
// This is useful for constructing packed blobs from already-extracted data.
|
||||
func NewTensorDataFromBytes(name, dtype string, shape []int32, rawData []byte) *TensorData {
|
||||
return &TensorData{
|
||||
Name: name,
|
||||
Dtype: dtype,
|
||||
Shape: shape,
|
||||
Size: int64(len(rawData)),
|
||||
reader: io.NewSectionReader(bytes.NewReader(rawData), 0, int64(len(rawData))),
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractRawFromSafetensors reads a safetensors-wrapped reader and extracts
|
||||
// the raw tensor data bytes (stripping the header).
|
||||
func ExtractRawFromSafetensors(r io.Reader) ([]byte, error) {
|
||||
// Read header size (8 bytes, little endian)
|
||||
var headerSize uint64
|
||||
if err := binary.Read(r, binary.LittleEndian, &headerSize); err != nil {
|
||||
return nil, fmt.Errorf("failed to read header size: %w", err)
|
||||
}
|
||||
|
||||
// Skip header
|
||||
if _, err := io.CopyN(io.Discard, r, int64(headerSize)); err != nil {
|
||||
return nil, fmt.Errorf("failed to skip header: %w", err)
|
||||
}
|
||||
|
||||
// Read remaining bytes (the raw tensor data)
|
||||
return io.ReadAll(r)
|
||||
}
|
||||
|
||||
// BuildPackedSafetensorsReader builds a streaming io.Reader that outputs a valid
|
||||
// safetensors file containing multiple tensors. Used for packing expert tensors
|
||||
// into a single blob without loading all data into memory.
|
||||
// Each TensorData must have been obtained from GetTensor.
|
||||
func BuildPackedSafetensorsReader(tensors []*TensorData) io.Reader {
|
||||
// Build the header with sequential data offsets
|
||||
header := make(map[string]tensorInfo, len(tensors))
|
||||
var offset int
|
||||
for _, td := range tensors {
|
||||
header[td.Name] = tensorInfo{
|
||||
Dtype: td.Dtype,
|
||||
Shape: td.Shape,
|
||||
DataOffsets: [2]int{0, int(td.Size)},
|
||||
},
|
||||
DataOffsets: [2]int{offset, offset + int(td.Size)},
|
||||
}
|
||||
offset += int(td.Size)
|
||||
}
|
||||
|
||||
headerJSON, _ := json.Marshal(header)
|
||||
|
||||
// Pad header to 8-byte alignment
|
||||
padding := (8 - len(headerJSON)%8) % 8
|
||||
return 8 + int64(len(headerJSON)) + int64(padding) + td.Size
|
||||
headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
|
||||
|
||||
// Build header with size prefix
|
||||
headerBuf := new(bytes.Buffer)
|
||||
binary.Write(headerBuf, binary.LittleEndian, uint64(len(headerJSON)))
|
||||
headerBuf.Write(headerJSON)
|
||||
|
||||
// Build multi-reader: header + all tensor data readers
|
||||
readers := make([]io.Reader, 0, 1+len(tensors))
|
||||
readers = append(readers, headerBuf)
|
||||
for _, td := range tensors {
|
||||
td.reader.Seek(0, io.SeekStart)
|
||||
readers = append(readers, td.reader)
|
||||
}
|
||||
|
||||
return io.MultiReader(readers...)
|
||||
}
|
||||
|
||||
// OpenForExtraction opens a safetensors file for tensor extraction.
|
||||
|
||||
@@ -17,7 +17,7 @@ type WeightSource interface {
|
||||
GetTensor(name string) (*mlx.Array, error)
|
||||
ListTensors() []string
|
||||
HasTensor(name string) bool
|
||||
Quantization() string // Returns "NVFP4", "Q4", "Q8", or ""
|
||||
Quantization() string // Returns "NVFP4", "INT4", "INT8", or ""
|
||||
GroupSize() int // Returns quantization group size, or 0 if not specified
|
||||
}
|
||||
|
||||
|
||||
96
x/mlxrunner/cache.go
Normal file
96
x/mlxrunner/cache.go
Normal file
@@ -0,0 +1,96 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
)
|
||||
|
||||
type CacheEntry struct {
|
||||
Caches []cache.Cache
|
||||
Count int
|
||||
Entries map[int32]*CacheEntry
|
||||
}
|
||||
|
||||
func (s Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
|
||||
current := &CacheEntry{Entries: s.CacheEntries}
|
||||
index, cacheIndex := 0, -1
|
||||
for _, token := range tokens {
|
||||
if _, ok := current.Entries[token]; !ok {
|
||||
break
|
||||
}
|
||||
|
||||
current = current.Entries[token]
|
||||
if len(current.Caches) > 0 {
|
||||
cacheIndex = index
|
||||
}
|
||||
|
||||
index += 1
|
||||
}
|
||||
|
||||
if cacheIndex == len(tokens)-1 {
|
||||
slog.Info("Cache hit", "type", "exact", "total", len(tokens), "cached", len(tokens), "left", len(tokens))
|
||||
return current.Caches, []int32{}
|
||||
} else if cacheIndex > 1 {
|
||||
slog.Info("Cache hit", "type", "partial", "total", len(tokens), "cached", cacheIndex+1, "left", len(tokens[cacheIndex+1:]))
|
||||
return current.Caches, tokens[cacheIndex+1:]
|
||||
} else if index > 0 && cacheIndex < 0 {
|
||||
type stackItem struct {
|
||||
entry *CacheEntry
|
||||
tokens []int32
|
||||
}
|
||||
|
||||
var best, item stackItem
|
||||
stack := []stackItem{{entry: current, tokens: []int32{}}}
|
||||
for len(stack) > 0 {
|
||||
item, stack = stack[len(stack)-1], stack[:len(stack)-1]
|
||||
if len(item.entry.Caches) > 0 {
|
||||
if len(best.tokens) == 0 || len(item.tokens) < len(best.tokens) {
|
||||
best = item
|
||||
}
|
||||
} else {
|
||||
for token, entry := range item.entry.Entries {
|
||||
stack = append(stack, stackItem{
|
||||
entry: entry,
|
||||
tokens: append(item.tokens, token),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prefix := min(len(tokens)-1, index)
|
||||
caches := make([]cache.Cache, len(best.entry.Caches))
|
||||
trim := len(best.tokens)+1
|
||||
for i := range caches {
|
||||
caches[i] = best.entry.Caches[i].Clone()
|
||||
caches[i].Trim(trim)
|
||||
}
|
||||
|
||||
slog.Info("Cache hit", "type", "prefix", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]), "trimmed", trim)
|
||||
return caches, tokens[prefix:]
|
||||
}
|
||||
|
||||
slog.Info("Cache miss", "left", len(tokens))
|
||||
return nil, tokens
|
||||
}
|
||||
|
||||
func (s *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
|
||||
current := &CacheEntry{Entries: s.CacheEntries}
|
||||
for _, token := range tokens {
|
||||
if _, ok := current.Entries[token]; !ok {
|
||||
current.Entries[token] = &CacheEntry{
|
||||
Entries: make(map[int32]*CacheEntry),
|
||||
}
|
||||
}
|
||||
|
||||
current = current.Entries[token]
|
||||
}
|
||||
|
||||
if len(current.Caches) > 0 {
|
||||
current.Count += 1
|
||||
} else {
|
||||
current.Caches = caches
|
||||
}
|
||||
}
|
||||
196
x/mlxrunner/cache/cache.go
vendored
Normal file
196
x/mlxrunner/cache/cache.go
vendored
Normal file
@@ -0,0 +1,196 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
type Cache interface {
|
||||
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
|
||||
State() (keys, values *mlx.Array)
|
||||
Trim(int) int
|
||||
Clone() Cache
|
||||
Offset() int
|
||||
Len() int
|
||||
}
|
||||
|
||||
type KVCache struct {
|
||||
keys, values *mlx.Array
|
||||
offset int
|
||||
step int
|
||||
}
|
||||
|
||||
func NewKVCache() *KVCache {
|
||||
return &KVCache{step: 256}
|
||||
}
|
||||
|
||||
func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
|
||||
|
||||
prev := c.offset
|
||||
|
||||
// Grow buffer if needed
|
||||
if c.keys == nil || (prev+L) > c.keys.Dim(2) {
|
||||
steps := (c.step + L - 1) / c.step
|
||||
newKeys := mlx.Zeros(keys.DType(), B, H, steps*c.step, Dk)
|
||||
newValues := mlx.Zeros(values.DType(), B, H, steps*c.step, Dv)
|
||||
|
||||
if c.keys != nil {
|
||||
if prev%c.step != 0 {
|
||||
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice()))
|
||||
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice()))
|
||||
}
|
||||
c.keys.Set(c.keys.Concatenate(2, newKeys))
|
||||
c.values.Set(c.values.Concatenate(2, newValues))
|
||||
} else {
|
||||
c.keys, c.values = newKeys, newValues
|
||||
}
|
||||
}
|
||||
|
||||
c.offset += L
|
||||
c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
|
||||
c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
|
||||
|
||||
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||
}
|
||||
|
||||
func (c *KVCache) State() (*mlx.Array, *mlx.Array) {
|
||||
if c.offset == c.keys.Dim(2) {
|
||||
return c.keys, c.values
|
||||
}
|
||||
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||
}
|
||||
|
||||
func (c *KVCache) Trim(n int) int {
|
||||
n = min(c.offset, n)
|
||||
c.offset -= n
|
||||
return n
|
||||
}
|
||||
|
||||
func (c *KVCache) Clone() Cache {
|
||||
return &KVCache{
|
||||
keys: c.keys.Clone(),
|
||||
values: c.values.Clone(),
|
||||
offset: c.offset,
|
||||
step: c.step,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *KVCache) Offset() int { return c.offset }
|
||||
func (c *KVCache) Len() int { return c.offset }
|
||||
|
||||
// RotatingKVCache implements sliding window attention with bounded memory
|
||||
type RotatingKVCache struct {
|
||||
maxSize int
|
||||
idx int
|
||||
|
||||
*KVCache
|
||||
}
|
||||
|
||||
func NewRotatingKVCache(maxSize int) *RotatingKVCache {
|
||||
return &RotatingKVCache{maxSize: maxSize, KVCache: NewKVCache()}
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
if keys.Dim(2) > 1 {
|
||||
return c.concat(keys, values)
|
||||
}
|
||||
return c.update(keys, values)
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV *mlx.Array) {
|
||||
slog.Debug("(*RotatingKVCache).concat", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
|
||||
if c.keys == nil {
|
||||
c.keys, c.values = keys, values
|
||||
} else {
|
||||
if c.idx < c.keys.Dim(2) {
|
||||
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
||||
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
||||
}
|
||||
|
||||
// Trim to max_size to maintain sliding window
|
||||
if trim := c.idx - c.maxSize + 1; trim > 0 {
|
||||
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.keys.Dim(2)), mlx.Slice()))
|
||||
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.values.Dim(2)), mlx.Slice()))
|
||||
}
|
||||
|
||||
c.keys.Set(c.keys.Concatenate(2, keys))
|
||||
c.values.Set(c.values.Concatenate(2, values))
|
||||
c.idx = c.keys.Dim(2)
|
||||
}
|
||||
|
||||
c.offset += keys.Dim(2)
|
||||
c.idx = c.keys.Dim(2)
|
||||
return c.keys, c.values
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
slog.Debug("(*RotatingKVCache).update", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
|
||||
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
|
||||
|
||||
prev := c.offset
|
||||
|
||||
// Grow buffer if not yet at max
|
||||
if c.keys == nil || (prev >= c.keys.Dim(2) && c.keys.Dim(2) < c.maxSize) {
|
||||
newSize := min(c.step, c.maxSize-prev)
|
||||
newKeys := mlx.Zeros(keys.DType(), B, H, newSize, Dk)
|
||||
newValues := mlx.Zeros(values.DType(), B, H, newSize, Dv)
|
||||
if c.keys != nil {
|
||||
c.keys.Set(c.keys.Concatenate(2, newKeys))
|
||||
c.values.Set(c.values.Concatenate(2, newValues))
|
||||
} else {
|
||||
c.keys, c.values = newKeys, newValues
|
||||
}
|
||||
c.idx = prev
|
||||
}
|
||||
|
||||
// Trim to max_size to maintain sliding window
|
||||
if trim := c.keys.Dim(2) - c.maxSize; trim > 0 {
|
||||
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.keys.Dim(2)), mlx.Slice()))
|
||||
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.values.Dim(2)), mlx.Slice()))
|
||||
c.idx = c.maxSize
|
||||
}
|
||||
|
||||
// Rotate when hitting max
|
||||
if c.idx >= c.maxSize {
|
||||
c.idx = 0
|
||||
}
|
||||
|
||||
c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.idx+L), mlx.Slice()))
|
||||
c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.idx+L), mlx.Slice()))
|
||||
|
||||
c.offset += L
|
||||
c.idx += L
|
||||
|
||||
validLen := min(c.offset, c.maxSize)
|
||||
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice())
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) {
|
||||
if c.offset < c.keys.Dim(2) {
|
||||
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||
}
|
||||
return c.keys, c.values
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Trim(n int) int {
|
||||
n = min(c.offset, n)
|
||||
c.offset -= n
|
||||
c.idx -= n
|
||||
return n
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Clone() Cache {
|
||||
return &RotatingKVCache{
|
||||
maxSize: c.maxSize,
|
||||
idx: c.idx,
|
||||
KVCache: c.KVCache.Clone().(*KVCache),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }
|
||||
174
x/mlxrunner/client.go
Normal file
174
x/mlxrunner/client.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
Port int
|
||||
*exec.Cmd
|
||||
}
|
||||
|
||||
func (c *Client) JoinPath(path string) string {
|
||||
return (&url.URL{
|
||||
Scheme: "http",
|
||||
Host: net.JoinHostPort("127.0.0.1", strconv.Itoa(c.Port)),
|
||||
}).JoinPath(path).String()
|
||||
}
|
||||
|
||||
func (c *Client) CheckError(w *http.Response) error {
|
||||
if w.StatusCode >= 400 {
|
||||
return errors.New(w.Status)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close implements llm.LlamaServer.
|
||||
func (c *Client) Close() error {
|
||||
return c.Cmd.Process.Kill()
|
||||
}
|
||||
|
||||
// Completion implements llm.LlamaServer.
|
||||
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w, err := http.Post(c.JoinPath("/v1/completions"), "application/json", &b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer w.Body.Close()
|
||||
|
||||
if err := c.CheckError(w); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(w.Body)
|
||||
for scanner.Scan() {
|
||||
bts := scanner.Bytes()
|
||||
|
||||
var resp llm.CompletionResponse
|
||||
if err := json.Unmarshal(bts, &resp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fn(resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) ContextLength() int {
|
||||
return math.MaxInt
|
||||
}
|
||||
|
||||
// Detokenize implements llm.LlamaServer.
|
||||
func (c *Client) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Embedding implements llm.LlamaServer.
|
||||
func (c *Client) Embedding(ctx context.Context, input string) ([]float32, int, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// GetDeviceInfos implements llm.LlamaServer.
|
||||
func (c *Client) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// GetPort implements llm.LlamaServer.
|
||||
func (c *Client) GetPort() int {
|
||||
return c.Port
|
||||
}
|
||||
|
||||
// HasExited implements llm.LlamaServer.
|
||||
func (c *Client) HasExited() bool {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Load implements llm.LlamaServer.
|
||||
func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) ([]ml.DeviceID, error) {
|
||||
w, err := http.Post(c.JoinPath("/v1/models"), "application/json", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer w.Body.Close()
|
||||
|
||||
return []ml.DeviceID{}, nil
|
||||
}
|
||||
|
||||
// ModelPath implements llm.LlamaServer.
|
||||
func (c *Client) ModelPath() string {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Pid implements llm.LlamaServer.
|
||||
func (c *Client) Pid() int {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Ping implements llm.LlamaServer.
|
||||
func (c *Client) Ping(ctx context.Context) error {
|
||||
w, err := http.Get(c.JoinPath("/v1/status"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer w.Body.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Tokenize implements llm.LlamaServer.
|
||||
func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
w, err := http.Post(c.JoinPath("/v1/tokenize"), "text/plain", strings.NewReader(content))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer w.Body.Close()
|
||||
|
||||
var tokens []int
|
||||
if err := json.NewDecoder(w.Body).Decode(&tokens); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// TotalSize implements llm.LlamaServer.
|
||||
func (c *Client) TotalSize() uint64 {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// VRAMByGPU implements llm.LlamaServer.
|
||||
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// VRAMSize implements llm.LlamaServer.
|
||||
func (c *Client) VRAMSize() uint64 {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// WaitUntilRunning implements llm.LlamaServer.
|
||||
func (c *Client) WaitUntilRunning(ctx context.Context) error {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
var _ llm.LlamaServer = (*Client)(nil)
|
||||
3
x/mlxrunner/mlx/.gitignore
vendored
Normal file
3
x/mlxrunner/mlx/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
_deps
|
||||
build
|
||||
dist
|
||||
26
x/mlxrunner/mlx/CMakeLists.txt
Normal file
26
x/mlxrunner/mlx/CMakeLists.txt
Normal file
@@ -0,0 +1,26 @@
|
||||
cmake_minimum_required(VERSION 3.5)
|
||||
|
||||
project(mlx)
|
||||
|
||||
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
|
||||
set(CMAKE_INSTALL_PREFIX "${CMAKE_CURRENT_SOURCE_DIR}/dist" CACHE PATH "" FORCE)
|
||||
endif()
|
||||
|
||||
set(MLX_BUILD_GGUF OFF CACHE BOOL "" FORCE)
|
||||
set(MLX_BUILD_SAFETENSORS ON CACHE BOOL "" FORCE)
|
||||
set(MLX_C_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
|
||||
set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
|
||||
|
||||
set(CMAKE_INSTALL_RPATH "@loader_path")
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
set(MLX_C_GIT_TAG "v0.4.1" CACHE STRING "")
|
||||
|
||||
FetchContent_Declare(
|
||||
mlx-c
|
||||
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
|
||||
GIT_TAG ${MLX_C_GIT_TAG}
|
||||
)
|
||||
|
||||
FetchContent_MakeAvailable(mlx-c)
|
||||
21
x/mlxrunner/mlx/act.go
Normal file
21
x/mlxrunner/mlx/act.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
import "math"
|
||||
|
||||
func GELUApprox(t *Array) *Array {
|
||||
return t.Multiply(
|
||||
FromValue[float32](0.5),
|
||||
).Multiply(
|
||||
t.Add(
|
||||
t.Power(FromValue[float32](3.0)).Multiply(FromValue[float32](0.044715)),
|
||||
).Multiply(
|
||||
FromValue(float32(math.Sqrt(2 / math.Pi))),
|
||||
).Tanh().Add(FromValue[float32](1.0)),
|
||||
).AsType(t.DType())
|
||||
}
|
||||
|
||||
func SILU(t *Array) *Array {
|
||||
return t.Multiply(t.Sigmoid()).AsType(t.DType())
|
||||
}
|
||||
271
x/mlxrunner/mlx/array.go
Normal file
271
x/mlxrunner/mlx/array.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"log/slog"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type tensorDesc struct {
|
||||
name string
|
||||
inputs []*Array
|
||||
numRefs int
|
||||
}
|
||||
|
||||
func (d tensorDesc) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.String("name", d.name),
|
||||
slog.Int("inputs", len(d.inputs)),
|
||||
slog.Int("num_refs", d.numRefs),
|
||||
)
|
||||
}
|
||||
|
||||
type Array struct {
|
||||
ctx C.mlx_array
|
||||
desc tensorDesc
|
||||
}
|
||||
|
||||
// constructor utilities
|
||||
|
||||
func New(name string, inputs ...*Array) *Array {
|
||||
t := &Array{
|
||||
desc: tensorDesc{
|
||||
name: name,
|
||||
inputs: inputs,
|
||||
},
|
||||
}
|
||||
|
||||
for _, input := range inputs {
|
||||
input.desc.numRefs++
|
||||
}
|
||||
logutil.Trace("New", "t", t)
|
||||
return t
|
||||
}
|
||||
|
||||
type scalarTypes interface {
|
||||
~bool | ~int | ~float32 | ~float64 | ~complex64
|
||||
}
|
||||
|
||||
func FromValue[T scalarTypes](t T) *Array {
|
||||
tt := New("")
|
||||
switch v := any(t).(type) {
|
||||
case bool:
|
||||
tt.ctx = C.mlx_array_new_bool(C.bool(v))
|
||||
case int:
|
||||
tt.ctx = C.mlx_array_new_int(C.int(v))
|
||||
case float32:
|
||||
tt.ctx = C.mlx_array_new_float32(C.float(v))
|
||||
case float64:
|
||||
tt.ctx = C.mlx_array_new_float64(C.double(v))
|
||||
case complex64:
|
||||
tt.ctx = C.mlx_array_new_complex(C.float(real(v)), C.float(imag(v)))
|
||||
default:
|
||||
panic("unsupported type")
|
||||
}
|
||||
return tt
|
||||
}
|
||||
|
||||
type arrayTypes interface {
|
||||
~bool | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
|
||||
~int8 | ~int16 | ~int32 | ~int64 |
|
||||
~float32 | ~float64 |
|
||||
~complex64
|
||||
}
|
||||
|
||||
func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
|
||||
if len(shape) == 0 {
|
||||
panic("shape must be provided for non-scalar tensors")
|
||||
}
|
||||
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i := range shape {
|
||||
cShape[i] = C.int(shape[i])
|
||||
}
|
||||
|
||||
var dtype DType
|
||||
switch reflect.TypeOf(s).Elem().Kind() {
|
||||
case reflect.Bool:
|
||||
dtype = DTypeBool
|
||||
case reflect.Uint8:
|
||||
dtype = DTypeUint8
|
||||
case reflect.Uint16:
|
||||
dtype = DTypeUint16
|
||||
case reflect.Uint32:
|
||||
dtype = DTypeUint32
|
||||
case reflect.Uint64:
|
||||
dtype = DTypeUint64
|
||||
case reflect.Int8:
|
||||
dtype = DTypeInt8
|
||||
case reflect.Int16:
|
||||
dtype = DTypeInt16
|
||||
case reflect.Int32:
|
||||
dtype = DTypeInt32
|
||||
case reflect.Int64:
|
||||
dtype = DTypeInt64
|
||||
case reflect.Float32:
|
||||
dtype = DTypeFloat32
|
||||
case reflect.Float64:
|
||||
dtype = DTypeFloat64
|
||||
case reflect.Complex64:
|
||||
dtype = DTypeComplex64
|
||||
default:
|
||||
panic("unsupported type")
|
||||
}
|
||||
|
||||
bts := make([]byte, binary.Size(s))
|
||||
if _, err := binary.Encode(bts, binary.LittleEndian, s); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
tt := New("")
|
||||
tt.ctx = C.mlx_array_new_data(unsafe.Pointer(&bts[0]), unsafe.SliceData(cShape), C.int(len(cShape)), C.mlx_dtype(dtype))
|
||||
return tt
|
||||
}
|
||||
|
||||
func (t *Array) Set(other *Array) {
|
||||
other.desc.numRefs++
|
||||
t.desc.inputs = []*Array{other}
|
||||
C.mlx_array_set(&t.ctx, other.ctx)
|
||||
}
|
||||
|
||||
func (t *Array) Clone() *Array {
|
||||
tt := New(t.desc.name, t.desc.inputs...)
|
||||
C.mlx_array_set(&tt.ctx, t.ctx)
|
||||
return tt
|
||||
}
|
||||
|
||||
// misc. utilities
|
||||
|
||||
func (t *Array) Valid() bool {
|
||||
return t.ctx.ctx != nil
|
||||
}
|
||||
|
||||
func (t *Array) String() string {
|
||||
str := C.mlx_string_new()
|
||||
defer C.mlx_string_free(str)
|
||||
C.mlx_array_tostring(&str, t.ctx)
|
||||
return strings.TrimSpace(C.GoString(C.mlx_string_data(str)))
|
||||
}
|
||||
|
||||
func (t *Array) LogValue() slog.Value {
|
||||
attrs := []slog.Attr{slog.Any("", t.desc)}
|
||||
if t.Valid() {
|
||||
attrs = append(attrs,
|
||||
slog.Any("dtype", t.DType()),
|
||||
slog.Any("shape", t.Dims()),
|
||||
slog.Int("num_bytes", t.NumBytes()),
|
||||
)
|
||||
}
|
||||
return slog.GroupValue(attrs...)
|
||||
}
|
||||
|
||||
// shape utilities
|
||||
|
||||
func (t Array) Size() int {
|
||||
return int(C.mlx_array_size(t.ctx))
|
||||
}
|
||||
|
||||
func (t Array) NumBytes() int {
|
||||
return int(C.mlx_array_nbytes(t.ctx))
|
||||
}
|
||||
|
||||
func (t Array) NumDims() int {
|
||||
return int(C.mlx_array_ndim(t.ctx))
|
||||
}
|
||||
|
||||
func (t Array) Dims() []int {
|
||||
dims := make([]int, t.NumDims())
|
||||
for i := range dims {
|
||||
dims[i] = t.Dim(i)
|
||||
}
|
||||
|
||||
return dims
|
||||
}
|
||||
|
||||
func (t Array) Dim(dim int) int {
|
||||
return int(C.mlx_array_dim(t.ctx, C.int(dim)))
|
||||
}
|
||||
|
||||
func (t Array) DType() DType {
|
||||
return DType(C.mlx_array_dtype(t.ctx))
|
||||
}
|
||||
|
||||
// data utilities
|
||||
|
||||
func (t Array) Int() int {
|
||||
var item C.int64_t
|
||||
C.mlx_array_item_int64(&item, t.ctx)
|
||||
return int(item)
|
||||
}
|
||||
|
||||
func (t Array) Float() float64 {
|
||||
var item C.double
|
||||
C.mlx_array_item_float64(&item, t.ctx)
|
||||
return float64(item)
|
||||
}
|
||||
|
||||
func (t Array) Ints() []int {
|
||||
ints := make([]int, t.Size())
|
||||
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) {
|
||||
ints[i] = int(f)
|
||||
}
|
||||
return ints
|
||||
}
|
||||
|
||||
func (t Array) Floats() []float32 {
|
||||
floats := make([]float32, t.Size())
|
||||
for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) {
|
||||
floats[i] = float32(f)
|
||||
}
|
||||
return floats
|
||||
}
|
||||
|
||||
func (t Array) Save(name string) error {
|
||||
cName := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cName))
|
||||
C.mlx_save(cName, t.ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
func Free(s ...*Array) (n int) {
|
||||
now := time.Now()
|
||||
defer func() {
|
||||
if n > 0 {
|
||||
logutil.Trace("Freed tensors", "num_bytes", PrettyBytes(n), "took", time.Since(now))
|
||||
}
|
||||
}()
|
||||
|
||||
free := make([]*Array, 0, 8192)
|
||||
fn := func(t *Array) {
|
||||
if t.Valid() {
|
||||
free = append(free, t.desc.inputs...)
|
||||
t.desc.numRefs--
|
||||
if t.desc.numRefs <= 0 {
|
||||
logutil.Trace("Free", "t", t)
|
||||
n += t.NumBytes()
|
||||
C.mlx_array_free(t.ctx)
|
||||
t.ctx.ctx = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, t := range s {
|
||||
fn(t)
|
||||
}
|
||||
|
||||
for len(free) > 0 {
|
||||
tail := free[len(free)-1]
|
||||
free = free[:len(free)-1]
|
||||
fn(tail)
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
43
x/mlxrunner/mlx/array_test.go
Normal file
43
x/mlxrunner/mlx/array_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package mlx
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestFromValue(t *testing.T) {
|
||||
for got, want := range map[*Array]DType{
|
||||
FromValue(true): DTypeBool,
|
||||
FromValue(false): DTypeBool,
|
||||
FromValue(int(7)): DTypeInt32,
|
||||
FromValue(float32(3.14)): DTypeFloat32,
|
||||
FromValue(float64(2.71)): DTypeFloat64,
|
||||
FromValue(complex64(1 + 2i)): DTypeComplex64,
|
||||
} {
|
||||
t.Run(want.String(), func(t *testing.T) {
|
||||
if got.DType() != want {
|
||||
t.Errorf("want %v, got %v", want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromValues(t *testing.T) {
|
||||
for got, want := range map[*Array]DType{
|
||||
FromValues([]bool{true, false, true}, 3): DTypeBool,
|
||||
FromValues([]uint8{1, 2, 3}, 3): DTypeUint8,
|
||||
FromValues([]uint16{1, 2, 3}, 3): DTypeUint16,
|
||||
FromValues([]uint32{1, 2, 3}, 3): DTypeUint32,
|
||||
FromValues([]uint64{1, 2, 3}, 3): DTypeUint64,
|
||||
FromValues([]int8{-1, -2, -3}, 3): DTypeInt8,
|
||||
FromValues([]int16{-1, -2, -3}, 3): DTypeInt16,
|
||||
FromValues([]int32{-1, -2, -3}, 3): DTypeInt32,
|
||||
FromValues([]int64{-1, -2, -3}, 3): DTypeInt64,
|
||||
FromValues([]float32{3.14, 2.71, 1.61}, 3): DTypeFloat32,
|
||||
FromValues([]float64{3.14, 2.71, 1.61}, 3): DTypeFloat64,
|
||||
FromValues([]complex64{1 + 2i, 3 + 4i, 5 + 6i}, 3): DTypeComplex64,
|
||||
} {
|
||||
t.Run(want.String(), func(t *testing.T) {
|
||||
if got.DType() != want {
|
||||
t.Errorf("want %v, got %v", want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
94
x/mlxrunner/mlx/dtype.go
Normal file
94
x/mlxrunner/mlx/dtype.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
type DType int
|
||||
|
||||
func (t DType) String() string {
|
||||
switch t {
|
||||
case DTypeBool:
|
||||
return "BOOL"
|
||||
case DTypeUint8:
|
||||
return "U8"
|
||||
case DTypeUint16:
|
||||
return "U16"
|
||||
case DTypeUint32:
|
||||
return "U32"
|
||||
case DTypeUint64:
|
||||
return "U64"
|
||||
case DTypeInt8:
|
||||
return "I8"
|
||||
case DTypeInt16:
|
||||
return "I16"
|
||||
case DTypeInt32:
|
||||
return "I32"
|
||||
case DTypeInt64:
|
||||
return "I64"
|
||||
case DTypeFloat16:
|
||||
return "F16"
|
||||
case DTypeFloat32:
|
||||
return "F32"
|
||||
case DTypeFloat64:
|
||||
return "F64"
|
||||
case DTypeBFloat16:
|
||||
return "BF16"
|
||||
case DTypeComplex64:
|
||||
return "C64"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (t *DType) UnmarshalJSON(b []byte) error {
|
||||
switch string(b) {
|
||||
case `"BOOL"`:
|
||||
*t = DTypeBool
|
||||
case `"U8"`:
|
||||
*t = DTypeUint8
|
||||
case `"U16"`:
|
||||
*t = DTypeUint16
|
||||
case `"U32"`:
|
||||
*t = DTypeUint32
|
||||
case `"U64"`:
|
||||
*t = DTypeUint64
|
||||
case `"I8"`:
|
||||
*t = DTypeInt8
|
||||
case `"I16"`:
|
||||
*t = DTypeInt16
|
||||
case `"I32"`:
|
||||
*t = DTypeInt32
|
||||
case `"I64"`:
|
||||
*t = DTypeInt64
|
||||
case `"F16"`:
|
||||
*t = DTypeFloat16
|
||||
case `"F64"`:
|
||||
*t = DTypeFloat64
|
||||
case `"F32"`:
|
||||
*t = DTypeFloat32
|
||||
case `"BF16"`:
|
||||
*t = DTypeBFloat16
|
||||
case `"C64"`:
|
||||
*t = DTypeComplex64
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
DTypeBool DType = C.MLX_BOOL
|
||||
DTypeUint8 DType = C.MLX_UINT8
|
||||
DTypeUint16 DType = C.MLX_UINT16
|
||||
DTypeUint32 DType = C.MLX_UINT32
|
||||
DTypeUint64 DType = C.MLX_UINT64
|
||||
DTypeInt8 DType = C.MLX_INT8
|
||||
DTypeInt16 DType = C.MLX_INT16
|
||||
DTypeInt32 DType = C.MLX_INT32
|
||||
DTypeInt64 DType = C.MLX_INT64
|
||||
DTypeFloat16 DType = C.MLX_FLOAT16
|
||||
DTypeFloat32 DType = C.MLX_FLOAT32
|
||||
DTypeFloat64 DType = C.MLX_FLOAT64
|
||||
DTypeBFloat16 DType = C.MLX_BFLOAT16
|
||||
DTypeComplex64 DType = C.MLX_COMPLEX64
|
||||
)
|
||||
34
x/mlxrunner/mlx/dynamic.c
Normal file
34
x/mlxrunner/mlx/dynamic.c
Normal file
@@ -0,0 +1,34 @@
|
||||
#include "dynamic.h"
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
#define DLOPEN(path) LoadLibraryA(path)
|
||||
#define DLCLOSE(handle) FreeLibrary((HMODULE)(handle))
|
||||
#else
|
||||
#ifdef __APPLE__
|
||||
#include <mach-o/dyld.h>
|
||||
#include <libgen.h>
|
||||
#endif
|
||||
#include <dlfcn.h>
|
||||
#define DLOPEN(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
|
||||
#define DLCLOSE(handle) dlclose(handle)
|
||||
#endif
|
||||
|
||||
static int mlx_dynamic_open(mlx_dynamic_handle* handle, const char* path) {
|
||||
handle->ctx = (void*) DLOPEN(path);
|
||||
CHECK(handle->ctx != NULL);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int mlx_dynamic_load(mlx_dynamic_handle* handle, const char *path) {
|
||||
return mlx_dynamic_open(handle, path);
|
||||
}
|
||||
|
||||
void mlx_dynamic_unload(mlx_dynamic_handle* handle) {
|
||||
if (handle->ctx) {
|
||||
DLCLOSE(handle->ctx);
|
||||
handle->ctx = NULL;
|
||||
}
|
||||
}
|
||||
63
x/mlxrunner/mlx/dynamic.go
Normal file
63
x/mlxrunner/mlx/dynamic.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package mlx
|
||||
|
||||
// #include "dynamic.h"
|
||||
// #include "generated.h"
|
||||
// #include <stdlib.h>
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func init() {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
|
||||
case "windows":
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH")
|
||||
if !ok {
|
||||
slog.Debug("OLLAMA_LIBRARY_PATH not set, skipping mlx dynamic loading")
|
||||
return
|
||||
}
|
||||
|
||||
for _, path := range filepath.SplitList(paths) {
|
||||
matches, err := fs.Glob(os.DirFS(path), "libmlxc.*")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for _, match := range matches {
|
||||
path := filepath.Join(paths, match)
|
||||
slog.Info("Loading MLX dynamic library", "path", path)
|
||||
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
var handle C.mlx_dynamic_handle
|
||||
if C.mlx_dynamic_load(&handle, cPath) != 0 {
|
||||
slog.Error("Failed to load MLX dynamic library", "path", path)
|
||||
continue
|
||||
}
|
||||
|
||||
if C.mlx_dynamic_load_symbols(handle) != 0 {
|
||||
slog.Error("Failed to load MLX dynamic library symbols", "path", path)
|
||||
C.mlx_dynamic_unload(&handle)
|
||||
continue
|
||||
}
|
||||
|
||||
slog.Info("Loaded MLX dynamic library", "path", path)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
panic("Failed to load any MLX dynamic library")
|
||||
}
|
||||
27
x/mlxrunner/mlx/dynamic.h
Normal file
27
x/mlxrunner/mlx/dynamic.h
Normal file
@@ -0,0 +1,27 @@
|
||||
#ifndef MLX_DYNAMIC_H
|
||||
#define MLX_DYNAMIC_H
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
#define DLSYM(handle, symbol) GetProcAddress((HMODULE)(handle), symbol)
|
||||
#else
|
||||
#include <dlfcn.h>
|
||||
#define DLSYM(handle, symbol) dlsym(handle.ctx, symbol)
|
||||
#endif
|
||||
|
||||
#define ERROR(fmt, ...) fprintf(stderr, "%s %s - ERROR - %s:%d - " fmt "\n", __DATE__, __TIME__, __FILE__, __LINE__, ##__VA_ARGS__); return 1
|
||||
#define CHECK(x) if (!(x)) { ERROR("CHECK failed: " #x); }
|
||||
#define CHECK_LOAD(handle, x) x##_ = DLSYM(handle, #x); CHECK(x##_)
|
||||
|
||||
typedef struct {
|
||||
void* ctx;
|
||||
} mlx_dynamic_handle;
|
||||
|
||||
int mlx_dynamic_load(
|
||||
mlx_dynamic_handle* handle,
|
||||
const char *path);
|
||||
|
||||
void mlx_dynamic_unload(
|
||||
mlx_dynamic_handle* handle);
|
||||
|
||||
#endif // MLX_DYNAMIC_H
|
||||
72
x/mlxrunner/mlx/fast.go
Normal file
72
x/mlxrunner/mlx/fast.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *Array {
|
||||
if mask == nil {
|
||||
mask = New("")
|
||||
}
|
||||
|
||||
sinks := New("")
|
||||
|
||||
mode := "causal"
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
out := New("FAST_SDPA", query, key, value, mask, sinks)
|
||||
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
type LayerNorm struct {
|
||||
Weight Array `weight:"weight"`
|
||||
Bias Array `weight:"bias"`
|
||||
}
|
||||
|
||||
func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
|
||||
out := New("FAST_LAYERNORM", x)
|
||||
C.mlx_fast_layer_norm(&out.ctx, x.ctx, r.Weight.ctx, r.Bias.ctx, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
type RMSNorm struct {
|
||||
Weight Array `weight:"weight"`
|
||||
}
|
||||
|
||||
func (r RMSNorm) Forward(x *Array, eps float32) *Array {
|
||||
out := New("FAST_RMSNORM", x)
|
||||
C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
type RoPE struct {
|
||||
Dims int
|
||||
Traditional bool
|
||||
Base float32 `json:"rope_theta"`
|
||||
Scale float32
|
||||
}
|
||||
|
||||
func (r RoPE) Forward(t *Array, offset int) *Array {
|
||||
freqs := New("")
|
||||
out := New("FAST_ROPE", t, freqs)
|
||||
C.mlx_fast_rope(
|
||||
&out.ctx,
|
||||
t.ctx,
|
||||
C.int(r.Dims),
|
||||
C._Bool(r.Traditional),
|
||||
C.mlx_optional_float{
|
||||
value: C.float(r.Base),
|
||||
has_value: C._Bool(func() bool { return r.Base != 0 }()),
|
||||
},
|
||||
C.float(r.Scale),
|
||||
C.int(offset),
|
||||
freqs.ctx,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
2724
x/mlxrunner/mlx/generated.c
Normal file
2724
x/mlxrunner/mlx/generated.c
Normal file
File diff suppressed because it is too large
Load Diff
7135
x/mlxrunner/mlx/generated.h
Normal file
7135
x/mlxrunner/mlx/generated.h
Normal file
File diff suppressed because it is too large
Load Diff
17
x/mlxrunner/mlx/generator/generated.c.gotmpl
Normal file
17
x/mlxrunner/mlx/generator/generated.c.gotmpl
Normal file
@@ -0,0 +1,17 @@
|
||||
// This code is auto-generated; DO NOT EDIT.
|
||||
|
||||
#include "generated.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
{{ range .Functions }}
|
||||
{{ .Type }} (*{{ .Name }}_){{ .Parameters }} = NULL;
|
||||
{{- end }}
|
||||
|
||||
int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
{{- range .Functions }}
|
||||
CHECK_LOAD(handle, {{ .Name }});
|
||||
{{- end }}
|
||||
return 0;
|
||||
}
|
||||
22
x/mlxrunner/mlx/generator/generated.h.gotmpl
Normal file
22
x/mlxrunner/mlx/generator/generated.h.gotmpl
Normal file
@@ -0,0 +1,22 @@
|
||||
// This code is auto-generated; DO NOT EDIT.
|
||||
|
||||
#ifndef MLX_GENERATED_H
|
||||
#define MLX_GENERATED_H
|
||||
|
||||
#include "dynamic.h"
|
||||
#include "mlx/c/mlx.h"
|
||||
{{ range .Functions }}
|
||||
#undef {{ .Name }}
|
||||
{{- end }}
|
||||
{{ range .Functions }}
|
||||
extern {{ .Type }} (*{{ .Name }}_){{ .Parameters }};
|
||||
{{- end }}
|
||||
|
||||
int mlx_dynamic_load_symbols(mlx_dynamic_handle handle);
|
||||
{{ range .Functions }}
|
||||
static inline {{ .Type }} {{ .Name }}{{ .Parameters }} {{ "{" }}
|
||||
return {{ .Name }}_({{ .Args }});
|
||||
{{ "}" }}
|
||||
{{- end }}
|
||||
|
||||
#endif // MLX_GENERATED_H
|
||||
135
x/mlxrunner/mlx/generator/main.go
Normal file
135
x/mlxrunner/mlx/generator/main.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
tree_sitter "github.com/tree-sitter/go-tree-sitter"
|
||||
tree_sitter_cpp "github.com/tree-sitter/tree-sitter-cpp/bindings/go"
|
||||
)
|
||||
|
||||
//go:embed *.gotmpl
|
||||
var fsys embed.FS
|
||||
|
||||
type Function struct {
|
||||
Type,
|
||||
Name,
|
||||
Parameters,
|
||||
Args string
|
||||
}
|
||||
|
||||
func ParseFunction(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) Function {
|
||||
var fn Function
|
||||
fn.Name = node.ChildByFieldName("declarator").Utf8Text(source)
|
||||
if params := node.ChildByFieldName("parameters"); params != nil {
|
||||
fn.Parameters = params.Utf8Text(source)
|
||||
fn.Args = ParseParameters(params, tc, source)
|
||||
}
|
||||
|
||||
var types []string
|
||||
for node.Parent() != nil && node.Parent().Kind() != "declaration" {
|
||||
if node.Parent().Kind() == "pointer_declarator" {
|
||||
types = append(types, "*")
|
||||
}
|
||||
node = node.Parent()
|
||||
}
|
||||
|
||||
for sibling := node.PrevSibling(); sibling != nil; sibling = sibling.PrevSibling() {
|
||||
types = append(types, sibling.Utf8Text(source))
|
||||
}
|
||||
|
||||
slices.Reverse(types)
|
||||
fn.Type = strings.Join(types, " ")
|
||||
return fn
|
||||
}
|
||||
|
||||
func ParseParameters(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) string {
|
||||
var s []string
|
||||
for _, child := range node.Children(tc) {
|
||||
if child.IsNamed() {
|
||||
child := child.ChildByFieldName("declarator")
|
||||
for child != nil && child.Kind() != "identifier" {
|
||||
if child.Kind() == "parenthesized_declarator" {
|
||||
child = child.Child(1)
|
||||
} else {
|
||||
child = child.ChildByFieldName("declarator")
|
||||
}
|
||||
}
|
||||
|
||||
if child != nil {
|
||||
s = append(s, child.Utf8Text(source))
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Join(s, ", ")
|
||||
}
|
||||
|
||||
func main() {
|
||||
var output string
|
||||
flag.StringVar(&output, "output", ".", "Output directory for generated files")
|
||||
flag.Parse()
|
||||
|
||||
parser := tree_sitter.NewParser()
|
||||
defer parser.Close()
|
||||
|
||||
language := tree_sitter.NewLanguage(tree_sitter_cpp.Language())
|
||||
parser.SetLanguage(language)
|
||||
|
||||
query, _ := tree_sitter.NewQuery(language, `(function_declarator declarator: (identifier)) @func`)
|
||||
defer query.Close()
|
||||
|
||||
qc := tree_sitter.NewQueryCursor()
|
||||
defer qc.Close()
|
||||
|
||||
var funs []Function
|
||||
for _, arg := range flag.Args() {
|
||||
bts, err := os.ReadFile(arg)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error reading file %s: %v\n", arg, err)
|
||||
continue
|
||||
}
|
||||
|
||||
tree := parser.Parse(bts, nil)
|
||||
defer tree.Close()
|
||||
|
||||
tc := tree.Walk()
|
||||
defer tc.Close()
|
||||
|
||||
matches := qc.Matches(query, tree.RootNode(), bts)
|
||||
for match := matches.Next(); match != nil; match = matches.Next() {
|
||||
for _, capture := range match.Captures {
|
||||
funs = append(funs, ParseFunction(&capture.Node, tc, bts))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tmpl, err := template.New("").ParseFS(fsys, "*.gotmpl")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error parsing template: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, tmpl := range tmpl.Templates() {
|
||||
name := filepath.Join(output, strings.TrimSuffix(tmpl.Name(), ".gotmpl"))
|
||||
|
||||
fmt.Println("Generating", name)
|
||||
f, err := os.Create(name)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error creating file %s: %v\n", name, err)
|
||||
continue
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := tmpl.Execute(f, map[string]any{
|
||||
"Functions": funs,
|
||||
}); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error executing template %s: %v\n", tmpl.Name(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
43
x/mlxrunner/mlx/io.go
Normal file
43
x/mlxrunner/mlx/io.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"iter"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func Load(path string) iter.Seq2[string, *Array] {
|
||||
return func(yield func(string, *Array) bool) {
|
||||
string2array := C.mlx_map_string_to_array_new()
|
||||
defer C.mlx_map_string_to_array_free(string2array)
|
||||
|
||||
string2string := C.mlx_map_string_to_string_new()
|
||||
defer C.mlx_map_string_to_string_free(string2string)
|
||||
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
cpu := C.mlx_default_cpu_stream_new()
|
||||
defer C.mlx_stream_free(cpu)
|
||||
|
||||
C.mlx_load_safetensors(&string2array, &string2string, cPath, cpu)
|
||||
|
||||
it := C.mlx_map_string_to_array_iterator_new(string2array)
|
||||
defer C.mlx_map_string_to_array_iterator_free(it)
|
||||
|
||||
for {
|
||||
var key *C.char
|
||||
value := C.mlx_array_new()
|
||||
if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 {
|
||||
break
|
||||
}
|
||||
|
||||
name := C.GoString(key)
|
||||
if !yield(name, &Array{ctx: value, desc: tensorDesc{name: name, numRefs: 1000}}) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
85
x/mlxrunner/mlx/memory.go
Normal file
85
x/mlxrunner/mlx/memory.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func (b Byte) String() string {
|
||||
return strconv.FormatInt(int64(b), 10) + " B"
|
||||
}
|
||||
|
||||
func (b KibiByte) String() string {
|
||||
return strconv.FormatFloat(float64(b)/(1<<10), 'f', 2, 64) + " KiB"
|
||||
}
|
||||
|
||||
func (b MebiByte) String() string {
|
||||
return strconv.FormatFloat(float64(b)/(1<<(2*10)), 'f', 2, 64) + " MiB"
|
||||
}
|
||||
|
||||
func (b GibiByte) String() string {
|
||||
return strconv.FormatFloat(float64(b)/(1<<(3*10)), 'f', 2, 64) + " GiB"
|
||||
}
|
||||
|
||||
func (b TebiByte) String() string {
|
||||
return strconv.FormatFloat(float64(b)/(1<<(4*10)), 'f', 2, 64) + " TiB"
|
||||
}
|
||||
|
||||
func PrettyBytes(n int) fmt.Stringer {
|
||||
switch {
|
||||
case n < 1<<10:
|
||||
return Byte(n)
|
||||
case n < 1<<(2*10):
|
||||
return KibiByte(n)
|
||||
case n < 1<<(3*10):
|
||||
return MebiByte(n)
|
||||
case n < 1<<(4*10):
|
||||
return GibiByte(n)
|
||||
default:
|
||||
return TebiByte(n)
|
||||
}
|
||||
}
|
||||
|
||||
func ActiveMemory() int {
|
||||
var active C.size_t
|
||||
C.mlx_get_active_memory(&active)
|
||||
return int(active)
|
||||
}
|
||||
|
||||
func CacheMemory() int {
|
||||
var cache C.size_t
|
||||
C.mlx_get_cache_memory(&cache)
|
||||
return int(cache)
|
||||
}
|
||||
|
||||
func PeakMemory() int {
|
||||
var peak C.size_t
|
||||
C.mlx_get_peak_memory(&peak)
|
||||
return int(peak)
|
||||
}
|
||||
|
||||
type Memory struct{}
|
||||
|
||||
func (Memory) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.Any("active", PrettyBytes(ActiveMemory())),
|
||||
slog.Any("cache", PrettyBytes(CacheMemory())),
|
||||
slog.Any("peak", PrettyBytes(PeakMemory())),
|
||||
)
|
||||
}
|
||||
|
||||
type (
|
||||
Byte int
|
||||
KibiByte int
|
||||
MebiByte int
|
||||
GibiByte int
|
||||
TebiByte int
|
||||
)
|
||||
|
||||
func ClearCache() {
|
||||
C.mlx_clear_cache()
|
||||
}
|
||||
38
x/mlxrunner/mlx/mlx.go
Normal file
38
x/mlxrunner/mlx/mlx.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package mlx
|
||||
|
||||
//go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist -DCMAKE_BUILD_TYPE=Release
|
||||
//go:generate cmake --build build --parallel
|
||||
//go:generate cmake --install build
|
||||
//go:generate sh -c "go run generator/main.go -output=. ./dist/include/mlx/c/*.h"
|
||||
|
||||
// #cgo CXXFLAGS: -std=c++17
|
||||
// #cgo CPPFLAGS: -I${SRCDIR}/dist/include
|
||||
// #cgo LDFLAGS: -L${SRCDIR}/dist/lib -lstdc++
|
||||
// #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
func doEval(outputs []*Array, async bool) {
|
||||
vector := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
for _, output := range outputs {
|
||||
if output.Valid() {
|
||||
C.mlx_vector_array_append_value(vector, output.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
if async {
|
||||
C.mlx_async_eval(vector)
|
||||
} else {
|
||||
C.mlx_eval(vector)
|
||||
}
|
||||
}
|
||||
|
||||
func AsyncEval(outputs ...*Array) {
|
||||
doEval(outputs, true)
|
||||
}
|
||||
|
||||
func Eval(outputs ...*Array) {
|
||||
doEval(outputs, false)
|
||||
}
|
||||
36
x/mlxrunner/mlx/nn.go
Normal file
36
x/mlxrunner/mlx/nn.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package mlx
|
||||
|
||||
type Linear struct {
|
||||
Weight Array `weight:"weight"`
|
||||
Bias Array `weight:"bias"`
|
||||
}
|
||||
|
||||
// Forward computes the linear transformation: x @ Weight.T + Bias
|
||||
func (m Linear) Forward(x *Array) *Array {
|
||||
w := m.Weight.Transpose(1, 0)
|
||||
if m.Bias.Valid() {
|
||||
return m.Bias.Addmm(x, w, 1.0, 1.0)
|
||||
}
|
||||
|
||||
return x.Matmul(w)
|
||||
}
|
||||
|
||||
func (m Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array {
|
||||
w := m.Weight.Transpose(0, 2, 1)
|
||||
// TODO: bias
|
||||
return x.GatherMM(w, lhs, rhs, sorted)
|
||||
}
|
||||
|
||||
type Embedding struct {
|
||||
Weight Array `weight:"weight"`
|
||||
}
|
||||
|
||||
func (e *Embedding) Forward(indices *Array) *Array {
|
||||
return e.Weight.TakeAxis(indices, 0)
|
||||
}
|
||||
|
||||
func (e *Embedding) AsLinear() Linear {
|
||||
return Linear{
|
||||
Weight: e.Weight,
|
||||
}
|
||||
}
|
||||
254
x/mlxrunner/mlx/ops.go
Normal file
254
x/mlxrunner/mlx/ops.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func (t *Array) Abs() *Array {
|
||||
out := New("ABS", t)
|
||||
C.mlx_abs(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Add(other *Array) *Array {
|
||||
out := New("ADD", t, other)
|
||||
C.mlx_add(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Addmm(a, b *Array, alpha, beta float32) *Array {
|
||||
out := New("ADDMM", t, a, b)
|
||||
C.mlx_addmm(&out.ctx, t.ctx, a.ctx, b.ctx, C.float(alpha), C.float(beta), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Argmax(axis int, keepDims bool) *Array {
|
||||
out := New("ARGMAX", t)
|
||||
C.mlx_argmax_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) ArgpartitionAxis(kth int, axis int) *Array {
|
||||
out := New("ARGPARTITION", t)
|
||||
C.mlx_argpartition_axis(&out.ctx, t.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) ArgsortAxis(axis int) *Array {
|
||||
out := New("ARGSORT_AXIS", t)
|
||||
C.mlx_argsort_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) AsType(dtype DType) *Array {
|
||||
out := New("AS_TYPE", t)
|
||||
C.mlx_astype(&out.ctx, t.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array {
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
}
|
||||
|
||||
cStrides := make([]C.int64_t, len(strides))
|
||||
for i, s := range strides {
|
||||
cStrides[i] = C.int64_t(s)
|
||||
}
|
||||
|
||||
out := New("AS_STRIDED", t)
|
||||
C.mlx_as_strided(
|
||||
&out.ctx, t.ctx,
|
||||
unsafe.SliceData(cShape), C.size_t(len(shape)),
|
||||
unsafe.SliceData(cStrides), C.size_t(len(strides)),
|
||||
C.size_t(offset),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Concatenate(axis int, others ...*Array) *Array {
|
||||
vector := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
s := append([]*Array{t}, others...)
|
||||
for _, other := range s {
|
||||
C.mlx_vector_array_append_value(vector, other.ctx)
|
||||
}
|
||||
|
||||
out := New("CONCATENATE", s...)
|
||||
C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Divide(other *Array) *Array {
|
||||
out := New("DIVIDE", t, other)
|
||||
C.mlx_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) ExpandDims(axis int) *Array {
|
||||
out := New("EXPAND_DIMS", t)
|
||||
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Flatten(startAxis, endAxis int) *Array {
|
||||
out := New("FLATTEN", t)
|
||||
C.mlx_flatten(&out.ctx, t.ctx, C.int(startAxis), C.int(endAxis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) FloorDivide(other *Array) *Array {
|
||||
out := New("FLOOR_DIVIDE", t, other)
|
||||
C.mlx_floor_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array {
|
||||
if lhs == nil {
|
||||
lhs = New("")
|
||||
}
|
||||
if rhs == nil {
|
||||
rhs = New("")
|
||||
}
|
||||
out := New("GATHER_MM", t, other, lhs, rhs)
|
||||
C.mlx_gather_mm(&out.ctx, t.ctx, other.ctx, lhs.ctx, rhs.ctx, C.bool(sorted), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Logsumexp(keepDims bool) *Array {
|
||||
out := New("LOGSUMEXP", t)
|
||||
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Matmul(other *Array) *Array {
|
||||
out := New("MATMUL", t, other)
|
||||
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Multiply(other *Array) *Array {
|
||||
out := New("MULTIPLY", t, other)
|
||||
C.mlx_multiply(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Negative() *Array {
|
||||
out := New("NEGATIVE", t)
|
||||
C.mlx_negative(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Power(exponent *Array) *Array {
|
||||
out := New("POWER", t, exponent)
|
||||
C.mlx_power(&out.ctx, t.ctx, exponent.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array {
|
||||
out := New("PUT_ALONG_AXIS", t, indices, values)
|
||||
C.mlx_put_along_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Reshape(axes ...int) *Array {
|
||||
cAxes := make([]C.int, len(axes))
|
||||
for i := range axes {
|
||||
cAxes[i] = C.int(axes[i])
|
||||
}
|
||||
|
||||
out := New("RESHAPE", t)
|
||||
C.mlx_reshape(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Sigmoid() *Array {
|
||||
out := New("SIGMOID", t)
|
||||
C.mlx_sigmoid(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Sqrt() *Array {
|
||||
out := New("SQRT", t)
|
||||
C.mlx_sqrt(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Squeeze(axis int) *Array {
|
||||
out := New("SQUEEZE", t)
|
||||
C.mlx_squeeze_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) StackAxis(axis int, others ...*Array) *Array {
|
||||
vectorData := make([]C.mlx_array, len(others)+1)
|
||||
vectorData[0] = t.ctx
|
||||
for i := range others {
|
||||
vectorData[i+1] = others[i].ctx
|
||||
}
|
||||
|
||||
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
out := New("STACK_AXIS", append(others, t)...)
|
||||
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Subtract(other *Array) *Array {
|
||||
out := New("SUBTRACT", t, other)
|
||||
C.mlx_subtract(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) SumAxis(axis int, keepDims bool) *Array {
|
||||
out := New("SUM_AXIS", t)
|
||||
C.mlx_sum_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) TakeAxis(indices *Array, axis int) *Array {
|
||||
out := New("TAKE_AXIS", t, indices)
|
||||
C.mlx_take_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) TakeAlongAxis(indices *Array, axis int) *Array {
|
||||
out := New("TAKE_ALONG_AXIS", t, indices)
|
||||
C.mlx_take_along_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Tanh() *Array {
|
||||
out := New("TANH", t)
|
||||
C.mlx_tanh(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Transpose(axes ...int) *Array {
|
||||
cAxes := make([]C.int, len(axes))
|
||||
for i, axis := range axes {
|
||||
cAxes[i] = C.int(axis)
|
||||
}
|
||||
|
||||
out := New("TRANSPOSE", t)
|
||||
C.mlx_transpose_axes(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Zeros(dtype DType, shape ...int) *Array {
|
||||
cAxes := make([]C.int, len(shape))
|
||||
for i := range shape {
|
||||
cAxes[i] = C.int(shape[i])
|
||||
}
|
||||
|
||||
t := New("ZEROS")
|
||||
C.mlx_zeros(&t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
return t
|
||||
}
|
||||
425
x/mlxrunner/mlx/ops_extra.go
Normal file
425
x/mlxrunner/mlx/ops_extra.go
Normal file
@@ -0,0 +1,425 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// Quantization operations
|
||||
|
||||
func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, biases *Array) {
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||
res := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(res)
|
||||
C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, DefaultStream().ctx)
|
||||
|
||||
vecSize := int(C.mlx_vector_array_size(res))
|
||||
w0 := New("QUANTIZE_W")
|
||||
C.mlx_vector_array_get(&w0.ctx, res, 0)
|
||||
w1 := New("QUANTIZE_S")
|
||||
C.mlx_vector_array_get(&w1.ctx, res, 1)
|
||||
if vecSize >= 3 {
|
||||
w2 := New("QUANTIZE_B")
|
||||
C.mlx_vector_array_get(&w2.ctx, res, 2)
|
||||
return w0, w1, w2
|
||||
}
|
||||
return w0, w1, nil
|
||||
}
|
||||
|
||||
func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Array {
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||
optDtype := C.mlx_optional_dtype{has_value: false}
|
||||
|
||||
inputs := []*Array{w, scales}
|
||||
var b C.mlx_array
|
||||
if biases != nil {
|
||||
b = biases.ctx
|
||||
inputs = append(inputs, biases)
|
||||
}
|
||||
|
||||
out := New("DEQUANTIZE", inputs...)
|
||||
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, optDtype, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int, mode string) *Array {
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||
|
||||
inputs := []*Array{x, w, scales}
|
||||
var b C.mlx_array
|
||||
if biases != nil {
|
||||
b = biases.ctx
|
||||
inputs = append(inputs, biases)
|
||||
}
|
||||
|
||||
out := New("QUANTIZED_MATMUL", inputs...)
|
||||
C.mlx_quantized_matmul(&out.ctx, x.ctx, w.ctx, scales.ctx, b, C.bool(transpose), optGroupSize, optBits, cMode, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, transpose bool, groupSize, bits int, mode string, sortedIndices bool) *Array {
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||
|
||||
inputs := []*Array{x, w, scales}
|
||||
var b, lhs, rhs C.mlx_array
|
||||
if biases != nil {
|
||||
b = biases.ctx
|
||||
inputs = append(inputs, biases)
|
||||
}
|
||||
if lhsIndices != nil {
|
||||
lhs = lhsIndices.ctx
|
||||
inputs = append(inputs, lhsIndices)
|
||||
}
|
||||
if rhsIndices != nil {
|
||||
rhs = rhsIndices.ctx
|
||||
inputs = append(inputs, rhsIndices)
|
||||
}
|
||||
|
||||
out := New("GATHER_QMM", inputs...)
|
||||
C.mlx_gather_qmm(&out.ctx, x.ctx, w.ctx, scales.ctx, b, lhs, rhs, C.bool(transpose), optGroupSize, optBits, cMode, C.bool(sortedIndices), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Missing tensor ops
|
||||
|
||||
func Tile(a *Array, reps []int32) *Array {
|
||||
cReps := make([]C.int, len(reps))
|
||||
for i, r := range reps {
|
||||
cReps[i] = C.int(r)
|
||||
}
|
||||
out := New("TILE", a)
|
||||
C.mlx_tile(&out.ctx, a.ctx, unsafe.SliceData(cReps), C.size_t(len(reps)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Tri(n, m int32, k int) *Array {
|
||||
out := New("TRI")
|
||||
C.mlx_tri(&out.ctx, C.int(n), C.int(m), C.int(k), C.mlx_dtype(DTypeFloat32), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Where(condition, a, b *Array) *Array {
|
||||
out := New("WHERE", condition, a, b)
|
||||
C.mlx_where(&out.ctx, condition.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Convenience wrappers (function-style for the model code)
|
||||
|
||||
func Stack(arrays []*Array, axis int) *Array {
|
||||
vectorData := make([]C.mlx_array, len(arrays))
|
||||
for i := range arrays {
|
||||
vectorData[i] = arrays[i].ctx
|
||||
}
|
||||
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
out := New("STACK", arrays...)
|
||||
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Neg(a *Array) *Array {
|
||||
return a.Negative()
|
||||
}
|
||||
|
||||
func Sum(a *Array, axis int, keepDims bool) *Array {
|
||||
return a.SumAxis(axis, keepDims)
|
||||
}
|
||||
|
||||
func Argsort(a *Array, axis int) *Array {
|
||||
return a.ArgsortAxis(axis)
|
||||
}
|
||||
|
||||
func Take(a *Array, indices *Array, axis int) *Array {
|
||||
return a.TakeAxis(indices, axis)
|
||||
}
|
||||
|
||||
func RSqrt(a *Array) *Array {
|
||||
out := New("RSQRT", a)
|
||||
C.mlx_rsqrt(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Mean(a *Array, axis int, keepDims bool) *Array {
|
||||
out := New("MEAN_AXIS", a)
|
||||
C.mlx_mean_axis(&out.ctx, a.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Argpartition(a *Array, kth int, axis int) *Array {
|
||||
return a.ArgpartitionAxis(kth, axis)
|
||||
}
|
||||
|
||||
func TakeAlongAxis(a, indices *Array, axis int) *Array {
|
||||
return a.TakeAlongAxis(indices, axis)
|
||||
}
|
||||
|
||||
// Function-style wrappers matching imagegen API
|
||||
|
||||
func Add(a, b *Array) *Array {
|
||||
return a.Add(b)
|
||||
}
|
||||
|
||||
func Sub(a, b *Array) *Array {
|
||||
return a.Subtract(b)
|
||||
}
|
||||
|
||||
func Mul(a, b *Array) *Array {
|
||||
return a.Multiply(b)
|
||||
}
|
||||
|
||||
func Div(a, b *Array) *Array {
|
||||
return a.Divide(b)
|
||||
}
|
||||
|
||||
func Matmul(a, b *Array) *Array {
|
||||
return a.Matmul(b)
|
||||
}
|
||||
|
||||
func Reshape(a *Array, shape ...int32) *Array {
|
||||
axes := make([]int, len(shape))
|
||||
for i, s := range shape {
|
||||
axes[i] = int(s)
|
||||
}
|
||||
return a.Reshape(axes...)
|
||||
}
|
||||
|
||||
func Transpose(a *Array, axes ...int) *Array {
|
||||
return a.Transpose(axes...)
|
||||
}
|
||||
|
||||
func ExpandDims(a *Array, axis int) *Array {
|
||||
return a.ExpandDims(axis)
|
||||
}
|
||||
|
||||
func Squeeze(a *Array, axis int) *Array {
|
||||
return a.Squeeze(axis)
|
||||
}
|
||||
|
||||
func Flatten(a *Array) *Array {
|
||||
return a.Flatten(0, -1)
|
||||
}
|
||||
|
||||
func Concatenate(arrays []*Array, axis int) *Array {
|
||||
if len(arrays) == 0 {
|
||||
return nil
|
||||
}
|
||||
return arrays[0].Concatenate(axis, arrays[1:]...)
|
||||
}
|
||||
|
||||
func SliceStartStop(a *Array, start, stop []int32) *Array {
|
||||
n := len(start)
|
||||
cStart := make([]C.int, n)
|
||||
cStop := make([]C.int, n)
|
||||
cStrides := make([]C.int, n)
|
||||
for i := 0; i < n; i++ {
|
||||
cStart[i] = C.int(start[i])
|
||||
cStop[i] = C.int(stop[i])
|
||||
cStrides[i] = 1
|
||||
}
|
||||
out := New("SLICE", a)
|
||||
C.mlx_slice(&out.ctx, a.ctx, unsafe.SliceData(cStart), C.size_t(n), unsafe.SliceData(cStop), C.size_t(n), unsafe.SliceData(cStrides), C.size_t(n), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func GatherMM(a, b *Array, lhsIndices, rhsIndices *Array, sortedIndices bool) *Array {
|
||||
if lhsIndices == nil {
|
||||
lhsIndices = New("")
|
||||
}
|
||||
if rhsIndices == nil {
|
||||
rhsIndices = New("")
|
||||
}
|
||||
return a.GatherMM(b, lhsIndices, rhsIndices, sortedIndices)
|
||||
}
|
||||
|
||||
func SiLU(a *Array) *Array {
|
||||
sig := a.Sigmoid()
|
||||
return a.Multiply(sig)
|
||||
}
|
||||
|
||||
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
|
||||
freqs := New("")
|
||||
out := New("FAST_ROPE", x, freqs)
|
||||
C.mlx_fast_rope(
|
||||
&out.ctx,
|
||||
x.ctx,
|
||||
C.int(dims),
|
||||
C.bool(traditional),
|
||||
C.mlx_optional_float{
|
||||
value: C.float(base),
|
||||
has_value: C.bool(func() bool { return base != 0 }()),
|
||||
},
|
||||
C.float(scale),
|
||||
C.int(offset),
|
||||
freqs.ctx,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
func Sigmoid(a *Array) *Array {
|
||||
return a.Sigmoid()
|
||||
}
|
||||
|
||||
func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask bool) *Array {
|
||||
mask := New("")
|
||||
sinks := New("")
|
||||
mode := ""
|
||||
if causalMask {
|
||||
mode = "causal"
|
||||
}
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
out := New("FAST_SDPA", q, k, v, mask, sinks)
|
||||
C.mlx_fast_scaled_dot_product_attention(&out.ctx, q.ctx, k.ctx, v.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func RMSNormFn(x, weight *Array, eps float32) *Array {
|
||||
out := New("FAST_RMSNORM", x)
|
||||
C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func AddMM(c, a, b *Array, alpha, beta float32) *Array {
|
||||
return c.Addmm(a, b, alpha, beta)
|
||||
}
|
||||
|
||||
// Scalar helpers
|
||||
|
||||
func AddScalar(a *Array, s float32) *Array {
|
||||
scalar := FromValue(s)
|
||||
return a.Add(scalar)
|
||||
}
|
||||
|
||||
func MulScalar(a *Array, s float32) *Array {
|
||||
scalar := FromValue(s)
|
||||
return a.Multiply(scalar)
|
||||
}
|
||||
|
||||
func DivScalar(a *Array, s float32) *Array {
|
||||
scalar := FromValue(s)
|
||||
return a.Divide(scalar)
|
||||
}
|
||||
|
||||
func FloorDivideScalar(a *Array, s int32) *Array {
|
||||
scalar := FromValue(int(s))
|
||||
return a.FloorDivide(scalar)
|
||||
}
|
||||
|
||||
// Array constructors
|
||||
|
||||
func NewArrayInt32(data []int32, shape []int32) *Array {
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
}
|
||||
out := New("NEW_ARRAY_INT32")
|
||||
out.ctx = C.mlx_array_new_data(unsafe.Pointer(&data[0]), unsafe.SliceData(cShape), C.int(len(shape)), C.mlx_dtype(DTypeInt32))
|
||||
return out
|
||||
}
|
||||
|
||||
func NewScalarArray(value float32) *Array {
|
||||
out := New("SCALAR")
|
||||
out.ctx = C.mlx_array_new_float32(C.float(value))
|
||||
return out
|
||||
}
|
||||
|
||||
func ZerosF32(shape []int32) *Array {
|
||||
return Zeros(DTypeFloat32, func() []int {
|
||||
ints := make([]int, len(shape))
|
||||
for i, s := range shape {
|
||||
ints[i] = int(s)
|
||||
}
|
||||
return ints
|
||||
}()...)
|
||||
}
|
||||
|
||||
// Utility
|
||||
|
||||
func Collect(v any) []*Array {
|
||||
var arrays []*Array
|
||||
seen := make(map[uintptr]bool)
|
||||
collect(reflect.ValueOf(v), &arrays, seen)
|
||||
return arrays
|
||||
}
|
||||
|
||||
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
|
||||
if !v.IsValid() {
|
||||
return
|
||||
}
|
||||
|
||||
if v.Kind() == reflect.Ptr {
|
||||
if v.IsNil() {
|
||||
return
|
||||
}
|
||||
ptr := v.Pointer()
|
||||
if seen[ptr] {
|
||||
return
|
||||
}
|
||||
seen[ptr] = true
|
||||
|
||||
if arr, ok := v.Interface().(*Array); ok {
|
||||
if arr != nil && arr.Valid() {
|
||||
*arrays = append(*arrays, arr)
|
||||
}
|
||||
return
|
||||
}
|
||||
collect(v.Elem(), arrays, seen)
|
||||
return
|
||||
}
|
||||
|
||||
switch v.Kind() {
|
||||
case reflect.Struct:
|
||||
// Check if this struct IS an Array (not a pointer to one)
|
||||
if arr, ok := v.Addr().Interface().(*Array); ok {
|
||||
if arr != nil && arr.Valid() {
|
||||
*arrays = append(*arrays, arr)
|
||||
}
|
||||
return
|
||||
}
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Field(i)
|
||||
if field.CanInterface() {
|
||||
collect(field, arrays, seen)
|
||||
}
|
||||
}
|
||||
case reflect.Slice:
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
collect(v.Index(i), arrays, seen)
|
||||
}
|
||||
case reflect.Map:
|
||||
for _, key := range v.MapKeys() {
|
||||
collect(v.MapIndex(key), arrays, seen)
|
||||
}
|
||||
case reflect.Interface:
|
||||
if !v.IsNil() {
|
||||
collect(v.Elem(), arrays, seen)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func EnableCompile() {
|
||||
C.mlx_enable_compile()
|
||||
}
|
||||
|
||||
func DisableCompile() {
|
||||
C.mlx_disable_compile()
|
||||
}
|
||||
11
x/mlxrunner/mlx/random.go
Normal file
11
x/mlxrunner/mlx/random.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
func (t *Array) Categorical(axis int) *Array {
|
||||
key := New("")
|
||||
out := New("", t, key)
|
||||
C.mlx_random_categorical(&out.ctx, t.ctx, C.int(axis), key.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
84
x/mlxrunner/mlx/slice.go
Normal file
84
x/mlxrunner/mlx/slice.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type slice struct {
|
||||
args []int
|
||||
}
|
||||
|
||||
func Slice(args ...int) slice {
|
||||
return slice{args: args}
|
||||
}
|
||||
|
||||
func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
|
||||
if len(slices) != len(dims) {
|
||||
panic("number of slice arguments must match number of tensor dimensions")
|
||||
}
|
||||
|
||||
args := [3][]C.int{
|
||||
make([]C.int, len(slices)),
|
||||
make([]C.int, len(slices)),
|
||||
make([]C.int, len(slices)),
|
||||
}
|
||||
|
||||
for i, s := range slices {
|
||||
switch len(s.args) {
|
||||
case 0:
|
||||
// slice[:]
|
||||
args[0][i] = C.int(0)
|
||||
args[1][i] = C.int(dims[i])
|
||||
args[2][i] = C.int(1)
|
||||
case 1:
|
||||
// slice[i]
|
||||
args[0][i] = C.int(s.args[0])
|
||||
args[1][i] = C.int(s.args[0] + 1)
|
||||
args[2][i] = C.int(1)
|
||||
case 2:
|
||||
// slice[i:j]
|
||||
args[0][i] = C.int(s.args[0])
|
||||
args[1][i] = cmp.Or(C.int(s.args[1]), C.int(dims[i]))
|
||||
args[2][i] = C.int(1)
|
||||
case 3:
|
||||
// slice[i:j:k]
|
||||
args[0][i] = C.int(s.args[0])
|
||||
args[1][i] = cmp.Or(C.int(s.args[1]), C.int(dims[i]))
|
||||
args[2][i] = C.int(s.args[2])
|
||||
default:
|
||||
panic("invalid slice arguments")
|
||||
}
|
||||
}
|
||||
|
||||
return args[0], args[1], args[2]
|
||||
}
|
||||
|
||||
func (t *Array) Slice(slices ...slice) *Array {
|
||||
starts, stops, strides := makeSlices(t.Dims(), slices...)
|
||||
out := New("SLICE", t)
|
||||
C.mlx_slice(
|
||||
&out.ctx, t.ctx,
|
||||
unsafe.SliceData(starts), C.size_t(len(starts)),
|
||||
unsafe.SliceData(stops), C.size_t(len(stops)),
|
||||
unsafe.SliceData(strides), C.size_t(len(strides)),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) SliceUpdate(other *Array, slices ...slice) *Array {
|
||||
starts, stops, strides := makeSlices(t.Dims(), slices...)
|
||||
out := New("SLICE_UPDATE", t, other)
|
||||
C.mlx_slice_update(
|
||||
&out.ctx, t.ctx, other.ctx,
|
||||
unsafe.SliceData(starts), C.size_t(len(starts)),
|
||||
unsafe.SliceData(stops), C.size_t(len(stops)),
|
||||
unsafe.SliceData(strides), C.size_t(len(strides)),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
43
x/mlxrunner/mlx/stream.go
Normal file
43
x/mlxrunner/mlx/stream.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Device struct {
|
||||
ctx C.mlx_device
|
||||
}
|
||||
|
||||
func (d Device) LogValue() slog.Value {
|
||||
str := C.mlx_string_new()
|
||||
defer C.mlx_string_free(str)
|
||||
C.mlx_device_tostring(&str, d.ctx)
|
||||
return slog.StringValue(C.GoString(C.mlx_string_data(str)))
|
||||
}
|
||||
|
||||
var DefaultDevice = sync.OnceValue(func() Device {
|
||||
d := C.mlx_device_new()
|
||||
C.mlx_get_default_device(&d)
|
||||
return Device{d}
|
||||
})
|
||||
|
||||
type Stream struct {
|
||||
ctx C.mlx_stream
|
||||
}
|
||||
|
||||
func (s Stream) LogValue() slog.Value {
|
||||
str := C.mlx_string_new()
|
||||
defer C.mlx_string_free(str)
|
||||
C.mlx_stream_tostring(&str, s.ctx)
|
||||
return slog.StringValue(C.GoString(C.mlx_string_data(str)))
|
||||
}
|
||||
|
||||
var DefaultStream = sync.OnceValue(func() Stream {
|
||||
s := C.mlx_stream_new()
|
||||
C.mlx_get_default_stream(&s, DefaultDevice().ctx)
|
||||
return Stream{s}
|
||||
})
|
||||
123
x/mlxrunner/pipeline.go
Normal file
123
x/mlxrunner/pipeline.go
Normal file
@@ -0,0 +1,123 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
if r.Model == nil {
|
||||
return errors.New("model not loaded")
|
||||
}
|
||||
|
||||
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
||||
|
||||
caches, tokens := r.FindNearestCache(inputs)
|
||||
if len(caches) == 0 {
|
||||
caches = make([]cache.Cache, r.Model.NumLayers())
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
}
|
||||
|
||||
total, processed := len(tokens), 0
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
for total-processed > 1 {
|
||||
n := min(2<<10, total-processed-1)
|
||||
temp := r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
||||
defer mlx.Free(temp)
|
||||
mlx.Eval(func() []*mlx.Array {
|
||||
s := make([]*mlx.Array, 2*len(caches))
|
||||
for i, c := range caches {
|
||||
s[2*i], s[2*i+1] = c.State()
|
||||
}
|
||||
return s
|
||||
}()...)
|
||||
processed += n
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
mlx.ClearCache()
|
||||
}
|
||||
|
||||
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
logits := r.Model.Unembed(r.Model.Forward(token.ExpandDims(0), caches))
|
||||
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
||||
|
||||
logprobs := logits.Subtract(logits.Logsumexp(true))
|
||||
return request.Sample(logprobs), logprobs
|
||||
}
|
||||
|
||||
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed))
|
||||
mlx.AsyncEval(sample, logprobs)
|
||||
|
||||
var b bytes.Buffer
|
||||
|
||||
now := time.Now()
|
||||
final := Response{PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
|
||||
outputs := make([]int32, 0, request.Options.MaxTokens)
|
||||
for i := range request.Options.MaxTokens {
|
||||
nextSample, nextLogprobs := step(sample)
|
||||
mlx.AsyncEval(nextSample, nextLogprobs)
|
||||
|
||||
if i == 0 {
|
||||
slog.Info("Prompt processing progress", "processed", total, "total", total)
|
||||
mlx.Eval(sample)
|
||||
final.PromptTokensDuration = time.Since(now)
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
output := int32(sample.Int())
|
||||
outputs = append(outputs, output)
|
||||
|
||||
if r.Tokenizer.IsEOS(output) {
|
||||
final.Token = int(output)
|
||||
final.DoneReason = 0
|
||||
final.CompletionTokens = i
|
||||
break
|
||||
}
|
||||
|
||||
request.Responses <- Response{
|
||||
Text: r.Decode(output, &b),
|
||||
Token: int(output),
|
||||
}
|
||||
|
||||
mlx.Free(sample, logprobs)
|
||||
if i%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
|
||||
sample, logprobs = nextSample, nextLogprobs
|
||||
}
|
||||
|
||||
mlx.Free(sample, logprobs)
|
||||
final.CompletionTokensDuration = time.Since(now)
|
||||
request.Responses <- final
|
||||
r.InsertCache(append(inputs, outputs...), caches)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
|
||||
token := r.Tokenizer.Decode([]int32{sample})
|
||||
|
||||
if _, err := b.WriteString(token); err != nil {
|
||||
slog.Error("Failed to write token to buffer", "error", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
if text := b.String(); utf8.ValidString(text) {
|
||||
b.Reset()
|
||||
return text
|
||||
} else if b.Len() >= utf8.UTFMax {
|
||||
b.Reset()
|
||||
return text
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
139
x/mlxrunner/runner.go
Normal file
139
x/mlxrunner/runner.go
Normal file
@@ -0,0 +1,139 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/sample"
|
||||
"github.com/ollama/ollama/x/models/glm4_moe_lite"
|
||||
)
|
||||
|
||||
// TextModel is the interface that model implementations must satisfy.
|
||||
type TextModel interface {
|
||||
Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array
|
||||
Unembed(x *mlx.Array) *mlx.Array
|
||||
NumLayers() int
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
TextCompletionsRequest
|
||||
Responses chan Response
|
||||
Pipeline func(Request) error
|
||||
|
||||
sample.Sampler
|
||||
caches []cache.Cache
|
||||
}
|
||||
|
||||
type TextCompletionsRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Options struct {
|
||||
Temperature float32 `json:"temperature"`
|
||||
TopP float32 `json:"top_p"`
|
||||
MinP float32 `json:"min_p"`
|
||||
TopK int `json:"top_k"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
|
||||
// Deprecated: use MaxTokens instead
|
||||
NumPredict int `json:"num_predict"`
|
||||
} `json:"options"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Text string `json:"content,omitempty"`
|
||||
Token int `json:"token,omitempty"`
|
||||
Logprobs []float32 `json:"logprobs,omitempty"`
|
||||
Done bool `json:"done,omitempty"`
|
||||
DoneReason int `json:"done_reason,omitempty"`
|
||||
|
||||
PromptTokens int `json:"prompt_eval_count,omitempty"`
|
||||
PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
||||
CompletionTokens int `json:"eval_count,omitempty"`
|
||||
CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"`
|
||||
TotalTokens int `json:"total_tokens,omitempty"`
|
||||
}
|
||||
|
||||
type Runner struct {
|
||||
Model TextModel
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
Requests chan Request
|
||||
CacheEntries map[int32]*CacheEntry
|
||||
}
|
||||
|
||||
func (r *Runner) Load(modelName string) error {
|
||||
modelManifest, err := manifest.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read config to detect architecture
|
||||
configData, err := modelManifest.ReadConfig("config.json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read config.json: %w", err)
|
||||
}
|
||||
|
||||
var archConfig struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
}
|
||||
if err := json.Unmarshal(configData, &archConfig); err != nil {
|
||||
return fmt.Errorf("failed to parse config.json: %w", err)
|
||||
}
|
||||
|
||||
if len(archConfig.Architectures) == 0 {
|
||||
return fmt.Errorf("no architectures found in config.json")
|
||||
}
|
||||
|
||||
slog.Info("Model architecture", "arch", archConfig.Architectures[0])
|
||||
|
||||
switch archConfig.Architectures[0] {
|
||||
case "Glm4MoeLiteForCausalLM", "GLM4MoeLite":
|
||||
model, err := glm4_moe_lite.LoadFromManifest(modelManifest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load GLM4-MoE-Lite model: %w", err)
|
||||
}
|
||||
r.Model = model
|
||||
r.Tokenizer = model.Tokenizer()
|
||||
default:
|
||||
return fmt.Errorf("unsupported architecture: %s", archConfig.Architectures[0])
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) Run(host, port string, mux http.Handler) error {
|
||||
g, ctx := errgroup.WithContext(context.Background())
|
||||
|
||||
g.Go(func() error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case request := <-r.Requests:
|
||||
if err := request.Pipeline(request); err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
close(request.Responses)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
slog.Info("Starting HTTP server", "host", host, "port", port)
|
||||
return http.ListenAndServe(net.JoinHostPort(host, port), mux)
|
||||
})
|
||||
|
||||
return g.Wait()
|
||||
}
|
||||
75
x/mlxrunner/sample/sample.go
Normal file
75
x/mlxrunner/sample/sample.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
type Sampler interface {
|
||||
Sample(*mlx.Array) *mlx.Array
|
||||
}
|
||||
|
||||
func New(temp, top_p, min_p float32, top_k int) Sampler {
|
||||
if temp == 0 {
|
||||
return greedy{}
|
||||
}
|
||||
|
||||
var samplers []Sampler
|
||||
if top_p > 0 && top_p < 1 {
|
||||
samplers = append(samplers, TopP(top_p))
|
||||
}
|
||||
|
||||
if min_p != 0 {
|
||||
samplers = append(samplers, MinP(min_p))
|
||||
}
|
||||
|
||||
if top_k > 0 {
|
||||
samplers = append(samplers, TopK(top_k))
|
||||
}
|
||||
|
||||
samplers = append(samplers, Temperature(temp))
|
||||
return chain(samplers)
|
||||
}
|
||||
|
||||
type greedy struct{}
|
||||
|
||||
func (greedy) Sample(logits *mlx.Array) *mlx.Array {
|
||||
return logits.Argmax(-1, false)
|
||||
}
|
||||
|
||||
type chain []Sampler
|
||||
|
||||
func (c chain) Sample(logits *mlx.Array) *mlx.Array {
|
||||
for _, sampler := range c {
|
||||
logits = sampler.Sample(logits)
|
||||
}
|
||||
return logits
|
||||
}
|
||||
|
||||
type Temperature float32
|
||||
|
||||
func (t Temperature) Sample(logits *mlx.Array) *mlx.Array {
|
||||
return logits.Multiply(mlx.FromValue(1 / float32(t))).Categorical(-1)
|
||||
}
|
||||
|
||||
type TopP float32
|
||||
|
||||
func (p TopP) Sample(logprobs *mlx.Array) *mlx.Array {
|
||||
// TODO: implement
|
||||
return logprobs
|
||||
}
|
||||
|
||||
type MinP float32
|
||||
|
||||
func (p MinP) Sample(logprobs *mlx.Array) *mlx.Array {
|
||||
// TODO: implement
|
||||
return logprobs
|
||||
}
|
||||
|
||||
type TopK int
|
||||
|
||||
func (k TopK) Sample(logprobs *mlx.Array) *mlx.Array {
|
||||
mask := logprobs.Negative().ArgpartitionAxis(int(k)-1, -1).Slice(mlx.Slice(), mlx.Slice(int(k), 0))
|
||||
return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
|
||||
}
|
||||
176
x/mlxrunner/server.go
Normal file
176
x/mlxrunner/server.go
Normal file
@@ -0,0 +1,176 @@
|
||||
//go:build mlx
|
||||
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/sample"
|
||||
)
|
||||
|
||||
func Execute(args []string) error {
|
||||
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
|
||||
|
||||
var (
|
||||
modelName string
|
||||
port int
|
||||
)
|
||||
|
||||
flagSet := flag.NewFlagSet("mlxrunner", flag.ExitOnError)
|
||||
flagSet.StringVar(&modelName, "model", "", "Model name")
|
||||
flagSet.IntVar(&port, "port", 0, "Port to listen on")
|
||||
_ = flagSet.Bool("verbose", false, "Enable debug logging")
|
||||
flagSet.Parse(args)
|
||||
|
||||
runner := Runner{
|
||||
Requests: make(chan Request),
|
||||
CacheEntries: make(map[int32]*CacheEntry),
|
||||
}
|
||||
|
||||
if err := runner.Load(modelName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": 0,
|
||||
"progress": 100,
|
||||
}); err != nil {
|
||||
slog.Error("Failed to encode response", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
mux.HandleFunc("/v1/models", func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case "POST":
|
||||
fallthrough
|
||||
case "GET":
|
||||
if err := json.NewEncoder(w).Encode(map[string]any{
|
||||
"Success": true,
|
||||
}); err != nil {
|
||||
slog.Error("Failed to encode response", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
case "DELETE":
|
||||
// TODO: cleanup model and cache
|
||||
}
|
||||
})
|
||||
|
||||
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
request := Request{Responses: make(chan Response)}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
|
||||
slog.Error("Failed to decode request", "error", err)
|
||||
http.Error(w, "Bad Request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
|
||||
if request.Options.MaxTokens < 1 {
|
||||
request.Options.MaxTokens = 16 << 10
|
||||
}
|
||||
|
||||
request.Pipeline = runner.TextGenerationPipeline
|
||||
request.Sampler = sample.New(
|
||||
request.Options.Temperature,
|
||||
request.Options.TopP,
|
||||
request.Options.MinP,
|
||||
request.Options.TopK,
|
||||
)
|
||||
|
||||
runner.Requests <- request
|
||||
|
||||
w.Header().Set("Content-Type", "application/jsonl")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
enc := json.NewEncoder(w)
|
||||
for response := range request.Responses {
|
||||
if err := enc.Encode(response); err != nil {
|
||||
slog.Error("Failed to encode response", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
mux.HandleFunc("POST /v1/tokenize", func(w http.ResponseWriter, r *http.Request) {
|
||||
var b bytes.Buffer
|
||||
if _, err := io.Copy(&b, r.Body); err != nil {
|
||||
slog.Error("Failed to read request body", "error", err)
|
||||
http.Error(w, "Bad Request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
tokens := runner.Tokenizer.Encode(b.String(), true)
|
||||
|
||||
if err := json.NewEncoder(w).Encode(tokens); err != nil {
|
||||
slog.Error("Failed to encode response", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
for source, target := range map[string]string{
|
||||
"GET /health": "/v1/status",
|
||||
"POST /load": "/v1/models",
|
||||
"POST /completion": "/v1/completions",
|
||||
} {
|
||||
mux.Handle(source, http.RedirectHandler(target, http.StatusPermanentRedirect))
|
||||
}
|
||||
|
||||
return runner.Run("127.0.0.1", strconv.Itoa(port), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
recorder := &statusRecorder{ResponseWriter: w, code: http.StatusOK}
|
||||
t := time.Now()
|
||||
mux.ServeHTTP(recorder, r)
|
||||
|
||||
var level slog.Level
|
||||
switch {
|
||||
case recorder.code >= 500:
|
||||
level = slog.LevelError
|
||||
case recorder.code >= 400:
|
||||
level = slog.LevelWarn
|
||||
case recorder.code >= 300:
|
||||
return
|
||||
}
|
||||
|
||||
slog.Log(r.Context(), level, "ServeHTTP", "method", r.Method, "path", r.URL.Path, "took", time.Since(t), "status", recorder.Status())
|
||||
}))
|
||||
}
|
||||
|
||||
type statusRecorder struct {
|
||||
http.ResponseWriter
|
||||
code int
|
||||
}
|
||||
|
||||
func (w *statusRecorder) WriteHeader(code int) {
|
||||
w.code = code
|
||||
w.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (w *statusRecorder) Status() string {
|
||||
return strconv.Itoa(w.code) + " " + http.StatusText(w.code)
|
||||
}
|
||||
|
||||
func (w *statusRecorder) Flush() {
|
||||
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
10
x/mlxrunner/server_stub.go
Normal file
10
x/mlxrunner/server_stub.go
Normal file
@@ -0,0 +1,10 @@
|
||||
//go:build !mlx
|
||||
|
||||
package mlxrunner
|
||||
|
||||
import "errors"
|
||||
|
||||
// Execute returns an error when not built with MLX support.
|
||||
func Execute(args []string) error {
|
||||
return errors.New("MLX runner not available: build with mlx tag")
|
||||
}
|
||||
860
x/models/glm4_moe_lite/glm4_moe_lite.go
Normal file
860
x/models/glm4_moe_lite/glm4_moe_lite.go
Normal file
@@ -0,0 +1,860 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package glm4_moe_lite provides the GLM4-MoE-Lite implementation for MLX.
|
||||
// This model uses Multi-head Latent Attention (MLA) and Mixture of Experts (MoE).
|
||||
package glm4_moe_lite
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
)
|
||||
|
||||
// RopeScaling holds RoPE scaling configuration
|
||||
type RopeScaling struct {
|
||||
Factor float32 `json:"factor"`
|
||||
MscaleAllDim float32 `json:"mscale_all_dim"`
|
||||
}
|
||||
|
||||
// Config holds GLM4-MoE-Lite model configuration
|
||||
type Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
MoEIntermediateSize int32 `json:"moe_intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||
AttentionBias bool `json:"attention_bias"`
|
||||
|
||||
// MLA (Multi-head Latent Attention) parameters
|
||||
QLoraRank int32 `json:"q_lora_rank"`
|
||||
KVLoraRank int32 `json:"kv_lora_rank"`
|
||||
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
|
||||
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
|
||||
VHeadDim int32 `json:"v_head_dim"`
|
||||
|
||||
// MoE parameters
|
||||
NRoutedExperts int32 `json:"n_routed_experts"`
|
||||
NSharedExperts int32 `json:"n_shared_experts"`
|
||||
NumExpertsPerTok int32 `json:"num_experts_per_tok"`
|
||||
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
|
||||
NormTopKProb bool `json:"norm_topk_prob"`
|
||||
FirstKDenseReplace int32 `json:"first_k_dense_replace"`
|
||||
NGroup int32 `json:"n_group"`
|
||||
TopKGroup int32 `json:"topk_group"`
|
||||
|
||||
// RoPE scaling
|
||||
RopeScaling *RopeScaling `json:"rope_scaling"`
|
||||
|
||||
// Quantization parameters (set during load based on model quantization)
|
||||
QuantGroupSize int `json:"-"` // Group size for quantization (default 64)
|
||||
QuantBits int `json:"-"` // Bits per weight (4 or 8)
|
||||
QuantMode string `json:"-"` // Quantization mode ("affine", etc.)
|
||||
|
||||
// Computed fields
|
||||
QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim
|
||||
Scale float32 `json:"-"` // 1/sqrt(QHeadDim) with mscale adjustment
|
||||
}
|
||||
|
||||
// MLAAttention implements Multi-head Latent Attention with absorption.
|
||||
type MLAAttention struct {
|
||||
QAProj nn.LinearLayer
|
||||
QALayerNorm *nn.RMSNorm
|
||||
QBProj nn.LinearLayer
|
||||
|
||||
KVAProjWithMQA nn.LinearLayer
|
||||
KVALayerNorm *nn.RMSNorm
|
||||
|
||||
EmbedQ *nn.MultiLinear
|
||||
UnembedOut *nn.MultiLinear
|
||||
|
||||
OProj nn.LinearLayer
|
||||
}
|
||||
|
||||
// Forward computes absorbed MLA attention output.
|
||||
func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
q := a.QAProj.Forward(x)
|
||||
q = a.QALayerNorm.Forward(q, cfg.RMSNormEps)
|
||||
q = a.QBProj.Forward(q)
|
||||
|
||||
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.QHeadDim)
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
|
||||
qNope := mlx.SliceStartStop(q, []int32{0, 0, 0, 0}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim})
|
||||
qPE := mlx.SliceStartStop(q, []int32{0, 0, 0, cfg.QKNopeHeadDim}, []int32{B, cfg.NumAttentionHeads, L, cfg.QHeadDim})
|
||||
|
||||
compressedKV := a.KVAProjWithMQA.Forward(x)
|
||||
|
||||
kvCompressed := mlx.SliceStartStop(compressedKV, []int32{0, 0, 0}, []int32{B, L, cfg.KVLoraRank})
|
||||
kPE := mlx.SliceStartStop(compressedKV, []int32{0, 0, cfg.KVLoraRank}, []int32{B, L, cfg.KVLoraRank + cfg.QKRopeHeadDim})
|
||||
|
||||
kPE = mlx.Reshape(kPE, B, L, 1, cfg.QKRopeHeadDim)
|
||||
kPE = mlx.Transpose(kPE, 0, 2, 1, 3)
|
||||
|
||||
kvLatent := a.KVALayerNorm.Forward(kvCompressed, cfg.RMSNormEps)
|
||||
kvLatent = mlx.ExpandDims(kvLatent, 1)
|
||||
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
}
|
||||
qPE = mlx.RoPEWithBase(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
|
||||
kPE = mlx.RoPEWithBase(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
|
||||
|
||||
qLatent := a.EmbedQ.Forward(qNope)
|
||||
|
||||
keys := mlx.Concatenate([]*mlx.Array{kvLatent, kPE}, 3)
|
||||
|
||||
cachedL := L
|
||||
if c != nil {
|
||||
placeholderValues := mlx.ZerosF32([]int32{B, 1, L, 0})
|
||||
keys, _ = c.Update(keys, placeholderValues)
|
||||
cachedL = int32(keys.Dim(2))
|
||||
}
|
||||
|
||||
values := mlx.SliceStartStop(keys, []int32{0, 0, 0, 0}, []int32{B, 1, cachedL, cfg.KVLoraRank})
|
||||
|
||||
queries := mlx.Concatenate([]*mlx.Array{qLatent, qPE}, 3)
|
||||
|
||||
out := mlx.ScaledDotProductAttentionCausal(queries, keys, values, cfg.Scale, L > 1)
|
||||
|
||||
out = a.UnembedOut.Forward(out)
|
||||
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.VHeadDim)
|
||||
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
// DenseMLP implements the standard SwiGLU MLP for dense layers
|
||||
type DenseMLP struct {
|
||||
GateProj nn.LinearLayer
|
||||
UpProj nn.LinearLayer
|
||||
DownProj nn.LinearLayer
|
||||
}
|
||||
|
||||
// Forward applies the SwiGLU MLP
|
||||
func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := mlx.SiLU(m.GateProj.Forward(x))
|
||||
up := m.UpProj.Forward(x)
|
||||
return m.DownProj.Forward(mlx.Mul(gate, up))
|
||||
}
|
||||
|
||||
// MoEGate implements the expert gating mechanism
|
||||
type MoEGate struct {
|
||||
Gate nn.LinearLayer
|
||||
EScoreCorrectionBias *mlx.Array
|
||||
}
|
||||
|
||||
// Forward computes expert selection indices and scores
|
||||
func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) {
|
||||
gates := g.Gate.Forward(x)
|
||||
|
||||
scores := mlx.Sigmoid(gates)
|
||||
origScores := scores
|
||||
|
||||
if g.EScoreCorrectionBias != nil {
|
||||
scores = mlx.Add(scores, g.EScoreCorrectionBias)
|
||||
}
|
||||
|
||||
topK := cfg.NumExpertsPerTok
|
||||
negScores := mlx.Neg(scores)
|
||||
inds := mlx.Argpartition(negScores, int(topK)-1, -1)
|
||||
|
||||
dims := inds.Dims()
|
||||
inds = mlx.SliceStartStop(inds, []int32{0, 0, 0}, []int32{int32(dims[0]), int32(dims[1]), topK})
|
||||
|
||||
scores = mlx.TakeAlongAxis(origScores, inds, -1)
|
||||
|
||||
if topK > 1 && cfg.NormTopKProb {
|
||||
sumScores := mlx.Sum(scores, -1, true)
|
||||
scores = mlx.Div(scores, sumScores)
|
||||
}
|
||||
|
||||
scores = mlx.MulScalar(scores, cfg.RoutedScalingFactor)
|
||||
|
||||
return inds, scores
|
||||
}
|
||||
|
||||
// SwitchMLP implements the MoE expert computation using stacked weights
|
||||
type SwitchMLP struct {
|
||||
GateWeight *mlx.Array
|
||||
UpWeight *mlx.Array
|
||||
DownWeight *mlx.Array
|
||||
|
||||
GateWeightQ, GateScales, GateBiases *mlx.Array
|
||||
UpWeightQ, UpScales, UpBiases *mlx.Array
|
||||
DownWeightQ, DownScales, DownBiases *mlx.Array
|
||||
|
||||
GateBits int
|
||||
UpBits int
|
||||
DownBits int
|
||||
|
||||
GateGroupSize int
|
||||
UpGroupSize int
|
||||
DownGroupSize int
|
||||
|
||||
UseQuantized bool
|
||||
}
|
||||
|
||||
// Forward applies the switched expert MLP
|
||||
func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array {
|
||||
dims := x.Dims()
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
topK := cfg.NumExpertsPerTok
|
||||
|
||||
xExpanded := mlx.ExpandDims(mlx.ExpandDims(x, -2), -2)
|
||||
|
||||
xFlat := mlx.Reshape(xExpanded, B*L, 1, 1, cfg.HiddenSize)
|
||||
|
||||
idxFlat := mlx.Reshape(indices, B*L, topK)
|
||||
|
||||
doSort := B*L >= 64
|
||||
var invOrder *mlx.Array
|
||||
n := B * L * topK
|
||||
|
||||
if doSort {
|
||||
idxAll := mlx.Flatten(idxFlat)
|
||||
order := mlx.Argsort(idxAll, 0)
|
||||
invOrder = mlx.Argsort(order, 0)
|
||||
xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, topK), 0), 1)
|
||||
idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
|
||||
}
|
||||
|
||||
var gate, up, hidden, down *mlx.Array
|
||||
|
||||
if s.UseQuantized {
|
||||
gate = mlx.GatherQMM(xFlat, s.GateWeightQ, s.GateScales, s.GateBiases,
|
||||
nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort)
|
||||
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
|
||||
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
|
||||
|
||||
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
||||
|
||||
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
|
||||
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
|
||||
} else {
|
||||
gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||
up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||
|
||||
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
||||
|
||||
down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||
}
|
||||
|
||||
if doSort {
|
||||
down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, topK, cfg.HiddenSize)
|
||||
} else {
|
||||
down = mlx.Squeeze(down, 2)
|
||||
}
|
||||
|
||||
return mlx.Reshape(down, B, L, topK, cfg.HiddenSize)
|
||||
}
|
||||
|
||||
// SharedExperts implements the shared expert MLP
|
||||
type SharedExperts struct {
|
||||
GateProj nn.LinearLayer
|
||||
UpProj nn.LinearLayer
|
||||
DownProj nn.LinearLayer
|
||||
}
|
||||
|
||||
// Forward applies the shared expert MLP
|
||||
func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array {
|
||||
gate := mlx.SiLU(s.GateProj.Forward(x))
|
||||
up := s.UpProj.Forward(x)
|
||||
return s.DownProj.Forward(mlx.Mul(gate, up))
|
||||
}
|
||||
|
||||
// MoE implements the full Mixture of Experts layer
|
||||
type MoE struct {
|
||||
Gate *MoEGate
|
||||
SwitchMLP *SwitchMLP
|
||||
SharedExperts *SharedExperts
|
||||
}
|
||||
|
||||
// Forward applies the MoE layer
|
||||
func (m *MoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
|
||||
dims := x.Dims()
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
|
||||
inds, scores := m.Gate.Forward(x, cfg)
|
||||
|
||||
expertOut := m.SwitchMLP.Forward(x, inds, cfg)
|
||||
|
||||
scoresExpanded := mlx.ExpandDims(scores, -1)
|
||||
y := mlx.Sum(mlx.Mul(expertOut, scoresExpanded), 2, false)
|
||||
|
||||
if m.SharedExperts != nil {
|
||||
y = mlx.Add(y, m.SharedExperts.Forward(x))
|
||||
}
|
||||
|
||||
return mlx.Reshape(y, B, L, cfg.HiddenSize)
|
||||
}
|
||||
|
||||
// DenseBlock represents a dense transformer block (for first_k_dense_replace layers)
|
||||
type DenseBlock struct {
|
||||
Attention *MLAAttention
|
||||
MLP *DenseMLP
|
||||
InputLayerNorm *nn.RMSNorm
|
||||
PostAttentionLayerNorm *nn.RMSNorm
|
||||
}
|
||||
|
||||
// Forward applies the dense block
|
||||
func (b *DenseBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
|
||||
h := mlx.Add(x, r)
|
||||
|
||||
r = b.MLP.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps))
|
||||
return mlx.Add(h, r)
|
||||
}
|
||||
|
||||
// MoEBlock represents a MoE transformer block
|
||||
type MoEBlock struct {
|
||||
Attention *MLAAttention
|
||||
MoE *MoE
|
||||
InputLayerNorm *nn.RMSNorm
|
||||
PostAttentionLayerNorm *nn.RMSNorm
|
||||
}
|
||||
|
||||
// Forward applies the MoE block
|
||||
func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
|
||||
h := mlx.Add(x, r)
|
||||
|
||||
r = b.MoE.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg)
|
||||
return mlx.Add(h, r)
|
||||
}
|
||||
|
||||
// Block interface for both dense and MoE blocks
|
||||
type Block interface {
|
||||
Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array
|
||||
}
|
||||
|
||||
// Model represents the complete GLM4-MoE-Lite model
|
||||
type Model struct {
|
||||
EmbedTokens *nn.Embedding
|
||||
Layers []Block
|
||||
Norm *nn.RMSNorm
|
||||
LMHead nn.LinearLayer
|
||||
|
||||
tok *tokenizer.Tokenizer
|
||||
*Config
|
||||
}
|
||||
|
||||
// computeScale computes the attention scale.
|
||||
func computeScale(cfg *Config) float32 {
|
||||
keyLength := cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
|
||||
scale := float32(1.0 / math.Sqrt(float64(keyLength)))
|
||||
if cfg.RopeScaling != nil && cfg.RopeScaling.MscaleAllDim > 0 && cfg.RopeScaling.Factor > 1 {
|
||||
s := 0.1*cfg.RopeScaling.MscaleAllDim*float32(math.Log(float64(cfg.RopeScaling.Factor))) + 1.0
|
||||
scale *= s * s
|
||||
}
|
||||
return scale
|
||||
}
|
||||
|
||||
// supportsGatherQMM returns true if the quantization mode has GatherQMM kernel support.
|
||||
func supportsGatherQMM(mode string, bits int) bool {
|
||||
return mode == "affine" && (bits == 4 || bits == 8)
|
||||
}
|
||||
|
||||
// quantizationParams returns groupSize, bits, mode for a quantization type string.
|
||||
func quantizationParams(quantization string) (groupSize, bits int, mode string) {
|
||||
switch strings.ToUpper(quantization) {
|
||||
case "NVFP4":
|
||||
return 16, 4, "nvfp4"
|
||||
case "FP4", "Q4", "INT4":
|
||||
return 32, 4, "affine"
|
||||
case "MXFP8":
|
||||
return 32, 8, "mxfp8"
|
||||
case "FP8", "Q8", "INT8", "":
|
||||
return 64, 8, "affine"
|
||||
default:
|
||||
return 32, 8, "affine"
|
||||
}
|
||||
}
|
||||
|
||||
// readBlobMetadata reads the __metadata__ from a safetensors blob header.
|
||||
func readBlobMetadata(path string) (map[string]string, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var headerSize uint64
|
||||
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if headerSize > 1024*1024 {
|
||||
return nil, fmt.Errorf("header too large: %d", headerSize)
|
||||
}
|
||||
|
||||
data := make([]byte, headerSize)
|
||||
if _, err := io.ReadFull(f, data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var header map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data, &header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
metaRaw, ok := header["__metadata__"]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var meta map[string]string
|
||||
if err := json.Unmarshal(metaRaw, &meta); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return meta, nil
|
||||
}
|
||||
|
||||
// ExpertWeight holds a single expert's weight with optional quantization components.
|
||||
type ExpertWeight struct {
|
||||
Weight *mlx.Array
|
||||
Scales *mlx.Array
|
||||
Biases *mlx.Array
|
||||
Bits int
|
||||
GroupSize int
|
||||
}
|
||||
|
||||
// loadExpertWeight loads an expert weight from the tensor map.
|
||||
func loadExpertWeight(tensors map[string]*mlx.Array, path string, useQuantized bool, cfg *Config) *ExpertWeight {
|
||||
w := tensors[path+".weight"]
|
||||
if w == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
scales := tensors[path+".weight_scale"]
|
||||
if scales != nil {
|
||||
qbiases := tensors[path+".weight_qbias"]
|
||||
|
||||
groupSize, bits, mode := cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode
|
||||
|
||||
if useQuantized && supportsGatherQMM(mode, bits) {
|
||||
return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize}
|
||||
}
|
||||
|
||||
return &ExpertWeight{Weight: mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)}
|
||||
}
|
||||
|
||||
return &ExpertWeight{Weight: w}
|
||||
}
|
||||
|
||||
// StackedExpertWeights holds stacked weights for all experts.
|
||||
type StackedExpertWeights struct {
|
||||
Weight *mlx.Array
|
||||
Scales *mlx.Array
|
||||
Biases *mlx.Array
|
||||
Bits int
|
||||
GroupSize int
|
||||
}
|
||||
|
||||
// collectAndStackExpertWeights loads and stacks expert weights for one projection type.
|
||||
func collectAndStackExpertWeights(
|
||||
tensors map[string]*mlx.Array,
|
||||
prefix string,
|
||||
projName string,
|
||||
numExperts int32,
|
||||
useQuantized bool,
|
||||
cfg *Config,
|
||||
) *StackedExpertWeights {
|
||||
var w, s, b []*mlx.Array
|
||||
var bits, groupSize int
|
||||
|
||||
for e := int32(0); e < numExperts; e++ {
|
||||
path := fmt.Sprintf("%s.mlp.experts.%d.%s", prefix, e, projName)
|
||||
ew := loadExpertWeight(tensors, path, useQuantized, cfg)
|
||||
if ew == nil {
|
||||
continue
|
||||
}
|
||||
w = append(w, ew.Weight)
|
||||
if ew.Scales != nil {
|
||||
s = append(s, ew.Scales)
|
||||
}
|
||||
if ew.Biases != nil {
|
||||
b = append(b, ew.Biases)
|
||||
}
|
||||
if e == 0 {
|
||||
bits = ew.Bits
|
||||
groupSize = ew.GroupSize
|
||||
}
|
||||
}
|
||||
|
||||
result := &StackedExpertWeights{Bits: bits, GroupSize: groupSize}
|
||||
if len(w) > 0 {
|
||||
result.Weight = mlx.Stack(w, 0)
|
||||
if len(s) > 0 {
|
||||
result.Scales = mlx.Stack(s, 0)
|
||||
}
|
||||
if len(b) > 0 {
|
||||
result.Biases = mlx.Stack(b, 0)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// sanitizeExpertWeights stacks individual expert weights into tensors.
|
||||
func sanitizeExpertWeights(tensors map[string]*mlx.Array, prefix string, numExperts int32, useQuantized bool, cfg *Config) (gate, up, down *StackedExpertWeights) {
|
||||
gate = collectAndStackExpertWeights(tensors, prefix, "gate_proj", numExperts, useQuantized, cfg)
|
||||
up = collectAndStackExpertWeights(tensors, prefix, "up_proj", numExperts, useQuantized, cfg)
|
||||
down = collectAndStackExpertWeights(tensors, prefix, "down_proj", numExperts, useQuantized, cfg)
|
||||
return gate, up, down
|
||||
}
|
||||
|
||||
// sanitizeMLAWeights transforms kv_b_proj weights into absorbed MLA format.
|
||||
func sanitizeMLAWeights(tensors map[string]*mlx.Array, prefix string, cfg *Config) (*mlx.Array, *mlx.Array) {
|
||||
path := prefix + ".self_attn.kv_b_proj"
|
||||
w := tensors[path+".weight"]
|
||||
if w == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Check if quantized and dequantize
|
||||
if scales := tensors[path+".weight_scale"]; scales != nil {
|
||||
qbiases := tensors[path+".weight_qbias"]
|
||||
w = mlx.Dequantize(w, scales, qbiases, cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode)
|
||||
}
|
||||
|
||||
headDim := cfg.QKNopeHeadDim + cfg.VHeadDim
|
||||
w = mlx.Reshape(w, cfg.NumAttentionHeads, headDim, cfg.KVLoraRank)
|
||||
|
||||
wk := mlx.SliceStartStop(w, []int32{0, 0, 0}, []int32{cfg.NumAttentionHeads, cfg.QKNopeHeadDim, cfg.KVLoraRank})
|
||||
wv := mlx.SliceStartStop(w, []int32{0, cfg.QKNopeHeadDim, 0}, []int32{cfg.NumAttentionHeads, headDim, cfg.KVLoraRank})
|
||||
|
||||
embedQ := mlx.Transpose(wk, 0, 2, 1)
|
||||
unembedOut := wv
|
||||
|
||||
return embedQ, unembedOut
|
||||
}
|
||||
|
||||
// makeLinear creates a Linear or QuantizedLinear layer from the tensor map.
|
||||
func makeLinear(tensors map[string]*mlx.Array, path string, cfg *Config) nn.LinearLayer {
|
||||
w := tensors[path+".weight"]
|
||||
if w == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
scales := tensors[path+".weight_scale"]
|
||||
if scales != nil {
|
||||
qbiases := tensors[path+".weight_qbias"]
|
||||
bias := tensors[path+".bias"]
|
||||
return &nn.QuantizedLinear{
|
||||
Weight: w,
|
||||
Scales: scales,
|
||||
QBiases: qbiases,
|
||||
Bias: bias,
|
||||
GroupSize: cfg.QuantGroupSize,
|
||||
Bits: cfg.QuantBits,
|
||||
Mode: cfg.QuantMode,
|
||||
}
|
||||
}
|
||||
|
||||
bias := tensors[path+".bias"]
|
||||
return nn.NewLinear(w, bias)
|
||||
}
|
||||
|
||||
// LoadFromManifest loads a GLM4-MoE-Lite model from a manifest (Ollama blob storage).
|
||||
func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
|
||||
configData, err := modelManifest.ReadConfig("config.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(configData, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
|
||||
cfg.Scale = computeScale(&cfg)
|
||||
|
||||
// Load all tensors from manifest blobs into a flat map
|
||||
allTensors := make(map[string]*mlx.Array)
|
||||
seen := make(map[string]bool) // dedupe by digest
|
||||
var quantType string
|
||||
var quantGroupSize int
|
||||
|
||||
for _, layer := range modelManifest.GetTensorLayers("") {
|
||||
if seen[layer.Digest] {
|
||||
continue
|
||||
}
|
||||
seen[layer.Digest] = true
|
||||
blobPath := modelManifest.BlobPath(layer.Digest)
|
||||
|
||||
// Read quantization metadata from first blob
|
||||
if quantType == "" {
|
||||
if meta, err := readBlobMetadata(blobPath); err == nil && meta != nil {
|
||||
if qt := meta["quant_type"]; qt != "" {
|
||||
quantType = strings.ToUpper(qt)
|
||||
}
|
||||
if gs := meta["group_size"]; gs != "" {
|
||||
fmt.Sscanf(gs, "%d", &quantGroupSize)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for name, arr := range mlx.Load(blobPath) {
|
||||
// Map safetensors key naming to our naming convention
|
||||
// Combined blobs use ".scale" and ".bias" suffixes
|
||||
if strings.HasSuffix(name, ".scale") {
|
||||
baseName := strings.TrimSuffix(name, ".scale")
|
||||
allTensors[baseName+"_scale"] = arr
|
||||
} else if strings.HasSuffix(name, ".bias") && !strings.HasSuffix(name, ".weight_qbias") {
|
||||
// Check if this is a quantization bias or a regular bias
|
||||
// by checking if there's a corresponding weight
|
||||
baseName := strings.TrimSuffix(name, ".bias")
|
||||
if _, hasScale := allTensors[baseName+"_scale"]; hasScale {
|
||||
allTensors[baseName+"_qbias"] = arr
|
||||
} else {
|
||||
allTensors[name] = arr
|
||||
}
|
||||
} else {
|
||||
allTensors[name] = arr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set up quantization parameters
|
||||
useQuantized := false
|
||||
if quantType != "" {
|
||||
_, cfg.QuantBits, cfg.QuantMode = quantizationParams(quantType)
|
||||
if quantGroupSize > 0 {
|
||||
cfg.QuantGroupSize = quantGroupSize
|
||||
} else {
|
||||
cfg.QuantGroupSize, _, _ = quantizationParams(quantType)
|
||||
}
|
||||
useQuantized = supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
|
||||
}
|
||||
|
||||
// Load tokenizer
|
||||
tokData, err := modelManifest.ReadConfig("tokenizer.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load tokenizer config: %w", err)
|
||||
}
|
||||
|
||||
tokConfig := &tokenizer.TokenizerConfig{
|
||||
ConfigJSON: configData,
|
||||
}
|
||||
|
||||
if genConfigData, err := modelManifest.ReadConfig("generation_config.json"); err == nil {
|
||||
tokConfig.GenerationConfigJSON = genConfigData
|
||||
}
|
||||
|
||||
if tokConfigData, err := modelManifest.ReadConfig("tokenizer_config.json"); err == nil {
|
||||
tokConfig.TokenizerConfigJSON = tokConfigData
|
||||
}
|
||||
|
||||
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse tokenizer: %w", err)
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
Layers: make([]Block, cfg.NumHiddenLayers),
|
||||
Config: &cfg,
|
||||
tok: tok,
|
||||
}
|
||||
|
||||
// Load embedding
|
||||
if w := allTensors["model.embed_tokens.weight"]; w != nil {
|
||||
m.EmbedTokens = nn.NewEmbedding(w)
|
||||
}
|
||||
|
||||
// Load final norm
|
||||
if w := allTensors["model.norm.weight"]; w != nil {
|
||||
m.Norm = nn.NewRMSNorm(w, cfg.RMSNormEps)
|
||||
}
|
||||
|
||||
// Load LM head
|
||||
m.LMHead = makeLinear(allTensors, "lm_head", &cfg)
|
||||
|
||||
// Load layers
|
||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
prefix := fmt.Sprintf("model.layers.%d", i)
|
||||
|
||||
// Load attention (same for both block types)
|
||||
attn := &MLAAttention{}
|
||||
attn.QAProj = makeLinear(allTensors, prefix+".self_attn.q_a_proj", &cfg)
|
||||
if w := allTensors[prefix+".self_attn.q_a_layernorm.weight"]; w != nil {
|
||||
attn.QALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
|
||||
}
|
||||
attn.QBProj = makeLinear(allTensors, prefix+".self_attn.q_b_proj", &cfg)
|
||||
attn.KVAProjWithMQA = makeLinear(allTensors, prefix+".self_attn.kv_a_proj_with_mqa", &cfg)
|
||||
if w := allTensors[prefix+".self_attn.kv_a_layernorm.weight"]; w != nil {
|
||||
attn.KVALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
|
||||
}
|
||||
attn.OProj = makeLinear(allTensors, prefix+".self_attn.o_proj", &cfg)
|
||||
|
||||
// Sanitize MLA weights for absorbed attention
|
||||
embedQ, unembedOut := sanitizeMLAWeights(allTensors, prefix, &cfg)
|
||||
attn.EmbedQ = nn.NewMultiLinear(embedQ)
|
||||
attn.UnembedOut = nn.NewMultiLinear(unembedOut)
|
||||
|
||||
inputLN := allTensors[prefix+".input_layernorm.weight"]
|
||||
postAttnLN := allTensors[prefix+".post_attention_layernorm.weight"]
|
||||
|
||||
if i < cfg.FirstKDenseReplace {
|
||||
// Dense block
|
||||
block := &DenseBlock{Attention: attn}
|
||||
if inputLN != nil {
|
||||
block.InputLayerNorm = nn.NewRMSNorm(inputLN, cfg.RMSNormEps)
|
||||
}
|
||||
if postAttnLN != nil {
|
||||
block.PostAttentionLayerNorm = nn.NewRMSNorm(postAttnLN, cfg.RMSNormEps)
|
||||
}
|
||||
|
||||
block.MLP = &DenseMLP{
|
||||
GateProj: makeLinear(allTensors, prefix+".mlp.gate_proj", &cfg),
|
||||
UpProj: makeLinear(allTensors, prefix+".mlp.up_proj", &cfg),
|
||||
DownProj: makeLinear(allTensors, prefix+".mlp.down_proj", &cfg),
|
||||
}
|
||||
|
||||
m.Layers[i] = block
|
||||
} else {
|
||||
// MoE block
|
||||
block := &MoEBlock{Attention: attn}
|
||||
if inputLN != nil {
|
||||
block.InputLayerNorm = nn.NewRMSNorm(inputLN, cfg.RMSNormEps)
|
||||
}
|
||||
if postAttnLN != nil {
|
||||
block.PostAttentionLayerNorm = nn.NewRMSNorm(postAttnLN, cfg.RMSNormEps)
|
||||
}
|
||||
|
||||
// Stack expert weights
|
||||
gate, up, down := sanitizeExpertWeights(allTensors, prefix, cfg.NRoutedExperts, useQuantized, &cfg)
|
||||
|
||||
switchMLP := &SwitchMLP{UseQuantized: useQuantized}
|
||||
if useQuantized {
|
||||
switchMLP.GateWeightQ = gate.Weight
|
||||
switchMLP.GateScales = gate.Scales
|
||||
switchMLP.GateBiases = gate.Biases
|
||||
switchMLP.GateBits = gate.Bits
|
||||
switchMLP.GateGroupSize = gate.GroupSize
|
||||
switchMLP.UpWeightQ = up.Weight
|
||||
switchMLP.UpScales = up.Scales
|
||||
switchMLP.UpBiases = up.Biases
|
||||
switchMLP.UpBits = up.Bits
|
||||
switchMLP.UpGroupSize = up.GroupSize
|
||||
switchMLP.DownWeightQ = down.Weight
|
||||
switchMLP.DownScales = down.Scales
|
||||
switchMLP.DownBiases = down.Biases
|
||||
switchMLP.DownBits = down.Bits
|
||||
switchMLP.DownGroupSize = down.GroupSize
|
||||
} else {
|
||||
switchMLP.GateWeight = gate.Weight
|
||||
switchMLP.UpWeight = up.Weight
|
||||
switchMLP.DownWeight = down.Weight
|
||||
}
|
||||
|
||||
moeGate := &MoEGate{}
|
||||
moeGate.Gate = makeLinear(allTensors, prefix+".mlp.gate", &cfg)
|
||||
if bias := allTensors[prefix+".mlp.gate.e_score_correction_bias"]; bias != nil {
|
||||
moeGate.EScoreCorrectionBias = bias
|
||||
}
|
||||
|
||||
block.MoE = &MoE{
|
||||
Gate: moeGate,
|
||||
SwitchMLP: switchMLP,
|
||||
}
|
||||
|
||||
// Load shared experts if present
|
||||
if cfg.NSharedExperts > 0 {
|
||||
block.MoE.SharedExperts = &SharedExperts{
|
||||
GateProj: makeLinear(allTensors, prefix+".mlp.shared_experts.gate_proj", &cfg),
|
||||
UpProj: makeLinear(allTensors, prefix+".mlp.shared_experts.up_proj", &cfg),
|
||||
DownProj: makeLinear(allTensors, prefix+".mlp.shared_experts.down_proj", &cfg),
|
||||
}
|
||||
}
|
||||
|
||||
m.Layers[i] = block
|
||||
}
|
||||
}
|
||||
|
||||
mlx.Eval(mlx.Collect(m)...)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Forward computes the forward pass of the model
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
dims := tokens.Dims()
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
if caches != nil {
|
||||
c = caches[i]
|
||||
}
|
||||
h = layer.Forward(h, c, B, L, m.Config)
|
||||
}
|
||||
|
||||
h = m.Norm.Forward(h, m.RMSNormEps)
|
||||
return h
|
||||
}
|
||||
|
||||
// Unembed applies the LM head to get logits.
|
||||
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
||||
return m.LMHead.Forward(x)
|
||||
}
|
||||
|
||||
// NumLayers returns the number of transformer layers
|
||||
func (m *Model) NumLayers() int { return len(m.Layers) }
|
||||
|
||||
// MaxContextLength returns the maximum context length
|
||||
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
|
||||
|
||||
// VocabSize returns the vocabulary size
|
||||
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
||||
|
||||
// Tokenizer returns the model's tokenizer
|
||||
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
|
||||
|
||||
// NewCache creates a new KV cache for the model
|
||||
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
// FormatPrompt applies the GLM-4 chat template with thinking enabled by default.
|
||||
func (m *Model) FormatPrompt(prompt string) string {
|
||||
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|><think>"
|
||||
}
|
||||
|
||||
// FormatPromptWithThinking applies the GLM-4 chat template with explicit thinking control.
|
||||
func (m *Model) FormatPromptWithThinking(prompt string, think bool) string {
|
||||
if think {
|
||||
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|><think>"
|
||||
}
|
||||
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|></think>"
|
||||
}
|
||||
|
||||
// NewRenderer returns a new Renderer for formatting multi-turn conversations.
|
||||
func (m *Model) NewRenderer() *Renderer {
|
||||
return &Renderer{}
|
||||
}
|
||||
|
||||
// NewParser returns a new Parser for extracting thinking and tool calls from output.
|
||||
func (m *Model) NewParser() *Parser {
|
||||
return &Parser{}
|
||||
}
|
||||
479
x/models/glm4_moe_lite/parser.go
Normal file
479
x/models/glm4_moe_lite/parser.go
Normal file
@@ -0,0 +1,479 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm4_moe_lite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type parserState int
|
||||
|
||||
const (
|
||||
parserState_LookingForThinkingOpen parserState = iota
|
||||
parserState_ThinkingStartedEatingWhitespace
|
||||
parserState_CollectingThinking
|
||||
parserState_ThinkingDoneEatingWhitespace
|
||||
parserState_CollectingContent
|
||||
parserState_ToolStartedEatingWhitespace
|
||||
parserState_CollectingToolContent
|
||||
)
|
||||
|
||||
const (
|
||||
thinkingOpenTag = "<think>"
|
||||
thinkingCloseTag = "</think>"
|
||||
toolOpenTag = "<tool_call>"
|
||||
toolCloseTag = "</tool_call>"
|
||||
)
|
||||
|
||||
// Parser parses GLM4-MoE-Lite model output to extract thinking and tool calls.
|
||||
// GLM-4's prompt ends with <think> when thinking is enabled, so the parser
|
||||
// must start in CollectingThinking state (the model outputs thinking content directly).
|
||||
type Parser struct {
|
||||
state parserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
}
|
||||
|
||||
// HasToolSupport returns true as GLM4 supports tool calling.
|
||||
func (p *Parser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// HasThinkingSupport returns true as GLM4 supports thinking mode.
|
||||
func (p *Parser) HasThinkingSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Init initializes the parser with tools and thinking configuration.
|
||||
func (p *Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
// When thinking is enabled (nil or true), the prompt ends with <think>,
|
||||
// so model output starts directly with thinking content (no opening tag).
|
||||
if thinkValue == nil || thinkValue.Bool() {
|
||||
p.state = parserState_CollectingThinking
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
type parserEvent interface {
|
||||
isParserEvent()
|
||||
}
|
||||
|
||||
type eventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (eventContent) isParserEvent() {}
|
||||
|
||||
type eventRawToolCall struct {
|
||||
raw string
|
||||
}
|
||||
|
||||
func (eventRawToolCall) isParserEvent() {}
|
||||
|
||||
type eventThinkingContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (eventThinkingContent) isParserEvent() {}
|
||||
|
||||
// Add processes new output text and returns parsed content, thinking, and tool calls.
|
||||
func (p *Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case eventRawToolCall:
|
||||
toolCall, err := parseToolCall(event, p.tools)
|
||||
if err != nil {
|
||||
slog.Warn("glm-4 tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
case eventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
case eventContent:
|
||||
contentSb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *Parser) parseEvents() []parserEvent {
|
||||
var all []parserEvent
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []parserEvent
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(all) > 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "glm-4 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
// eatLeadingWhitespaceAndTransitionTo consumes leading whitespace from the buffer
|
||||
// and transitions to the next state. Returns (nil, false) if only whitespace remains
|
||||
// in the buffer (needs more input), or (nil, true) if we successfully transitioned.
|
||||
func (p *Parser) eatLeadingWhitespaceAndTransitionTo(nextState parserState) ([]parserEvent, bool) {
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
if trimmed == "" {
|
||||
return nil, false // Still only whitespace, keep waiting for more input
|
||||
}
|
||||
p.state = nextState
|
||||
p.buffer.WriteString(trimmed)
|
||||
return nil, true // Successfully transitioned
|
||||
}
|
||||
|
||||
// splitAtTag splits the buffer at the given tag, returns the content before (trimmed of trailing whitespace),
|
||||
// the content after (optionally trimmed of leading whitespace), and updates the buffer
|
||||
func (p *Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
|
||||
split := strings.SplitN(p.buffer.String(), tag, 2)
|
||||
before := split[0]
|
||||
before = strings.TrimRightFunc(before, unicode.IsSpace)
|
||||
after := split[1]
|
||||
if trimAfter {
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
}
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
return before, after
|
||||
}
|
||||
|
||||
func (p *Parser) eat() ([]parserEvent, bool) {
|
||||
var events []parserEvent
|
||||
|
||||
switch p.state {
|
||||
case parserState_LookingForThinkingOpen:
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
if strings.HasPrefix(trimmed, thinkingOpenTag) {
|
||||
// Found <think> opening tag
|
||||
after := strings.TrimPrefix(trimmed, thinkingOpenTag)
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
if after == "" {
|
||||
p.state = parserState_ThinkingStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = parserState_CollectingThinking
|
||||
}
|
||||
return events, true
|
||||
} else if strings.HasPrefix(thinkingOpenTag, trimmed) {
|
||||
// Partial opening tag seen, keep accumulating
|
||||
return events, false
|
||||
} else if trimmed == "" {
|
||||
// Only whitespace, keep accumulating
|
||||
return events, false
|
||||
} else {
|
||||
// No thinking tag found, skip to content collection
|
||||
p.state = parserState_CollectingContent
|
||||
// Don't trim - we want to keep the original content
|
||||
return events, true
|
||||
}
|
||||
|
||||
case parserState_ThinkingStartedEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingThinking)
|
||||
|
||||
case parserState_CollectingThinking:
|
||||
acc := p.buffer.String()
|
||||
if strings.Contains(acc, thinkingCloseTag) {
|
||||
thinking, remaining := p.splitAtTag(thinkingCloseTag, true)
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, eventThinkingContent{content: thinking})
|
||||
}
|
||||
if remaining == "" {
|
||||
p.state = parserState_ThinkingDoneEatingWhitespace
|
||||
} else {
|
||||
p.state = parserState_CollectingContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(acc, thinkingCloseTag); overlapLen > 0 {
|
||||
// Partial closing tag - withhold it along with any trailing whitespace before it
|
||||
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, eventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
// Pure thinking content - withhold trailing whitespace (might precede closing tag)
|
||||
whitespaceLen := trailingWhitespaceLen(acc)
|
||||
ambiguousStart := len(acc) - whitespaceLen
|
||||
|
||||
unambiguous := acc[:ambiguousStart]
|
||||
ambiguous := acc[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, eventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case parserState_ThinkingDoneEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingContent)
|
||||
|
||||
case parserState_CollectingContent:
|
||||
if strings.Contains(p.buffer.String(), toolOpenTag) {
|
||||
before, after := p.splitAtTag(toolOpenTag, true)
|
||||
if len(before) > 0 {
|
||||
events = append(events, eventContent{content: before})
|
||||
}
|
||||
if after == "" {
|
||||
p.state = parserState_ToolStartedEatingWhitespace
|
||||
} else {
|
||||
p.state = parserState_CollectingToolContent
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(p.buffer.String(), toolOpenTag); overlapLen > 0 {
|
||||
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
|
||||
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, eventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
|
||||
ambiguousStart := len(p.buffer.String()) - whitespaceLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, eventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case parserState_ToolStartedEatingWhitespace:
|
||||
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingToolContent)
|
||||
|
||||
case parserState_CollectingToolContent:
|
||||
acc := p.buffer.String()
|
||||
if strings.Contains(acc, toolCloseTag) {
|
||||
toolContent, _ := p.splitAtTag(toolCloseTag, true)
|
||||
if len(toolContent) == 0 {
|
||||
slog.Warn("glm4 tool call closing tag found but no content before it")
|
||||
}
|
||||
events = append(events, eventRawToolCall{raw: toolContent})
|
||||
p.state = parserState_CollectingContent
|
||||
return events, true
|
||||
} else {
|
||||
// Keep accumulating - tool calls are not streamed
|
||||
// We just wait for the closing tag
|
||||
return events, false
|
||||
}
|
||||
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
// overlap returns the length of the overlap between the end of s and the start of tag.
|
||||
func overlap(s, tag string) int {
|
||||
for i := 1; i <= len(tag) && i <= len(s); i++ {
|
||||
if strings.HasSuffix(s, tag[:i]) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// trailingWhitespaceLen returns the length of trailing whitespace in s.
|
||||
func trailingWhitespaceLen(s string) int {
|
||||
trimmed := strings.TrimRightFunc(s, unicode.IsSpace)
|
||||
return len(s) - len(trimmed)
|
||||
}
|
||||
|
||||
// ToolCallXML represents the structure of a GLM-4 tool call for XML parsing
|
||||
type ToolCallXML struct {
|
||||
XMLName xml.Name `xml:"tool_call"`
|
||||
Content string `xml:",chardata"` // Function name (text nodes between tags)
|
||||
Keys []string `xml:"arg_key"` // All arg_key elements in document order
|
||||
Values []string `xml:"arg_value"` // All arg_value elements in document order
|
||||
}
|
||||
|
||||
// escapeContent escapes XML entities in text content while preserving arg_key/arg_value tags
|
||||
func escapeContent(s string) string {
|
||||
var result strings.Builder
|
||||
inTag := false
|
||||
|
||||
for i := range len(s) {
|
||||
ch := s[i]
|
||||
|
||||
if ch == '<' {
|
||||
// Check if this is a known tag
|
||||
if strings.HasPrefix(s[i:], "<arg_key>") ||
|
||||
strings.HasPrefix(s[i:], "</arg_key>") ||
|
||||
strings.HasPrefix(s[i:], "<arg_value>") ||
|
||||
strings.HasPrefix(s[i:], "</arg_value>") {
|
||||
inTag = true
|
||||
}
|
||||
}
|
||||
|
||||
if inTag {
|
||||
result.WriteByte(ch)
|
||||
if ch == '>' {
|
||||
inTag = false
|
||||
}
|
||||
} else {
|
||||
// Escape special characters in text content
|
||||
switch ch {
|
||||
case '&':
|
||||
result.WriteString("&")
|
||||
case '<':
|
||||
result.WriteString("<")
|
||||
case '>':
|
||||
result.WriteString(">")
|
||||
default:
|
||||
result.WriteByte(ch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
func parseToolCall(raw eventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
|
||||
// Escape any unescaped entities in text content
|
||||
escaped := escapeContent(raw.raw)
|
||||
|
||||
// Wrap the content in a root element to make it valid XML
|
||||
xmlString := "<tool_call>" + escaped + "</tool_call>"
|
||||
|
||||
// Parse XML into struct
|
||||
var parsed ToolCallXML
|
||||
if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
|
||||
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
|
||||
}
|
||||
|
||||
// Extract and trim function name
|
||||
functionName := strings.TrimSpace(parsed.Content)
|
||||
if functionName == "" {
|
||||
return api.ToolCall{}, fmt.Errorf("empty function name")
|
||||
}
|
||||
|
||||
// Verify keys and values are paired correctly
|
||||
if len(parsed.Keys) != len(parsed.Values) {
|
||||
return api.ToolCall{}, fmt.Errorf("mismatched arg_key and arg_value counts: %d keys, %d values", len(parsed.Keys), len(parsed.Values))
|
||||
}
|
||||
|
||||
// Find the matching tool to get parameter types
|
||||
var matchedTool *api.Tool
|
||||
for i := range tools {
|
||||
if tools[i].Function.Name == functionName {
|
||||
matchedTool = &tools[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Build arguments map by pairing keys and values
|
||||
toolCall := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: functionName,
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
}
|
||||
|
||||
for i := range parsed.Keys {
|
||||
key := strings.TrimSpace(parsed.Keys[i])
|
||||
value := parsed.Values[i] // Don't trim here - parseValue handles it
|
||||
|
||||
// Look up parameter type
|
||||
var paramType api.PropertyType
|
||||
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
|
||||
if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok {
|
||||
// Handle anyOf by collecting all types from the union
|
||||
if len(prop.AnyOf) > 0 {
|
||||
for _, anyOfProp := range prop.AnyOf {
|
||||
paramType = append(paramType, anyOfProp.Type...)
|
||||
}
|
||||
} else {
|
||||
paramType = prop.Type
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse value with type coercion
|
||||
toolCall.Function.Arguments.Set(key, parseValue(value, paramType))
|
||||
}
|
||||
|
||||
return toolCall, nil
|
||||
}
|
||||
|
||||
// parseValue parses a string value and coerces it to the appropriate type based on paramType.
|
||||
func parseValue(value string, paramType api.PropertyType) any {
|
||||
value = strings.TrimSpace(value)
|
||||
|
||||
// If no type specified, return as string
|
||||
if len(paramType) == 0 {
|
||||
return value
|
||||
}
|
||||
|
||||
// Try to parse based on specified types
|
||||
for _, t := range paramType {
|
||||
switch t {
|
||||
case "boolean":
|
||||
if value == "true" {
|
||||
return true
|
||||
}
|
||||
if value == "false" {
|
||||
return false
|
||||
}
|
||||
case "integer":
|
||||
var i int64
|
||||
if _, err := fmt.Sscanf(value, "%d", &i); err == nil {
|
||||
return i
|
||||
}
|
||||
case "number":
|
||||
var f float64
|
||||
if _, err := fmt.Sscanf(value, "%f", &f); err == nil {
|
||||
return f
|
||||
}
|
||||
case "array", "object":
|
||||
// Try to parse as JSON
|
||||
var result any
|
||||
if err := json.Unmarshal([]byte(value), &result); err == nil {
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Default to string
|
||||
return value
|
||||
}
|
||||
192
x/models/glm4_moe_lite/parser_test.go
Normal file
192
x/models/glm4_moe_lite/parser_test.go
Normal file
@@ -0,0 +1,192 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm4_moe_lite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestParserThinking(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
thinkEnabled bool
|
||||
wantContent string
|
||||
wantThinking string
|
||||
wantToolCalls int
|
||||
}{
|
||||
{
|
||||
name: "thinking enabled - simple thinking then content",
|
||||
input: "Let me think about this...</think>Here is my answer.",
|
||||
thinkEnabled: true,
|
||||
wantThinking: "Let me think about this...",
|
||||
wantContent: "Here is my answer.",
|
||||
},
|
||||
{
|
||||
name: "thinking enabled - only thinking",
|
||||
input: "I need to consider multiple factors...",
|
||||
thinkEnabled: true,
|
||||
wantThinking: "I need to consider multiple factors...",
|
||||
wantContent: "",
|
||||
},
|
||||
{
|
||||
name: "thinking disabled - direct content",
|
||||
input: "Here is my direct answer.",
|
||||
thinkEnabled: false,
|
||||
wantThinking: "",
|
||||
wantContent: "Here is my direct answer.",
|
||||
},
|
||||
{
|
||||
name: "thinking with tool call",
|
||||
input: "Let me search for that...</think>I'll use a tool.<tool_call>search<arg_key>query</arg_key><arg_value>test</arg_value></tool_call>",
|
||||
thinkEnabled: true,
|
||||
wantThinking: "Let me search for that...",
|
||||
wantContent: "I'll use a tool.",
|
||||
wantToolCalls: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := &Parser{}
|
||||
|
||||
var thinkValue *api.ThinkValue
|
||||
if tt.thinkEnabled {
|
||||
thinkValue = &api.ThinkValue{Value: true}
|
||||
} else {
|
||||
thinkValue = &api.ThinkValue{Value: false}
|
||||
}
|
||||
|
||||
// Define tools for tool call tests
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("query", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Function: api.ToolFunction{
|
||||
Name: "search",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
p.Init(tools, nil, thinkValue)
|
||||
|
||||
content, thinking, calls, err := p.Add(tt.input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if thinking != tt.wantThinking {
|
||||
t.Errorf("thinking = %q, want %q", thinking, tt.wantThinking)
|
||||
}
|
||||
if content != tt.wantContent {
|
||||
t.Errorf("content = %q, want %q", content, tt.wantContent)
|
||||
}
|
||||
if len(calls) != tt.wantToolCalls {
|
||||
t.Errorf("len(calls) = %d, want %d", len(calls), tt.wantToolCalls)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParserToolCall(t *testing.T) {
|
||||
p := &Parser{}
|
||||
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
props.Set("unit", api.ToolProperty{Type: api.PropertyType{"string"}})
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Initialize with thinking disabled
|
||||
tv := &api.ThinkValue{Value: false}
|
||||
p.Init(tools, nil, tv)
|
||||
|
||||
input := "<tool_call>get_weather<arg_key>location</arg_key><arg_value>San Francisco</arg_value><arg_key>unit</arg_key><arg_value>celsius</arg_value></tool_call>"
|
||||
|
||||
_, _, calls, err := p.Add(input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
|
||||
call := calls[0]
|
||||
if call.Function.Name != "get_weather" {
|
||||
t.Errorf("function name = %q, want %q", call.Function.Name, "get_weather")
|
||||
}
|
||||
|
||||
location, ok := call.Function.Arguments.Get("location")
|
||||
if !ok || location != "San Francisco" {
|
||||
t.Errorf("location = %v, want %q", location, "San Francisco")
|
||||
}
|
||||
|
||||
unit, ok := call.Function.Arguments.Get("unit")
|
||||
if !ok || unit != "celsius" {
|
||||
t.Errorf("unit = %v, want %q", unit, "celsius")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOverlap(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
tag string
|
||||
want int
|
||||
}{
|
||||
{"hello<", "</think>", 1},
|
||||
{"hello</", "</think>", 2},
|
||||
{"hello</t", "</think>", 3},
|
||||
{"hello</th", "</think>", 4},
|
||||
{"hello</thi", "</think>", 5},
|
||||
{"hello</thin", "</think>", 6},
|
||||
{"hello</think", "</think>", 7},
|
||||
{"hello</think>", "</think>", 8}, // Complete tag at end returns full length
|
||||
{"hello", "</think>", 0},
|
||||
{"", "</think>", 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.s+"_"+tt.tag, func(t *testing.T) {
|
||||
got := overlap(tt.s, tt.tag)
|
||||
if got != tt.want {
|
||||
t.Errorf("overlap(%q, %q) = %d, want %d", tt.s, tt.tag, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrailingWhitespaceLen(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
want int
|
||||
}{
|
||||
{"hello ", 3},
|
||||
{"hello\n\t ", 3},
|
||||
{"hello", 0},
|
||||
{"", 0},
|
||||
{" ", 3},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.s, func(t *testing.T) {
|
||||
got := trailingWhitespaceLen(tt.s)
|
||||
if got != tt.want {
|
||||
t.Errorf("trailingWhitespaceLen(%q) = %d, want %d", tt.s, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
175
x/models/glm4_moe_lite/render.go
Normal file
175
x/models/glm4_moe_lite/render.go
Normal file
@@ -0,0 +1,175 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm4_moe_lite
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Renderer renders messages for GLM4-MoE-Lite models.
|
||||
//
|
||||
// GLM-4 Thinking Modes (ref: https://docs.z.ai/guides/capabilities/thinking-mode):
|
||||
//
|
||||
// 1. INTERLEAVED THINKING
|
||||
// The model thinks between tool calls and after receiving tool results.
|
||||
// This enables complex step-by-step reasoning: interpreting each tool output
|
||||
// before deciding what to do next. Thinking blocks are preserved and returned
|
||||
// with tool results to maintain reasoning continuity.
|
||||
//
|
||||
// 2. PRESERVED THINKING
|
||||
// The model retains reasoning content from previous assistant turns in context.
|
||||
// This preserves reasoning continuity across multi-turn conversations. The
|
||||
// upstream API has a "clear_thinking" parameter to control this:
|
||||
// - clear_thinking=true: clears reasoning from previous turns (outputs </think>)
|
||||
// - clear_thinking=false: preserves <think>...</think> blocks from previous turns
|
||||
//
|
||||
// 3. TURN-LEVEL THINKING
|
||||
// Controls whether the model should reason on each turn. The upstream API
|
||||
// uses "enable_thinking" parameter:
|
||||
// - enable_thinking=true: outputs <think> to start reasoning
|
||||
// - enable_thinking=false: outputs </think> to skip reasoning
|
||||
//
|
||||
// OLLAMA DEFAULTS:
|
||||
// - Thinking is ENABLED by default (thinkValue=nil or true outputs <think>)
|
||||
// - Thinking is PRESERVED by default (reasoning content from previous turns is always
|
||||
// included in <think>...</think> blocks, equivalent to clear_thinking=false)
|
||||
// - Users can disable thinking per-turn via thinkValue=false
|
||||
type Renderer struct{}
|
||||
|
||||
// Render renders messages into the GLM4 chat format.
|
||||
func (r *Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("[gMASK]<sop>")
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString("<|system|>\n")
|
||||
sb.WriteString("# Tools\n\n")
|
||||
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
|
||||
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
|
||||
sb.WriteString("<tools>\n")
|
||||
for _, tool := range tools {
|
||||
d, _ := json.Marshal(tool)
|
||||
sb.WriteString(formatToolJSON(d))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("</tools>\n\n")
|
||||
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
|
||||
sb.WriteString("<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call>")
|
||||
}
|
||||
|
||||
think := true
|
||||
if thinkValue != nil && !thinkValue.Bool() {
|
||||
think = false
|
||||
}
|
||||
|
||||
for i, message := range messages {
|
||||
switch message.Role {
|
||||
case "user":
|
||||
sb.WriteString("<|user|>")
|
||||
sb.WriteString(message.Content)
|
||||
case "assistant":
|
||||
sb.WriteString("<|assistant|>")
|
||||
if message.Thinking != "" {
|
||||
sb.WriteString("<think>" + message.Thinking + "</think>")
|
||||
} else {
|
||||
sb.WriteString("</think>")
|
||||
}
|
||||
if message.Content != "" {
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
if len(message.ToolCalls) > 0 {
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("<tool_call>" + toolCall.Function.Name)
|
||||
sb.WriteString(renderToolArguments(toolCall.Function.Arguments))
|
||||
sb.WriteString("</tool_call>")
|
||||
}
|
||||
}
|
||||
case "tool":
|
||||
if i == 0 || messages[i-1].Role != "tool" {
|
||||
sb.WriteString("<|observation|>")
|
||||
}
|
||||
sb.WriteString("<tool_response>")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("</tool_response>")
|
||||
case "system":
|
||||
sb.WriteString("<|system|>")
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<|assistant|>")
|
||||
if think {
|
||||
sb.WriteString("<think>")
|
||||
} else {
|
||||
sb.WriteString("</think>")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
// renderToolArguments converts tool call arguments to GLM4 XML format.
|
||||
func renderToolArguments(args api.ToolCallFunctionArguments) string {
|
||||
var sb strings.Builder
|
||||
for key, value := range args.All() {
|
||||
sb.WriteString("<arg_key>" + key + "</arg_key>")
|
||||
var valueStr string
|
||||
if str, ok := value.(string); ok {
|
||||
valueStr = str
|
||||
} else {
|
||||
jsonBytes, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
valueStr = fmt.Sprintf("%v", value)
|
||||
} else {
|
||||
valueStr = string(jsonBytes)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<arg_value>" + valueStr + "</arg_value>")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// formatToolJSON formats JSON for GLM4 tool definitions by adding spaces after : and ,
|
||||
func formatToolJSON(raw []byte) string {
|
||||
var sb strings.Builder
|
||||
sb.Grow(len(raw) + len(raw)/10)
|
||||
|
||||
inString := false
|
||||
escaped := false
|
||||
for i := range raw {
|
||||
ch := raw[i]
|
||||
sb.WriteByte(ch)
|
||||
|
||||
if inString {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
inString = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == '"' {
|
||||
inString = true
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == ':' || ch == ',' {
|
||||
sb.WriteByte(' ')
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
205
x/models/glm4_moe_lite/render_test.go
Normal file
205
x/models/glm4_moe_lite/render_test.go
Normal file
@@ -0,0 +1,205 @@
|
||||
//go:build mlx
|
||||
|
||||
package glm4_moe_lite
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestRendererSimple(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
|
||||
// Thinking enabled (default)
|
||||
result, err := r.Render(messages, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expected := "[gMASK]<sop><|user|>Hello<|assistant|><think>"
|
||||
if result != expected {
|
||||
t.Errorf("result = %q, want %q", result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRendererThinkingDisabled(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
|
||||
tv := &api.ThinkValue{Value: false}
|
||||
|
||||
result, err := r.Render(messages, nil, tv)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expected := "[gMASK]<sop><|user|>Hello<|assistant|></think>"
|
||||
if result != expected {
|
||||
t.Errorf("result = %q, want %q", result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRendererMultiTurn(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "What is 2+2?"},
|
||||
{Role: "assistant", Content: "4", Thinking: "Let me calculate: 2+2=4"},
|
||||
{Role: "user", Content: "And 3+3?"},
|
||||
}
|
||||
|
||||
result, err := r.Render(messages, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Check key parts
|
||||
if !strings.Contains(result, "[gMASK]<sop>") {
|
||||
t.Error("missing [gMASK]<sop> prefix")
|
||||
}
|
||||
if !strings.Contains(result, "<|user|>What is 2+2?") {
|
||||
t.Error("missing first user message")
|
||||
}
|
||||
if !strings.Contains(result, "<|assistant|><think>Let me calculate: 2+2=4</think>4") {
|
||||
t.Error("missing assistant message with thinking")
|
||||
}
|
||||
if !strings.Contains(result, "<|user|>And 3+3?") {
|
||||
t.Error("missing second user message")
|
||||
}
|
||||
if !strings.HasSuffix(result, "<|assistant|><think>") {
|
||||
t.Errorf("should end with <|assistant|><think>, got suffix: %q", result[len(result)-30:])
|
||||
}
|
||||
}
|
||||
|
||||
func TestRendererWithSystem(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
|
||||
result, err := r.Render(messages, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "<|system|>You are a helpful assistant.") {
|
||||
t.Error("missing system message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRendererWithTools(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
}
|
||||
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "The city"})
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the weather for a location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: props,
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := r.Render(messages, tools, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Check for tool system prompt
|
||||
if !strings.Contains(result, "<|system|>") {
|
||||
t.Error("missing system tag for tools")
|
||||
}
|
||||
if !strings.Contains(result, "# Tools") {
|
||||
t.Error("missing tools header")
|
||||
}
|
||||
if !strings.Contains(result, "<tools>") {
|
||||
t.Error("missing tools tag")
|
||||
}
|
||||
if !strings.Contains(result, "get_weather") {
|
||||
t.Error("missing tool name")
|
||||
}
|
||||
if !strings.Contains(result, "</tools>") {
|
||||
t.Error("missing closing tools tag")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRendererWithToolCalls(t *testing.T) {
|
||||
r := &Renderer{}
|
||||
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
args.Set("location", "San Francisco")
|
||||
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "What's the weather in SF?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: args,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Sunny, 72F"},
|
||||
}
|
||||
|
||||
result, err := r.Render(messages, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "<tool_call>get_weather") {
|
||||
t.Error("missing tool call")
|
||||
}
|
||||
if !strings.Contains(result, "<arg_key>location</arg_key>") {
|
||||
t.Error("missing arg_key")
|
||||
}
|
||||
if !strings.Contains(result, "<arg_value>San Francisco</arg_value>") {
|
||||
t.Error("missing arg_value")
|
||||
}
|
||||
if !strings.Contains(result, "</tool_call>") {
|
||||
t.Error("missing tool call closing tag")
|
||||
}
|
||||
if !strings.Contains(result, "<|observation|>") {
|
||||
t.Error("missing observation tag")
|
||||
}
|
||||
if !strings.Contains(result, "<tool_response>Sunny, 72F</tool_response>") {
|
||||
t.Error("missing tool response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatToolJSON(t *testing.T) {
|
||||
input := []byte(`{"name":"test","value":123}`)
|
||||
result := formatToolJSON(input)
|
||||
|
||||
// Should add spaces after : and ,
|
||||
if !strings.Contains(result, ": ") {
|
||||
t.Error("should add space after colon")
|
||||
}
|
||||
if !strings.Contains(result, ", ") {
|
||||
t.Error("should add space after comma")
|
||||
}
|
||||
}
|
||||
186
x/models/nn/nn.go
Normal file
186
x/models/nn/nn.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package nn
|
||||
|
||||
import "github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
|
||||
// Layer is the interface for neural network layers with a Forward method.
|
||||
type Layer interface {
|
||||
Forward(x *mlx.Array) *mlx.Array
|
||||
}
|
||||
|
||||
// LinearLayer is an interface for linear layers (both regular and quantized).
|
||||
type LinearLayer interface {
|
||||
Forward(x *mlx.Array) *mlx.Array
|
||||
OutputDim() int32
|
||||
}
|
||||
|
||||
// Linear applies an affine transformation: y = x @ W.T + b
|
||||
type Linear struct {
|
||||
Weight *mlx.Array
|
||||
Bias *mlx.Array
|
||||
}
|
||||
|
||||
func NewLinear(weight *mlx.Array, bias *mlx.Array) *Linear {
|
||||
return &Linear{Weight: weight, Bias: bias}
|
||||
}
|
||||
|
||||
func (l *Linear) Forward(x *mlx.Array) *mlx.Array {
|
||||
w := l.Weight.Transpose(1, 0)
|
||||
if l.Bias != nil && l.Bias.Valid() {
|
||||
return l.Bias.Addmm(x, w, 1.0, 1.0)
|
||||
}
|
||||
return x.Matmul(w)
|
||||
}
|
||||
|
||||
func (l *Linear) OutputDim() int32 {
|
||||
return int32(l.Weight.Dim(0))
|
||||
}
|
||||
|
||||
// QuantizedLinear applies an affine transformation using quantized weights.
|
||||
type QuantizedLinear struct {
|
||||
Weight *mlx.Array // Quantized weight data
|
||||
Scales *mlx.Array // Scale factors for dequantization
|
||||
QBiases *mlx.Array // Quantization biases (nil for nvfp4)
|
||||
Bias *mlx.Array // Layer bias [output_dims] or nil
|
||||
GroupSize int
|
||||
Bits int
|
||||
Mode string
|
||||
}
|
||||
|
||||
func NewQuantizedLinear(weight *mlx.Array, bias *mlx.Array, groupSize, bits int, mode string) *QuantizedLinear {
|
||||
qw, scales, qbiases := mlx.Quantize(weight, groupSize, bits, mode)
|
||||
if qbiases != nil {
|
||||
mlx.Eval(qw, scales, qbiases)
|
||||
} else {
|
||||
mlx.Eval(qw, scales)
|
||||
}
|
||||
return &QuantizedLinear{
|
||||
Weight: qw,
|
||||
Scales: scales,
|
||||
QBiases: qbiases,
|
||||
Bias: bias,
|
||||
GroupSize: groupSize,
|
||||
Bits: bits,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
func (ql *QuantizedLinear) Forward(x *mlx.Array) *mlx.Array {
|
||||
out := mlx.QuantizedMatmul(x, ql.Weight, ql.Scales, ql.QBiases, true, ql.GroupSize, ql.Bits, ql.Mode)
|
||||
if ql.Bias != nil && ql.Bias.Valid() {
|
||||
out = out.Add(ql.Bias)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (ql *QuantizedLinear) OutputDim() int32 {
|
||||
return int32(ql.Weight.Dim(0))
|
||||
}
|
||||
|
||||
// RMSNorm represents an RMS normalization layer.
|
||||
type RMSNorm struct {
|
||||
Weight *mlx.Array
|
||||
Eps float32
|
||||
}
|
||||
|
||||
func NewRMSNorm(weight *mlx.Array, eps float32) *RMSNorm {
|
||||
return &RMSNorm{Weight: weight, Eps: eps}
|
||||
}
|
||||
|
||||
func (rn *RMSNorm) Forward(x *mlx.Array, eps float32) *mlx.Array {
|
||||
if eps == 0 {
|
||||
eps = rn.Eps
|
||||
}
|
||||
return mlx.RMSNormFn(x, rn.Weight, eps)
|
||||
}
|
||||
|
||||
// Embedding represents an embedding layer.
|
||||
type Embedding struct {
|
||||
Weight *mlx.Array
|
||||
}
|
||||
|
||||
func NewEmbedding(weight *mlx.Array) *Embedding {
|
||||
return &Embedding{Weight: weight}
|
||||
}
|
||||
|
||||
func (e *Embedding) Forward(indices *mlx.Array) *mlx.Array {
|
||||
return e.Weight.TakeAxis(indices, 0)
|
||||
}
|
||||
|
||||
// LayerNorm represents a standard layer normalization layer (with bias).
|
||||
type LayerNorm struct {
|
||||
Weight *mlx.Array
|
||||
Bias *mlx.Array
|
||||
Eps float32
|
||||
}
|
||||
|
||||
func (ln *LayerNorm) Forward(x *mlx.Array) *mlx.Array {
|
||||
eps := ln.Eps
|
||||
if eps == 0 {
|
||||
eps = 1e-5
|
||||
}
|
||||
mean := mlx.Mean(x, -1, true)
|
||||
centered := x.Subtract(mean)
|
||||
variance := mlx.Mean(centered.Multiply(centered), -1, true)
|
||||
normalized := centered.Multiply(mlx.RSqrt(mlx.AddScalar(variance, eps)))
|
||||
out := normalized.Multiply(ln.Weight)
|
||||
if ln.Bias != nil && ln.Bias.Valid() {
|
||||
out = out.Add(ln.Bias)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// MultiLinearLayer is an interface for per-head linear layers.
|
||||
type MultiLinearLayer interface {
|
||||
Forward(x *mlx.Array) *mlx.Array
|
||||
}
|
||||
|
||||
// MultiLinear performs per-head linear projections.
|
||||
// Weight shape: [num_heads, output_dims, input_dims]
|
||||
type MultiLinear struct {
|
||||
Weight *mlx.Array
|
||||
}
|
||||
|
||||
func NewMultiLinear(weight *mlx.Array) *MultiLinear {
|
||||
return &MultiLinear{Weight: weight}
|
||||
}
|
||||
|
||||
func (ml *MultiLinear) Forward(x *mlx.Array) *mlx.Array {
|
||||
wT := ml.Weight.Transpose(0, 2, 1)
|
||||
return x.Matmul(wT)
|
||||
}
|
||||
|
||||
// RepeatKV repeats K/V tensors for grouped query attention.
|
||||
func RepeatKV(x *mlx.Array, repeatFactor int32) *mlx.Array {
|
||||
if repeatFactor == 1 {
|
||||
return x
|
||||
}
|
||||
shape := x.Dims()
|
||||
x = x.ExpandDims(2)
|
||||
reps := []int32{1, 1, repeatFactor, 1, 1}
|
||||
x = mlx.Tile(x, reps)
|
||||
return mlx.Reshape(x, int32(shape[0]), int32(shape[1])*repeatFactor, int32(shape[2]), int32(shape[3]))
|
||||
}
|
||||
|
||||
// ApplyCausalMask applies causal (lower triangular) mask to attention scores.
|
||||
func ApplyCausalMask(scores *mlx.Array) *mlx.Array {
|
||||
shape := scores.Dims()
|
||||
seqLen := int32(shape[2])
|
||||
mask := mlx.Tri(seqLen, seqLen, 0)
|
||||
negInf := mlx.NewScalarArray(float32(-1e9))
|
||||
mask = mask.ExpandDims(0).ExpandDims(0)
|
||||
return mlx.Where(mask, scores, negInf)
|
||||
}
|
||||
|
||||
// ApplyCausalMaskWithOffset applies causal mask for cached attention.
|
||||
func ApplyCausalMaskWithOffset(scores *mlx.Array, offset int32) *mlx.Array {
|
||||
if offset == 0 {
|
||||
return ApplyCausalMask(scores)
|
||||
}
|
||||
shape := scores.Dims()
|
||||
queryLen := int32(shape[2])
|
||||
keyLen := int32(shape[3])
|
||||
mask := mlx.Tri(queryLen, keyLen, int(offset))
|
||||
negInf := mlx.NewScalarArray(float32(-1e9))
|
||||
mask = mask.ExpandDims(0).ExpandDims(0)
|
||||
return mlx.Where(mask, scores, negInf)
|
||||
}
|
||||
394
x/server/show.go
394
x/server/show.go
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
@@ -105,9 +106,9 @@ func buildModelInfo(config modelConfig, totalTensorBytes, tensorCount int64) map
|
||||
bytesPerParam = 1
|
||||
}
|
||||
|
||||
// Subtract safetensors header overhead (88 bytes per tensor file)
|
||||
// Each tensor is stored as a minimal safetensors file
|
||||
totalBytes := totalTensorBytes - tensorCount*88
|
||||
// Subtract safetensors header overhead per tensor blob.
|
||||
// Headers include __metadata__ with the tensor name, so overhead is ~150 bytes on average.
|
||||
totalBytes := totalTensorBytes - tensorCount*150
|
||||
|
||||
paramCount := totalBytes / bytesPerParam
|
||||
|
||||
@@ -163,24 +164,103 @@ func GetSafetensorsTensorInfo(name model.Name) ([]api.Tensor, error) {
|
||||
|
||||
// getTensorInfoFromManifest extracts tensor info from a manifest.
|
||||
// This is separated for testability.
|
||||
// For quantized models, groups weight/scale/qbias into single entries with detected quantization type.
|
||||
// For quantized tensors, reads quant_type from blob __metadata__.
|
||||
// For packed blobs (multiple tensors per blob), enumerates all tensors in the blob.
|
||||
func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) {
|
||||
var tensors []api.Tensor
|
||||
|
||||
// First pass: collect all tensor info and identify scale tensors
|
||||
type tensorData struct {
|
||||
info *safetensorsTensorInfo
|
||||
digest string
|
||||
}
|
||||
tensorMap := make(map[string]*tensorData)
|
||||
scaleMap := make(map[string]*tensorData) // base name -> scale tensor info
|
||||
|
||||
for _, layer := range mf.Layers {
|
||||
if layer.MediaType != manifest.MediaTypeImageTensor {
|
||||
continue
|
||||
}
|
||||
|
||||
// Read the safetensors header from the blob
|
||||
// Read all tensor entries from the safetensors header
|
||||
blobPath, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
f, err := os.Open(blobPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
allInfos, err := parseSafetensorsAllHeaders(f)
|
||||
f.Close()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Determine if this is a packed blob (multiple main tensors)
|
||||
isPacked := len(allInfos) > 1
|
||||
|
||||
for _, info := range allInfos {
|
||||
tensorName := layer.Name
|
||||
if isPacked {
|
||||
// For packed blobs, use the tensor name from the header
|
||||
tensorName = info.Name
|
||||
}
|
||||
|
||||
if info.QuantType != "" {
|
||||
quantType := strings.ToUpper(info.QuantType)
|
||||
|
||||
shape := make([]uint64, len(info.Shape))
|
||||
for i, s := range info.Shape {
|
||||
shape[i] = uint64(s)
|
||||
}
|
||||
|
||||
var packFactor int64
|
||||
switch strings.ToLower(info.QuantType) {
|
||||
case "int4", "nvfp4":
|
||||
packFactor = 8
|
||||
case "int8", "mxfp8":
|
||||
packFactor = 4
|
||||
}
|
||||
if packFactor > 0 && len(shape) >= 2 {
|
||||
shape[len(shape)-1] = uint64(info.Shape[len(info.Shape)-1] * packFactor)
|
||||
}
|
||||
|
||||
tensors = append(tensors, api.Tensor{
|
||||
Name: tensorName,
|
||||
Type: quantType,
|
||||
Shape: shape,
|
||||
})
|
||||
} else {
|
||||
shape := make([]uint64, len(info.Shape))
|
||||
for i, s := range info.Shape {
|
||||
shape[i] = uint64(s)
|
||||
}
|
||||
|
||||
tensors = append(tensors, api.Tensor{
|
||||
Name: tensorName,
|
||||
Type: info.Dtype,
|
||||
Shape: shape,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(tensors, func(i, j int) bool {
|
||||
return tensors[i].Name < tensors[j].Name
|
||||
})
|
||||
|
||||
return tensors, nil
|
||||
}
|
||||
|
||||
// GetSafetensorsDtype returns the quantization type for a safetensors model.
|
||||
// Reads quant_type from the first tensor blob's __metadata__.
|
||||
// Falls back to torch_dtype from config.json if no quant metadata.
|
||||
func GetSafetensorsDtype(name model.Name) (string, error) {
|
||||
mf, err := manifest.ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
// Check first tensor blob for quant_type metadata
|
||||
for _, layer := range mf.Layers {
|
||||
if layer.MediaType != manifest.MediaTypeImageTensor {
|
||||
continue
|
||||
}
|
||||
blobPath, err := manifest.BlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
continue
|
||||
@@ -189,131 +269,11 @@ func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
td := &tensorData{info: info, digest: layer.Digest}
|
||||
|
||||
if strings.HasSuffix(layer.Name, "_scale") {
|
||||
baseName := strings.TrimSuffix(layer.Name, "_scale")
|
||||
scaleMap[baseName] = td
|
||||
} else if strings.HasSuffix(layer.Name, "_qbias") {
|
||||
// Skip qbias tensors - they're included with the quantized weight
|
||||
continue
|
||||
} else {
|
||||
tensorMap[layer.Name] = td
|
||||
if info.QuantType != "" {
|
||||
return strings.ToUpper(info.QuantType), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Second pass: build tensor list with quantization info
|
||||
for _, layer := range mf.Layers {
|
||||
if layer.MediaType != manifest.MediaTypeImageTensor {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip scale and qbias tensors
|
||||
if strings.HasSuffix(layer.Name, "_scale") || strings.HasSuffix(layer.Name, "_qbias") {
|
||||
continue
|
||||
}
|
||||
|
||||
td := tensorMap[layer.Name]
|
||||
if td == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this tensor has a corresponding scale tensor (quantized)
|
||||
scaleTd := scaleMap[layer.Name]
|
||||
if scaleTd != nil && len(td.info.Shape) >= 2 && len(scaleTd.info.Shape) >= 2 {
|
||||
// Quantized tensor - detect bits from shapes
|
||||
weightCols := td.info.Shape[len(td.info.Shape)-1]
|
||||
scaleCols := scaleTd.info.Shape[len(scaleTd.info.Shape)-1]
|
||||
|
||||
// Detect quantization: Q4 has pack_factor=8, Q8 has pack_factor=4
|
||||
// Q4 uses group_size=32: weightCols * 8 / scaleCols = 32
|
||||
// Q8 uses group_size=64: weightCols * 4 / scaleCols = 64
|
||||
var bits int
|
||||
var quantType string
|
||||
if weightCols*8/scaleCols == 32 {
|
||||
bits = 4
|
||||
quantType = "Q4"
|
||||
} else if weightCols*4/scaleCols == 64 {
|
||||
bits = 8
|
||||
quantType = "Q8"
|
||||
} else {
|
||||
// Unknown quantization, show raw
|
||||
quantType = td.info.Dtype
|
||||
}
|
||||
|
||||
// Calculate unpacked shape
|
||||
shape := make([]uint64, len(td.info.Shape))
|
||||
for i, s := range td.info.Shape {
|
||||
shape[i] = uint64(s)
|
||||
}
|
||||
if bits > 0 {
|
||||
packFactor := int64(32 / bits)
|
||||
shape[len(shape)-1] = uint64(td.info.Shape[len(td.info.Shape)-1] * packFactor)
|
||||
}
|
||||
|
||||
tensors = append(tensors, api.Tensor{
|
||||
Name: layer.Name,
|
||||
Type: quantType,
|
||||
Shape: shape,
|
||||
})
|
||||
} else {
|
||||
// Non-quantized tensor
|
||||
shape := make([]uint64, len(td.info.Shape))
|
||||
for i, s := range td.info.Shape {
|
||||
shape[i] = uint64(s)
|
||||
}
|
||||
|
||||
tensors = append(tensors, api.Tensor{
|
||||
Name: layer.Name,
|
||||
Type: td.info.Dtype,
|
||||
Shape: shape,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return tensors, nil
|
||||
}
|
||||
|
||||
// GetSafetensorsDtype returns the quantization type for a safetensors model.
|
||||
// Reads from model_index.json first, falls back to detection from tensor names.
|
||||
// Otherwise returns the torch_dtype from config.json.
|
||||
func GetSafetensorsDtype(name model.Name) (string, error) {
|
||||
mf, err := manifest.ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load manifest: %w", err)
|
||||
}
|
||||
|
||||
// First try to read quantization from model_index.json
|
||||
var modelIndex struct {
|
||||
Quantization string `json:"quantization"`
|
||||
}
|
||||
if err := mf.ReadConfigJSON("model_index.json", &modelIndex); err == nil && modelIndex.Quantization != "" {
|
||||
return modelIndex.Quantization, nil
|
||||
}
|
||||
|
||||
// Fallback: detect from tensor names
|
||||
hasScales := false
|
||||
hasQBias := false
|
||||
for _, layer := range mf.Layers {
|
||||
if layer.MediaType == manifest.MediaTypeImageTensor {
|
||||
if strings.HasSuffix(layer.Name, "_scale") {
|
||||
hasScales = true
|
||||
}
|
||||
if strings.HasSuffix(layer.Name, "_qbias") {
|
||||
hasQBias = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasScales {
|
||||
if hasQBias {
|
||||
// Affine mode (has scale + qbias) - could be Q4 or Q8
|
||||
// Default to Q4 as it's more common
|
||||
return "Q4", nil
|
||||
}
|
||||
// No qbias = NVFP4
|
||||
return "NVFP4", nil
|
||||
// Only check the first tensor blob
|
||||
break
|
||||
}
|
||||
|
||||
// Not quantized - return torch_dtype from config.json
|
||||
@@ -329,8 +289,11 @@ func GetSafetensorsDtype(name model.Name) (string, error) {
|
||||
|
||||
// safetensorsTensorInfo holds metadata about a tensor from a safetensors header
|
||||
type safetensorsTensorInfo struct {
|
||||
Dtype string `json:"dtype"`
|
||||
Shape []int64 `json:"shape"`
|
||||
Name string // tensor name from the header key
|
||||
Dtype string `json:"dtype"`
|
||||
Shape []int64 `json:"shape"`
|
||||
QuantType string // from __metadata__.quant_type (e.g., "int4", "int8", "nvfp4", "mxfp8")
|
||||
GroupSize string // from __metadata__.group_size (e.g., "32", "64")
|
||||
}
|
||||
|
||||
// readSafetensorsHeader reads the JSON header from a safetensors file to get tensor metadata.
|
||||
@@ -347,6 +310,7 @@ func readSafetensorsHeader(path string) (*safetensorsTensorInfo, error) {
|
||||
|
||||
// parseSafetensorsHeader parses a safetensors header from a reader.
|
||||
// This is separated for testability.
|
||||
// Parses __metadata__ for quant_type and group_size if present.
|
||||
func parseSafetensorsHeader(r io.Reader) (*safetensorsTensorInfo, error) {
|
||||
// Read header size (8 bytes, little endian)
|
||||
var headerSize uint64
|
||||
@@ -371,7 +335,31 @@ func parseSafetensorsHeader(r io.Reader) (*safetensorsTensorInfo, error) {
|
||||
return nil, fmt.Errorf("failed to parse header: %w", err)
|
||||
}
|
||||
|
||||
// Find the first (and should be only) tensor entry
|
||||
// Parse metadata if present
|
||||
var quantType, groupSize string
|
||||
if metaRaw, ok := header["__metadata__"]; ok {
|
||||
var meta map[string]string
|
||||
if json.Unmarshal(metaRaw, &meta) == nil {
|
||||
quantType = meta["quant_type"]
|
||||
groupSize = meta["group_size"]
|
||||
}
|
||||
}
|
||||
|
||||
// Find the main tensor entry (not __metadata__, .scale, or .bias)
|
||||
for name, raw := range header {
|
||||
if name == "__metadata__" || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".bias") {
|
||||
continue
|
||||
}
|
||||
var info safetensorsTensorInfo
|
||||
if err := json.Unmarshal(raw, &info); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse tensor info: %w", err)
|
||||
}
|
||||
info.QuantType = quantType
|
||||
info.GroupSize = groupSize
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
// Fall back to first non-metadata tensor entry
|
||||
for name, raw := range header {
|
||||
if name == "__metadata__" {
|
||||
continue
|
||||
@@ -380,8 +368,134 @@ func parseSafetensorsHeader(r io.Reader) (*safetensorsTensorInfo, error) {
|
||||
if err := json.Unmarshal(raw, &info); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse tensor info: %w", err)
|
||||
}
|
||||
info.QuantType = quantType
|
||||
info.GroupSize = groupSize
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no tensor found in header")
|
||||
}
|
||||
|
||||
// parseSafetensorsAllHeaders parses all tensor entries from a safetensors header.
|
||||
// Returns one safetensorsTensorInfo per main tensor (skipping __metadata__, .scale, .bias).
|
||||
// For packed blobs this returns multiple entries; for single-tensor blobs, one entry.
|
||||
// Each tensor's quant type is inferred from its shape and the presence of .scale/.bias entries
|
||||
// when no global __metadata__ quant_type is present.
|
||||
func parseSafetensorsAllHeaders(r io.Reader) ([]safetensorsTensorInfo, error) {
|
||||
var headerSize uint64
|
||||
if err := binary.Read(r, binary.LittleEndian, &headerSize); err != nil {
|
||||
return nil, fmt.Errorf("failed to read header size: %w", err)
|
||||
}
|
||||
|
||||
if headerSize > 100*1024*1024 { // 100MB limit for packed blob headers
|
||||
return nil, fmt.Errorf("header size too large: %d", headerSize)
|
||||
}
|
||||
|
||||
headerBytes := make([]byte, headerSize)
|
||||
if _, err := io.ReadFull(r, headerBytes); err != nil {
|
||||
return nil, fmt.Errorf("failed to read header: %w", err)
|
||||
}
|
||||
|
||||
var header map[string]json.RawMessage
|
||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse header: %w", err)
|
||||
}
|
||||
|
||||
// Parse global metadata if present
|
||||
var globalQuantType, globalGroupSize string
|
||||
if metaRaw, ok := header["__metadata__"]; ok {
|
||||
var meta map[string]string
|
||||
if json.Unmarshal(metaRaw, &meta) == nil {
|
||||
globalQuantType = meta["quant_type"]
|
||||
globalGroupSize = meta["group_size"]
|
||||
}
|
||||
}
|
||||
|
||||
// Build a set of all keys for checking .scale/.bias presence
|
||||
headerKeys := make(map[string]bool, len(header))
|
||||
for k := range header {
|
||||
headerKeys[k] = true
|
||||
}
|
||||
|
||||
// Collect all main tensor entries (sorted for deterministic output)
|
||||
var mainNames []string
|
||||
for name := range header {
|
||||
if name == "__metadata__" || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".bias") {
|
||||
continue
|
||||
}
|
||||
mainNames = append(mainNames, name)
|
||||
}
|
||||
sort.Strings(mainNames)
|
||||
|
||||
var results []safetensorsTensorInfo
|
||||
for _, name := range mainNames {
|
||||
var info safetensorsTensorInfo
|
||||
if err := json.Unmarshal(header[name], &info); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse tensor info for %s: %w", name, err)
|
||||
}
|
||||
info.Name = name
|
||||
|
||||
if globalQuantType != "" {
|
||||
// Use global metadata
|
||||
info.QuantType = globalQuantType
|
||||
info.GroupSize = globalGroupSize
|
||||
} else if headerKeys[name+".scale"] {
|
||||
// No global metadata, but has .scale - infer quant type from shape
|
||||
info.QuantType = inferQuantType(header, name)
|
||||
}
|
||||
|
||||
results = append(results, info)
|
||||
}
|
||||
|
||||
if len(results) == 0 {
|
||||
return nil, fmt.Errorf("no tensor found in header")
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// inferQuantType infers the quantization type for a tensor from its shape and scale shape.
|
||||
// Returns "int4", "int8", etc. or "" if not quantized.
|
||||
func inferQuantType(header map[string]json.RawMessage, name string) string {
|
||||
// Parse the main tensor shape
|
||||
var mainInfo struct {
|
||||
Shape []int64 `json:"shape"`
|
||||
}
|
||||
if json.Unmarshal(header[name], &mainInfo) != nil || len(mainInfo.Shape) < 2 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Parse scale shape to determine group size
|
||||
scaleRaw, ok := header[name+".scale"]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
var scaleInfo struct {
|
||||
Shape []int64 `json:"shape"`
|
||||
}
|
||||
if json.Unmarshal(scaleRaw, &scaleInfo) != nil || len(scaleInfo.Shape) < 2 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Calculate group size: main_cols * pack_factor / scale_cols
|
||||
// Main dtype is U32, so we need to figure out the pack factor
|
||||
// For int4: pack=8, group=32. scale_cols = original_cols / 32 = main_cols * 8 / 32 = main_cols / 4
|
||||
// For int8: pack=4, group=64. scale_cols = original_cols / 64 = main_cols * 4 / 64 = main_cols / 16
|
||||
mainCols := mainInfo.Shape[len(mainInfo.Shape)-1]
|
||||
scaleCols := scaleInfo.Shape[len(scaleInfo.Shape)-1]
|
||||
if scaleCols == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
ratio := mainCols / scaleCols // main_packed_cols / scale_cols
|
||||
// int4: ratio = (orig/8) / (orig/32) = 32/8 = 4
|
||||
// int8: ratio = (orig/4) / (orig/64) = 64/4 = 16
|
||||
switch ratio {
|
||||
case 4:
|
||||
return "int4"
|
||||
case 16:
|
||||
return "int8"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ func TestBuildModelInfo(t *testing.T) {
|
||||
VocabSize: 262144,
|
||||
TorchDtype: "bfloat16",
|
||||
},
|
||||
totalTensorBytes: 8_600_000_088, // ~4.3B params * 2 bytes + 88 bytes header
|
||||
totalTensorBytes: 8_600_000_150, // ~4.3B params * 2 bytes + 150 bytes header
|
||||
tensorCount: 1,
|
||||
wantArch: "gemma3",
|
||||
wantContextLen: 131072,
|
||||
@@ -57,7 +57,7 @@ func TestBuildModelInfo(t *testing.T) {
|
||||
VocabSize: 32000,
|
||||
TorchDtype: "float16",
|
||||
},
|
||||
totalTensorBytes: 14_000_000_088, // ~7B params * 2 bytes + 88 bytes header
|
||||
totalTensorBytes: 14_000_000_150, // ~7B params * 2 bytes + 150 bytes header
|
||||
tensorCount: 1,
|
||||
wantArch: "llama",
|
||||
wantContextLen: 4096,
|
||||
@@ -84,7 +84,7 @@ func TestBuildModelInfo(t *testing.T) {
|
||||
VocabSize: 262144,
|
||||
TorchDtype: "bfloat16",
|
||||
},
|
||||
totalTensorBytes: 8_600_000_088,
|
||||
totalTensorBytes: 8_600_000_150,
|
||||
tensorCount: 1,
|
||||
wantArch: "gemma3",
|
||||
wantContextLen: 131072,
|
||||
@@ -101,7 +101,7 @@ func TestBuildModelInfo(t *testing.T) {
|
||||
MaxPositionEmbeddings: 2048,
|
||||
TorchDtype: "float32",
|
||||
},
|
||||
totalTensorBytes: 400_000_088, // 100M params * 4 bytes + 88 bytes header
|
||||
totalTensorBytes: 400_000_150, // 100M params * 4 bytes + 150 bytes header
|
||||
tensorCount: 1,
|
||||
wantArch: "test",
|
||||
wantContextLen: 2048,
|
||||
@@ -118,7 +118,7 @@ func TestBuildModelInfo(t *testing.T) {
|
||||
MaxPositionEmbeddings: 1024,
|
||||
TorchDtype: "bfloat16",
|
||||
},
|
||||
totalTensorBytes: 2_000_880, // 1M params * 2 bytes + 10 tensors * 88 bytes
|
||||
totalTensorBytes: 2_001_500, // 1M params * 2 bytes + 10 tensors * 150 bytes
|
||||
tensorCount: 10,
|
||||
wantArch: "test",
|
||||
wantContextLen: 1024,
|
||||
@@ -230,42 +230,42 @@ func TestBuildModelInfo_BytesPerParam(t *testing.T) {
|
||||
{
|
||||
name: "bfloat16",
|
||||
dtype: "bfloat16",
|
||||
totalBytes: 2_000_088, // 1M * 2 + 88
|
||||
totalBytes: 2_000_150, // 1M * 2 + 150
|
||||
tensorCount: 1,
|
||||
wantParamCount: 1_000_000,
|
||||
},
|
||||
{
|
||||
name: "float16",
|
||||
dtype: "float16",
|
||||
totalBytes: 2_000_088,
|
||||
totalBytes: 2_000_150,
|
||||
tensorCount: 1,
|
||||
wantParamCount: 1_000_000,
|
||||
},
|
||||
{
|
||||
name: "float32",
|
||||
dtype: "float32",
|
||||
totalBytes: 4_000_088, // 1M * 4 + 88
|
||||
totalBytes: 4_000_150, // 1M * 4 + 150
|
||||
tensorCount: 1,
|
||||
wantParamCount: 1_000_000,
|
||||
},
|
||||
{
|
||||
name: "int8",
|
||||
dtype: "int8",
|
||||
totalBytes: 1_000_088, // 1M * 1 + 88
|
||||
totalBytes: 1_000_150, // 1M * 1 + 150
|
||||
tensorCount: 1,
|
||||
wantParamCount: 1_000_000,
|
||||
},
|
||||
{
|
||||
name: "unknown dtype defaults to 2 bytes",
|
||||
dtype: "unknown",
|
||||
totalBytes: 2_000_088,
|
||||
totalBytes: 2_000_150,
|
||||
tensorCount: 1,
|
||||
wantParamCount: 1_000_000,
|
||||
},
|
||||
{
|
||||
name: "empty dtype defaults to 2 bytes",
|
||||
dtype: "",
|
||||
totalBytes: 2_000_088,
|
||||
totalBytes: 2_000_150,
|
||||
tensorCount: 1,
|
||||
wantParamCount: 1_000_000,
|
||||
},
|
||||
@@ -288,11 +288,13 @@ func TestBuildModelInfo_BytesPerParam(t *testing.T) {
|
||||
|
||||
func TestParseSafetensorsHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header map[string]any
|
||||
wantDtype string
|
||||
wantShape []int64
|
||||
wantErr bool
|
||||
name string
|
||||
header map[string]any
|
||||
wantDtype string
|
||||
wantShape []int64
|
||||
wantQuantType string
|
||||
wantGroupSize string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "simple tensor",
|
||||
@@ -307,7 +309,70 @@ func TestParseSafetensorsHeader(t *testing.T) {
|
||||
wantShape: []int64{2560, 262144},
|
||||
},
|
||||
{
|
||||
name: "with metadata",
|
||||
name: "tensor keyed by name",
|
||||
header: map[string]any{
|
||||
"model.layers.0.weight": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{2560, 2560},
|
||||
"data_offsets": []int64{0, 13107200},
|
||||
},
|
||||
},
|
||||
wantDtype: "BF16",
|
||||
wantShape: []int64{2560, 2560},
|
||||
},
|
||||
{
|
||||
name: "with int4 quant metadata",
|
||||
header: map[string]any{
|
||||
"__metadata__": map[string]any{
|
||||
"quant_type": "int4",
|
||||
"group_size": "32",
|
||||
},
|
||||
"model.layers.0.mlp.up_proj.weight": map[string]any{
|
||||
"dtype": "U32",
|
||||
"shape": []int64{2560, 320},
|
||||
"data_offsets": []int64{0, 3276800},
|
||||
},
|
||||
"model.layers.0.mlp.up_proj.weight.scale": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{2560, 80},
|
||||
"data_offsets": []int64{3276800, 3686400},
|
||||
},
|
||||
"model.layers.0.mlp.up_proj.weight.bias": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{2560, 80},
|
||||
"data_offsets": []int64{3686400, 4096000},
|
||||
},
|
||||
},
|
||||
wantDtype: "U32",
|
||||
wantShape: []int64{2560, 320},
|
||||
wantQuantType: "int4",
|
||||
wantGroupSize: "32",
|
||||
},
|
||||
{
|
||||
name: "int8 quant metadata",
|
||||
header: map[string]any{
|
||||
"__metadata__": map[string]any{
|
||||
"quant_type": "int8",
|
||||
"group_size": "64",
|
||||
},
|
||||
"model.layers.0.mlp.down_proj.weight": map[string]any{
|
||||
"dtype": "U32",
|
||||
"shape": []int64{2560, 640},
|
||||
"data_offsets": []int64{0, 6553600},
|
||||
},
|
||||
"model.layers.0.mlp.down_proj.weight.scale": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{2560, 40},
|
||||
"data_offsets": []int64{6553600, 6963200},
|
||||
},
|
||||
},
|
||||
wantDtype: "U32",
|
||||
wantShape: []int64{2560, 640},
|
||||
wantQuantType: "int8",
|
||||
wantGroupSize: "64",
|
||||
},
|
||||
{
|
||||
name: "with old-style format metadata",
|
||||
header: map[string]any{
|
||||
"__metadata__": map[string]any{
|
||||
"format": "pt",
|
||||
@@ -371,6 +436,13 @@ func TestParseSafetensorsHeader(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if info.QuantType != tt.wantQuantType {
|
||||
t.Errorf("QuantType = %v, want %v", info.QuantType, tt.wantQuantType)
|
||||
}
|
||||
if info.GroupSize != tt.wantGroupSize {
|
||||
t.Errorf("GroupSize = %v, want %v", info.GroupSize, tt.wantGroupSize)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -460,7 +532,7 @@ func TestGetTensorInfoFromManifest(t *testing.T) {
|
||||
t.Fatalf("failed to create blobs dir: %v", err)
|
||||
}
|
||||
|
||||
// Create test tensor blobs
|
||||
// Create test tensor blobs with __metadata__
|
||||
tensors := []struct {
|
||||
name string
|
||||
digest string
|
||||
@@ -487,10 +559,9 @@ func TestGetTensorInfoFromManifest(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
// Create blob files
|
||||
// Create blob files with tensor keyed by name
|
||||
var layers []manifest.Layer
|
||||
for _, tensor := range tensors {
|
||||
// Create safetensors blob
|
||||
header := map[string]any{
|
||||
tensor.name: map[string]any{
|
||||
"dtype": tensor.dtype,
|
||||
@@ -561,6 +632,391 @@ func TestGetTensorInfoFromManifest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTensorInfoFromManifest_Quantized(t *testing.T) {
|
||||
// Create a temp directory for blobs and set OLLAMA_MODELS
|
||||
tempDir := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", tempDir)
|
||||
|
||||
blobDir := filepath.Join(tempDir, "blobs")
|
||||
if err := os.MkdirAll(blobDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create blobs dir: %v", err)
|
||||
}
|
||||
|
||||
// Create a combined quantized blob with __metadata__
|
||||
header := map[string]any{
|
||||
"__metadata__": map[string]string{
|
||||
"quant_type": "int4",
|
||||
"group_size": "32",
|
||||
},
|
||||
"model.layers.0.mlp.up_proj.weight": map[string]any{
|
||||
"dtype": "U32",
|
||||
"shape": []int64{2560, 320}, // packed: 2560 / 8 = 320
|
||||
"data_offsets": []int64{0, 3276800},
|
||||
},
|
||||
"model.layers.0.mlp.up_proj.weight.scale": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{2560, 80}, // 2560 / 32 = 80
|
||||
"data_offsets": []int64{3276800, 3686400},
|
||||
},
|
||||
"model.layers.0.mlp.up_proj.weight.bias": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{2560, 80},
|
||||
"data_offsets": []int64{3686400, 4096000},
|
||||
},
|
||||
}
|
||||
headerJSON, _ := json.Marshal(header)
|
||||
|
||||
var buf bytes.Buffer
|
||||
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
|
||||
buf.Write(headerJSON)
|
||||
|
||||
digest := "sha256:aabb11aabb11aabb11aabb11aabb11aabb11aabb11aabb11aabb11aabb11aabb"
|
||||
blobPath, err := manifest.BlobsPath(digest)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get blob path: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(blobPath, buf.Bytes(), 0o644); err != nil {
|
||||
t.Fatalf("failed to write blob: %v", err)
|
||||
}
|
||||
|
||||
mf := &manifest.Manifest{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||
Layers: []manifest.Layer{
|
||||
{
|
||||
MediaType: manifest.MediaTypeImageTensor,
|
||||
Digest: digest,
|
||||
Size: int64(buf.Len() + 4096000),
|
||||
Name: "model.layers.0.mlp.up_proj.weight",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := getTensorInfoFromManifest(mf)
|
||||
if err != nil {
|
||||
t.Fatalf("getTensorInfoFromManifest() error = %v", err)
|
||||
}
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("got %d tensors, want 1", len(result))
|
||||
}
|
||||
|
||||
tensor := result[0]
|
||||
if tensor.Name != "model.layers.0.mlp.up_proj.weight" {
|
||||
t.Errorf("Name = %v, want model.layers.0.mlp.up_proj.weight", tensor.Name)
|
||||
}
|
||||
if tensor.Type != "INT4" {
|
||||
t.Errorf("Type = %v, want INT4", tensor.Type)
|
||||
}
|
||||
// Shape should be unpacked: 320 * 8 = 2560
|
||||
if len(tensor.Shape) != 2 || tensor.Shape[0] != 2560 || tensor.Shape[1] != 2560 {
|
||||
t.Errorf("Shape = %v, want [2560, 2560]", tensor.Shape)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSafetensorsAllHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header map[string]any
|
||||
wantCount int
|
||||
wantNames []string
|
||||
wantDtypes []string
|
||||
wantQuants []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "single tensor blob",
|
||||
header: map[string]any{
|
||||
"model.layers.0.weight": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{2560, 2560},
|
||||
"data_offsets": []int64{0, 13107200},
|
||||
},
|
||||
},
|
||||
wantCount: 1,
|
||||
wantNames: []string{"model.layers.0.weight"},
|
||||
wantDtypes: []string{"BF16"},
|
||||
wantQuants: []string{""},
|
||||
},
|
||||
{
|
||||
name: "packed unquantized blob",
|
||||
header: map[string]any{
|
||||
"model.layers.0.mlp.experts.0.down_proj.weight": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{2560, 10240},
|
||||
"data_offsets": []int64{0, 52428800},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.gate_proj.weight": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{10240, 2560},
|
||||
"data_offsets": []int64{52428800, 104857600},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.up_proj.weight": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{10240, 2560},
|
||||
"data_offsets": []int64{104857600, 157286400},
|
||||
},
|
||||
},
|
||||
wantCount: 3,
|
||||
wantNames: []string{
|
||||
"model.layers.0.mlp.experts.0.down_proj.weight",
|
||||
"model.layers.0.mlp.experts.0.gate_proj.weight",
|
||||
"model.layers.0.mlp.experts.0.up_proj.weight",
|
||||
},
|
||||
wantDtypes: []string{"BF16", "BF16", "BF16"},
|
||||
wantQuants: []string{"", "", ""},
|
||||
},
|
||||
{
|
||||
name: "packed quantized blob with global metadata",
|
||||
header: map[string]any{
|
||||
"__metadata__": map[string]any{
|
||||
"quant_type": "int4",
|
||||
"group_size": "32",
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.gate_proj.weight": map[string]any{
|
||||
"dtype": "U32",
|
||||
"shape": []int64{10240, 320},
|
||||
"data_offsets": []int64{0, 13107200},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.gate_proj.weight.scale": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{10240, 80},
|
||||
"data_offsets": []int64{13107200, 14745600},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.gate_proj.weight.bias": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{10240, 80},
|
||||
"data_offsets": []int64{14745600, 16384000},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.up_proj.weight": map[string]any{
|
||||
"dtype": "U32",
|
||||
"shape": []int64{10240, 320},
|
||||
"data_offsets": []int64{16384000, 29491200},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.up_proj.weight.scale": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{10240, 80},
|
||||
"data_offsets": []int64{29491200, 31129600},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.up_proj.weight.bias": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{10240, 80},
|
||||
"data_offsets": []int64{31129600, 32768000},
|
||||
},
|
||||
},
|
||||
wantCount: 2,
|
||||
wantNames: []string{
|
||||
"model.layers.0.mlp.experts.0.gate_proj.weight",
|
||||
"model.layers.0.mlp.experts.0.up_proj.weight",
|
||||
},
|
||||
wantDtypes: []string{"U32", "U32"},
|
||||
wantQuants: []string{"int4", "int4"},
|
||||
},
|
||||
{
|
||||
name: "packed mixed-precision blob (no global metadata)",
|
||||
header: map[string]any{
|
||||
"model.layers.0.mlp.experts.0.gate_proj.weight": map[string]any{
|
||||
"dtype": "U32",
|
||||
"shape": []int64{10240, 320},
|
||||
"data_offsets": []int64{0, 13107200},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.gate_proj.weight.scale": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{10240, 80},
|
||||
"data_offsets": []int64{13107200, 14745600},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.gate_proj.weight.bias": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{10240, 80},
|
||||
"data_offsets": []int64{14745600, 16384000},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.down_proj.weight": map[string]any{
|
||||
"dtype": "U32",
|
||||
"shape": []int64{2560, 2560},
|
||||
"data_offsets": []int64{16384000, 42598400},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.down_proj.weight.scale": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{2560, 160},
|
||||
"data_offsets": []int64{42598400, 43417600},
|
||||
},
|
||||
},
|
||||
wantCount: 2,
|
||||
wantNames: []string{
|
||||
"model.layers.0.mlp.experts.0.down_proj.weight",
|
||||
"model.layers.0.mlp.experts.0.gate_proj.weight",
|
||||
},
|
||||
wantDtypes: []string{"U32", "U32"},
|
||||
wantQuants: []string{"int8", "int4"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
headerJSON, err := json.Marshal(tt.header)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal header: %v", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON))); err != nil {
|
||||
t.Fatalf("failed to write header size: %v", err)
|
||||
}
|
||||
buf.Write(headerJSON)
|
||||
|
||||
results, err := parseSafetensorsAllHeaders(&buf)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("parseSafetensorsAllHeaders() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
|
||||
if len(results) != tt.wantCount {
|
||||
t.Fatalf("got %d tensors, want %d", len(results), tt.wantCount)
|
||||
}
|
||||
|
||||
for i, info := range results {
|
||||
if info.Name != tt.wantNames[i] {
|
||||
t.Errorf("tensor[%d].Name = %v, want %v", i, info.Name, tt.wantNames[i])
|
||||
}
|
||||
if info.Dtype != tt.wantDtypes[i] {
|
||||
t.Errorf("tensor[%d].Dtype = %v, want %v", i, info.Dtype, tt.wantDtypes[i])
|
||||
}
|
||||
if info.QuantType != tt.wantQuants[i] {
|
||||
t.Errorf("tensor[%d].QuantType = %v, want %v", i, info.QuantType, tt.wantQuants[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTensorInfoFromManifest_Packed(t *testing.T) {
|
||||
// Create a temp directory for blobs and set OLLAMA_MODELS
|
||||
tempDir := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", tempDir)
|
||||
|
||||
blobDir := filepath.Join(tempDir, "blobs")
|
||||
if err := os.MkdirAll(blobDir, 0o755); err != nil {
|
||||
t.Fatalf("failed to create blobs dir: %v", err)
|
||||
}
|
||||
|
||||
// Create a packed blob with multiple expert tensors (mixed quantization)
|
||||
header := map[string]any{
|
||||
"model.layers.0.mlp.experts.0.gate_proj.weight": map[string]any{
|
||||
"dtype": "U32",
|
||||
"shape": []int64{10240, 320},
|
||||
"data_offsets": []int64{0, 13107200},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.gate_proj.weight.scale": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{10240, 80},
|
||||
"data_offsets": []int64{13107200, 14745600},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.gate_proj.weight.bias": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{10240, 80},
|
||||
"data_offsets": []int64{14745600, 16384000},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.down_proj.weight": map[string]any{
|
||||
"dtype": "U32",
|
||||
"shape": []int64{2560, 2560},
|
||||
"data_offsets": []int64{16384000, 42598400},
|
||||
},
|
||||
"model.layers.0.mlp.experts.0.down_proj.weight.scale": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{2560, 160},
|
||||
"data_offsets": []int64{42598400, 43417600},
|
||||
},
|
||||
}
|
||||
headerJSON, _ := json.Marshal(header)
|
||||
|
||||
var buf bytes.Buffer
|
||||
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
|
||||
buf.Write(headerJSON)
|
||||
|
||||
packedDigest := "sha256:aaaa000000000000000000000000000000000000000000000000000000000001"
|
||||
blobPath, err := manifest.BlobsPath(packedDigest)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get blob path: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(blobPath, buf.Bytes(), 0o644); err != nil {
|
||||
t.Fatalf("failed to write packed blob: %v", err)
|
||||
}
|
||||
|
||||
// Also create a regular (single-tensor) blob
|
||||
singleHeader := map[string]any{
|
||||
"model.embed_tokens.weight": map[string]any{
|
||||
"dtype": "BF16",
|
||||
"shape": []int64{262144, 2560},
|
||||
"data_offsets": []int64{0, 1342177280},
|
||||
},
|
||||
}
|
||||
singleHeaderJSON, _ := json.Marshal(singleHeader)
|
||||
var singleBuf bytes.Buffer
|
||||
binary.Write(&singleBuf, binary.LittleEndian, uint64(len(singleHeaderJSON)))
|
||||
singleBuf.Write(singleHeaderJSON)
|
||||
|
||||
singleDigest := "sha256:bbbb000000000000000000000000000000000000000000000000000000000002"
|
||||
singleBlobPath, err := manifest.BlobsPath(singleDigest)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get blob path: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(singleBlobPath, singleBuf.Bytes(), 0o644); err != nil {
|
||||
t.Fatalf("failed to write single blob: %v", err)
|
||||
}
|
||||
|
||||
mf := &manifest.Manifest{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||
Layers: []manifest.Layer{
|
||||
{
|
||||
MediaType: manifest.MediaTypeImageTensor,
|
||||
Digest: singleDigest,
|
||||
Size: int64(singleBuf.Len()),
|
||||
Name: "model.embed_tokens.weight",
|
||||
},
|
||||
{
|
||||
MediaType: manifest.MediaTypeImageTensor,
|
||||
Digest: packedDigest,
|
||||
Size: int64(buf.Len()),
|
||||
Name: "model.layers.0.mlp.experts", // group prefix
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := getTensorInfoFromManifest(mf)
|
||||
if err != nil {
|
||||
t.Fatalf("getTensorInfoFromManifest() error = %v", err)
|
||||
}
|
||||
|
||||
// Should have 3 tensors: 1 single + 2 packed main tensors
|
||||
if len(result) != 3 {
|
||||
t.Fatalf("got %d tensors, want 3. Tensors: %v", len(result), result)
|
||||
}
|
||||
|
||||
// First tensor should be the single blob
|
||||
if result[0].Name != "model.embed_tokens.weight" {
|
||||
t.Errorf("tensor[0].Name = %v, want model.embed_tokens.weight", result[0].Name)
|
||||
}
|
||||
if result[0].Type != "BF16" {
|
||||
t.Errorf("tensor[0].Type = %v, want BF16", result[0].Type)
|
||||
}
|
||||
|
||||
// Packed tensors should have their actual names (sorted)
|
||||
packedNames := make(map[string]bool)
|
||||
for _, r := range result[1:] {
|
||||
packedNames[r.Name] = true
|
||||
}
|
||||
if !packedNames["model.layers.0.mlp.experts.0.down_proj.weight"] {
|
||||
t.Error("missing packed tensor: model.layers.0.mlp.experts.0.down_proj.weight")
|
||||
}
|
||||
if !packedNames["model.layers.0.mlp.experts.0.gate_proj.weight"] {
|
||||
t.Error("missing packed tensor: model.layers.0.mlp.experts.0.gate_proj.weight")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadSafetensorsHeader(t *testing.T) {
|
||||
// Create a temp file with a valid safetensors header
|
||||
tempDir := t.TempDir()
|
||||
|
||||
Reference in New Issue
Block a user