Compare commits

...

3 Commits

Author SHA1 Message Date
Michael Yang
b367596c58 s/tensor/array/g 2026-02-06 09:39:04 -08:00
Michael Yang
7c027625ef mlxrunner 2026-02-06 09:39:04 -08:00
Michael Yang
092ffbb2f6 draft: model manifest file interface
this change makes it easier to address blobs by their original names
without rebuilding the filesystem structure
2026-02-06 09:39:04 -08:00
45 changed files with 15135 additions and 18 deletions

5
go.mod
View File

@@ -13,7 +13,7 @@ require (
github.com/mattn/go-sqlite3 v1.14.24
github.com/olekukonko/tablewriter v0.0.5
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.9.0
github.com/stretchr/testify v1.10.0
github.com/x448/float16 v0.8.4
golang.org/x/sync v0.17.0
golang.org/x/sys v0.37.0
@@ -29,6 +29,8 @@ require (
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
github.com/tkrajina/typescriptify-golang-structs v0.2.0
github.com/tree-sitter/go-tree-sitter v0.25.0
github.com/tree-sitter/tree-sitter-cpp v0.23.4
github.com/wk8/go-ordered-map/v2 v2.1.8
golang.org/x/image v0.22.0
golang.org/x/mod v0.30.0
@@ -50,6 +52,7 @@ require (
github.com/google/flatbuffers v24.3.25+incompatible // indirect
github.com/kr/text v0.2.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-pointer v0.0.1 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.2.0 // indirect

31
go.sum
View File

@@ -152,6 +152,8 @@ github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-pointer v0.0.1 h1:n+XhsuGeVO6MEAp7xyEukFINEa+Quek5psIR/ylA6o0=
github.com/mattn/go-pointer v0.0.1/go.mod h1:2zXcozF6qYGgmsG+SeTZz3oAbFLdD3OWqnUbNvJZAlc=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
@@ -206,12 +208,39 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tkrajina/go-reflector v0.5.5 h1:gwoQFNye30Kk7NrExj8zm3zFtrGPqOkzFMLuQZg1DtQ=
github.com/tkrajina/go-reflector v0.5.5/go.mod h1:ECbqLgccecY5kPmPmXg1MrHW585yMcDkVl6IvJe64T4=
github.com/tkrajina/typescriptify-golang-structs v0.2.0 h1:ZedWk82egydDspGTryAatbX0/1NZDQbdiZLoCbOk4f8=
github.com/tkrajina/typescriptify-golang-structs v0.2.0/go.mod h1:sjU00nti/PMEOZb07KljFlR+lJ+RotsC0GBQMv9EKls=
github.com/tree-sitter/go-tree-sitter v0.25.0 h1:sx6kcg8raRFCvc9BnXglke6axya12krCJF5xJ2sftRU=
github.com/tree-sitter/go-tree-sitter v0.25.0/go.mod h1:r77ig7BikoZhHrrsjAnv8RqGti5rtSyvDHPzgTPsUuU=
github.com/tree-sitter/tree-sitter-c v0.23.4 h1:nBPH3FV07DzAD7p0GfNvXM+Y7pNIoPenQWBpvM++t4c=
github.com/tree-sitter/tree-sitter-c v0.23.4/go.mod h1:MkI5dOiIpeN94LNjeCp8ljXN/953JCwAby4bClMr6bw=
github.com/tree-sitter/tree-sitter-cpp v0.23.4 h1:LaWZsiqQKvR65yHgKmnaqA+uz6tlDJTJFCyFIeZU/8w=
github.com/tree-sitter/tree-sitter-cpp v0.23.4/go.mod h1:doqNW64BriC7WBCQ1klf0KmJpdEvfxyXtoEybnBo6v8=
github.com/tree-sitter/tree-sitter-embedded-template v0.23.2 h1:nFkkH6Sbe56EXLmZBqHHcamTpmz3TId97I16EnGy4rg=
github.com/tree-sitter/tree-sitter-embedded-template v0.23.2/go.mod h1:HNPOhN0qF3hWluYLdxWs5WbzP/iE4aaRVPMsdxuzIaQ=
github.com/tree-sitter/tree-sitter-go v0.23.4 h1:yt5KMGnTHS+86pJmLIAZMWxukr8W7Ae1STPvQUuNROA=
github.com/tree-sitter/tree-sitter-go v0.23.4/go.mod h1:Jrx8QqYN0v7npv1fJRH1AznddllYiCMUChtVjxPK040=
github.com/tree-sitter/tree-sitter-html v0.23.2 h1:1UYDV+Yd05GGRhVnTcbP58GkKLSHHZwVaN+lBZV11Lc=
github.com/tree-sitter/tree-sitter-html v0.23.2/go.mod h1:gpUv/dG3Xl/eebqgeYeFMt+JLOY9cgFinb/Nw08a9og=
github.com/tree-sitter/tree-sitter-java v0.23.5 h1:J9YeMGMwXYlKSP3K4Us8CitC6hjtMjqpeOf2GGo6tig=
github.com/tree-sitter/tree-sitter-java v0.23.5/go.mod h1:NRKlI8+EznxA7t1Yt3xtraPk1Wzqh3GAIC46wxvc320=
github.com/tree-sitter/tree-sitter-javascript v0.23.1 h1:1fWupaRC0ArlHJ/QJzsfQ3Ibyopw7ZfQK4xXc40Zveo=
github.com/tree-sitter/tree-sitter-javascript v0.23.1/go.mod h1:lmGD1EJdCA+v0S1u2fFgepMg/opzSg/4pgFym2FPGAs=
github.com/tree-sitter/tree-sitter-json v0.24.8 h1:tV5rMkihgtiOe14a9LHfDY5kzTl5GNUYe6carZBn0fQ=
github.com/tree-sitter/tree-sitter-json v0.24.8/go.mod h1:F351KK0KGvCaYbZ5zxwx/gWWvZhIDl0eMtn+1r+gQbo=
github.com/tree-sitter/tree-sitter-php v0.23.11 h1:iHewsLNDmznh8kgGyfWfujsZxIz1YGbSd2ZTEM0ZiP8=
github.com/tree-sitter/tree-sitter-php v0.23.11/go.mod h1:T/kbfi+UcCywQfUNAJnGTN/fMSUjnwPXA8k4yoIks74=
github.com/tree-sitter/tree-sitter-python v0.23.6 h1:qHnWFR5WhtMQpxBZRwiaU5Hk/29vGju6CVtmvu5Haas=
github.com/tree-sitter/tree-sitter-python v0.23.6/go.mod h1:cpdthSy/Yoa28aJFBscFHlGiU+cnSiSh1kuDVtI8YeM=
github.com/tree-sitter/tree-sitter-ruby v0.23.1 h1:T/NKHUA+iVbHM440hFx+lzVOzS4dV6z8Qw8ai+72bYo=
github.com/tree-sitter/tree-sitter-ruby v0.23.1/go.mod h1:kUS4kCCQloFcdX6sdpr8p6r2rogbM6ZjTox5ZOQy8cA=
github.com/tree-sitter/tree-sitter-rust v0.23.2 h1:6AtoooCW5GqNrRpfnvl0iUhxTAZEovEmLKDbyHlfw90=
github.com/tree-sitter/tree-sitter-rust v0.23.2/go.mod h1:hfeGWic9BAfgTrc7Xf6FaOAguCFJRo3RBbs7QJ6D7MI=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=

View File

@@ -4,6 +4,7 @@ import (
"github.com/ollama/ollama/runner/llamarunner"
"github.com/ollama/ollama/runner/ollamarunner"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/mlxrunner"
)
func Execute(args []string) error {
@@ -17,6 +18,8 @@ func Execute(args []string) error {
return ollamarunner.Execute(args[1:])
case "--imagegen-engine":
return imagegen.Execute(args[1:])
case "--mlx-engine":
return mlxrunner.Execute(args[1:])
}
}
return llamarunner.Execute(args)

View File

@@ -5,9 +5,13 @@ import (
"errors"
"fmt"
"log/slog"
"math/rand"
"os"
"os/exec"
"reflect"
"slices"
"sort"
"strconv"
"strings"
"sync"
"time"
@@ -22,6 +26,7 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/mlxrunner"
)
type LlmRequest struct {
@@ -195,25 +200,14 @@ func (s *Scheduler) processPending(ctx context.Context) {
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
}
// Check for image generation models - all use MLX runner
if slices.Contains(pending.model.Config.Capabilities, "image") {
if s.loadMLX(pending) {
// Check for experimental safetensors LLM models
if pending.model.Config.ModelFormat == "safetensors" {
if s.loadSafetensors(pending) {
break
}
continue
}
// Check for experimental safetensors LLM models
if pending.model.Config.ModelFormat == "safetensors" {
if slices.Contains(pending.model.Config.Capabilities, "completion") {
// LLM model with safetensors format - use MLX runner
if s.loadMLX(pending) {
break
}
continue
}
}
// Load model for fitting
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
@@ -563,9 +557,90 @@ iGPUScan:
return false
}
// loadMLX loads an experimental safetensors model using the unified MLX runner.
func subproc(args, environ []string) (*exec.Cmd, int, error) {
exe, err := os.Executable()
if err != nil {
return nil, 0, fmt.Errorf("unable to lookup executable path: %w", err)
}
for range 3 {
// get a random port in the ephemeral range
port := rand.Intn(65535-49152) + 49152
cmd := exec.Command(exe, slices.Concat([]string{"runner"}, args, []string{"--port", strconv.Itoa(port)})...)
cmd.Env = slices.Concat(os.Environ(), environ)
cmd.Stdout = os.Stderr
cmd.Stderr = os.Stderr
if err := cmd.Start(); err != nil {
continue
}
return cmd, port, nil
}
return nil, 0, fmt.Errorf("unable to start subprocess after multiple attempts")
}
func (s *Scheduler) loadSafetensors(req *LlmRequest) bool {
if slices.Contains(req.model.Config.Capabilities, "image") {
return s.loadImageGen(req)
}
args := []string{"--mlx-engine", "--model", req.model.ShortName}
environ := []string{}
cmd, port, err := subproc(args, environ)
if err != nil {
req.errCh <- fmt.Errorf("failed to start mlx subprocess: %w", err)
return true
}
sessionDuration := envconfig.KeepAlive()
if req.sessionDuration != nil {
sessionDuration = req.sessionDuration.Duration
}
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
Options: &req.opts,
loading: false,
sessionDuration: sessionDuration,
llama: &mlxrunner.Client{
Cmd: cmd,
Port: port,
},
}
s.loadedMu.Lock()
s.loaded[req.model.ModelPath] = runner
s.loadedMu.Unlock()
runner.refMu.Lock()
if sessionDuration > 0 {
runner.expireTimer = time.AfterFunc(sessionDuration, func() {
s.expiredCh <- runner
})
}
runner.refMu.Unlock()
req.useLoadedRunner(runner, s.finishedReqCh)
for range time.Tick(20 * time.Millisecond) {
if err := func() error {
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
return runner.llama.Ping(ctx)
}(); err != nil {
continue
}
break
}
return true
}
// loadImageGen loads an experimental safetensors model using the unified MLX runner.
// This supports both LLM (completion) and image generation models.
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
// Determine mode based on capabilities
var mode imagegen.ModelMode
if slices.Contains(req.model.Config.Capabilities, "image") {

View File

@@ -1,5 +1,14 @@
package tokenizer
import (
"encoding/json"
"errors"
"io"
"os"
"github.com/ollama/ollama/types/model"
)
const (
TOKEN_TYPE_NORMAL = iota + 1
TOKEN_TYPE_UNKNOWN
@@ -15,3 +24,287 @@ type Tokenizer interface {
Is(int32, Special) bool
Vocabulary() *Vocabulary
}
func New(root *model.Root) (Tokenizer, error) {
f, err := root.Open("tokenizer.json")
if err != nil {
return nil, err
}
defer f.Close()
var tokenizer struct {
Model struct {
Type string `json:"type"`
Vocab map[string]int32 `json:"vocab"`
Merges json.RawMessage `json:"merges"`
} `json:"model"`
PreTokenizer json.RawMessage `json:"pre_tokenizer"`
Decoder json.RawMessage `json:"decoder"`
AddedTokens []struct {
ID int32 `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
} `json:"added_tokens"`
}
if err := json.NewDecoder(f).Decode(&tokenizer); err != nil {
return nil, err
}
special := make(map[int32]struct{})
for _, token := range tokenizer.AddedTokens {
tokenizer.Model.Vocab[token.Content] = token.ID
special[token.ID] = struct{}{}
}
vocab, err := specialTokens(root, tokenizer.Model.Vocab)
if err != nil {
return nil, err
}
vocab.Values = make([]string, len(tokenizer.Model.Vocab))
vocab.Scores = make([]float32, len(tokenizer.Model.Vocab))
vocab.Types = make([]int32, len(tokenizer.Model.Vocab))
for content, id := range tokenizer.Model.Vocab {
vocab.Values[id] = content
vocab.Scores[id] = float32(id)
vocab.Types[id] = TOKEN_TYPE_NORMAL
if _, ok := special[id]; ok {
vocab.Types[id] = TOKEN_TYPE_USER_DEFINED
}
}
if tokenizer.Model.Merges != nil {
var pairs [][]string
if err := json.Unmarshal(tokenizer.Model.Merges, &pairs); err == nil {
vocab.Merges = make([]string, len(pairs))
for i, pair := range pairs {
vocab.Merges[i] = pair[0] + " " + pair[1]
}
} else if err := json.Unmarshal(tokenizer.Model.Merges, &vocab.Merges); err != nil {
return nil, err
}
}
vocab.valuesOnce.Do(func() {})
vocab.values = tokenizer.Model.Vocab
if tokenizer.Model.Type == "WordPiece" {
return NewWordPiece(vocab, true), nil
}
if tokenizer.Decoder != nil {
var decoder struct {
Type string `json:"type"`
Decoders []struct {
Type string `json:"type"`
Pattern struct {
String string `json:"string"`
} `json:"pattern"`
} `json:"decoders"`
}
if err := json.Unmarshal(tokenizer.Decoder, &decoder); err != nil {
return nil, err
}
if decoder.Type == "Sequence" {
for _, d := range decoder.Decoders {
if d.Type == "Replace" && d.Pattern.String == "▁" {
return NewSentencePiece(vocab), nil
}
}
}
}
var pretokenizers []string
if tokenizer.PreTokenizer != nil {
var pretokenizer struct {
Type string `json:"type"`
Pretokenizers []struct {
Type string `json:"type"`
Pattern struct {
Regex string
} `json:"pattern"`
IndividualDigits bool `json:"individual_digits"`
}
}
if err := json.Unmarshal(tokenizer.PreTokenizer, &pretokenizer); err != nil {
return nil, err
}
if pretokenizer.Type == "Sequence" {
for _, pretokenizer := range pretokenizer.Pretokenizers {
switch pretokenizer.Type {
case "Digits":
if pretokenizer.IndividualDigits {
pretokenizers = append(pretokenizers, `\d`)
} else {
pretokenizers = append(pretokenizers, `\d+`)
}
case "Punctuation":
pretokenizers = append(pretokenizers, `[^\p{L}\p{N}]+`)
case "Split":
pretokenizers = append(pretokenizers, pretokenizer.Pattern.Regex)
case "WhitespaceSplit":
pretokenizers = append(pretokenizers, `\s+(?!\S)|\s+`)
}
}
}
}
return NewBytePairEncoding(vocab, pretokenizers...), nil
}
// valueOrValues is a type that can unmarshal from either a single value or an array of values.
type valueOrValues[E any] []E
func (m *valueOrValues[E]) UnmarshalJSON(data []byte) error {
var s []E
if err := json.Unmarshal(data, &s); err != nil {
var e E
if err := json.Unmarshal(data, &e); err != nil {
return err
}
s = []E{e}
}
*m = valueOrValues[E](s)
return nil
}
type specialTokenIDs struct {
BOSTokenID valueOrValues[int32] `json:"bos_token_id"`
EOSTokenID valueOrValues[int32] `json:"eos_token_id"`
}
// stringOrContent is a type that can unmarshal from either a string or an object with a "content" field.
type stringOrContent string
func (t *stringOrContent) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
var m map[string]any
if err := json.Unmarshal(data, &m); err != nil {
return err
}
if content, ok := m["content"].(string); ok {
s = content
}
}
*t = stringOrContent(s)
return nil
}
func specialTokens(root *model.Root, values map[string]int32) (*Vocabulary, error) {
var vocab Vocabulary
for _, c := range []struct {
name string
fn func(io.Reader) error
}{
{
name: "generation_config.json",
fn: func(r io.Reader) error {
var c specialTokenIDs
if err := json.NewDecoder(r).Decode(&c); err != nil {
return err
}
vocab.BOS = c.BOSTokenID
vocab.EOS = c.EOSTokenID
return nil
},
},
{
name: "config.json",
fn: func(r io.Reader) error {
var c specialTokenIDs
if err := json.NewDecoder(r).Decode(&c); err != nil {
return err
}
if len(vocab.BOS) == 0 {
vocab.BOS = c.BOSTokenID
}
if len(vocab.EOS) == 0 {
vocab.EOS = c.EOSTokenID
}
return nil
},
},
{
name: "tokenizer_config.json",
fn: func(r io.Reader) error {
var c struct {
BOSToken stringOrContent `json:"bos_token"`
EOSToken stringOrContent `json:"eos_token"`
PADToken stringOrContent `json:"pad_token"`
AddBOSToken bool `json:"add_bos_token"`
AddEOSToken bool `json:"add_eos_token"`
}
if err := json.NewDecoder(r).Decode(&c); err != nil {
return err
}
if len(vocab.BOS) == 0 && c.BOSToken != "" {
if id, ok := values[string(c.BOSToken)]; ok {
vocab.BOS = []int32{id}
}
}
if len(vocab.EOS) == 0 && c.EOSToken != "" {
if id, ok := values[string(c.EOSToken)]; ok {
vocab.EOS = []int32{id}
}
}
vocab.AddBOS = c.AddBOSToken
vocab.AddEOS = c.AddEOSToken
return nil
},
},
{
name: "special_tokens_map.json",
fn: func(r io.Reader) error {
var c map[string]stringOrContent
if err := json.NewDecoder(r).Decode(&c); err != nil {
return err
}
if bos, ok := c["bos_token"]; ok && len(vocab.BOS) == 0 {
if id, ok := values[string(bos)]; ok {
vocab.BOS = []int32{id}
}
}
if eos, ok := c["eos_token"]; ok && len(vocab.EOS) == 0 {
if id, ok := values[string(eos)]; ok {
vocab.EOS = []int32{id}
}
}
return nil
},
},
} {
if err := func() error {
f, err := root.Open(c.name)
if errors.Is(err, os.ErrNotExist) {
return nil
} else if err != nil {
return err
}
defer f.Close()
return c.fn(f)
}(); err != nil {
return nil, err
}
}
return &vocab, nil
}

309
types/model/file.go Normal file
View File

@@ -0,0 +1,309 @@
package model
import (
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"hash"
"io"
"io/fs"
"iter"
"maps"
"mime"
"os"
"path/filepath"
"strings"
"github.com/ollama/ollama/envconfig"
)
func root() (*os.Root, error) {
root, err := os.OpenRoot(envconfig.Models())
if err != nil {
return nil, err
}
for _, sub := range []string{"manifests", "blobs"} {
if _, err := root.Stat(sub); errors.Is(err, fs.ErrNotExist) {
if err := root.MkdirAll(sub, 0o750); err != nil {
return nil, err
}
} else if err != nil {
return nil, err
}
}
return root, nil
}
// Open opens an existing file for reading. It will return [fs.ErrNotExist]
// if the file does not exist. The returned [*Root] can only be used for reading.
// It is the caller's responsibility to close the file when done.
func Open(n Name) (*Root, error) {
r, err := root()
if err != nil {
return nil, err
}
f, err := r.Open(filepath.Join("manifests", n.Filepath()))
if err != nil {
return nil, err
}
defer f.Close()
var m manifest
if err := json.NewDecoder(f).Decode(&m); err != nil {
return nil, err
}
blobs := make(map[string]*blob, len(m.Layers)+1)
blobs[NamePrefix] = m.Config
for _, layer := range m.Layers {
if layer.Name == "" && layer.MediaType != "" {
mediatype, _, err := mime.ParseMediaType(layer.MediaType)
if err != nil {
return nil, err
}
if suffix, ok := strings.CutPrefix(mediatype, MediaTypePrefix); ok {
layer.Name = NamePrefix + suffix
}
}
blobs[layer.Name] = layer
}
return &Root{
root: r,
name: n,
blobs: blobs,
flags: os.O_RDONLY,
}, nil
}
// Create creates a new file. The returned [Root] can be used for both reading
// and writing. It is the caller's responsibility to close the file when done
// in order to finalize any new blobs and write the manifest.
func Create(n Name) (*Root, error) {
r, err := root()
if err != nil {
return nil, err
}
return &Root{
root: r,
name: n,
blobs: make(map[string]*blob),
flags: os.O_RDWR,
}, nil
}
type blob struct {
Digest string `json:"digest"`
MediaType string `json:"mediaType"`
Name string `json:"name,omitempty"`
Size int64 `json:"size"`
// tempfile is the temporary file where the blob data is written.
tempfile *os.File
// hash is the hash.Hash used to compute the blob digest.
hash hash.Hash
}
func (b *blob) Write(p []byte) (int, error) {
return io.MultiWriter(b.tempfile, b.hash).Write(p)
}
func (b *blob) Filepath() string {
return strings.ReplaceAll(b.Digest, ":", "-")
}
type manifest struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config *blob `json:"config"`
Layers []*blob `json:"layers"`
}
// Root represents a model file. It can be used to read and write blobs
// associated with the model.
//
// Blobs are identified by name. Certain names are special and reserved;
// see [NamePrefix] for details.
type Root struct {
root *os.Root
name Name
blobs map[string]*blob
flags int
}
const MediaTypePrefix = "application/vnd.ollama"
// NamePrefix is the prefix used for identifying special names. Names
// with this prefix are idenfitied by their media types:
//
// - name: NamePrefix + suffix
// - mediaType: [MediaTypePrefix] + suffix
//
// For example:
//
// - name: "./..image.model"
// - mediaType: "application/vnd.ollama.image.model"
//
// NamePrefix by itself identifies the manifest config.
const NamePrefix = "./."
// Open opens the named blob for reading. It is the caller's responsibility
// to close the returned [io.ReadCloser] when done. It will return
// [fs.ErrNotExist] if the blob does not exist.
func (r Root) Open(name string) (io.ReadCloser, error) {
if b, ok := r.blobs[name]; ok {
r, err := r.root.Open(filepath.Join("blobs", b.Filepath()))
if err != nil {
return nil, err
}
return r, nil
}
return nil, fs.ErrNotExist
}
func (r Root) ReadFile(name string) ([]byte, error) {
f, err := r.Open(name)
if err != nil {
return nil, err
}
defer f.Close()
return io.ReadAll(f)
}
// Create creates or replaces a named blob in the file. If the blob already
// exists, it will be overwritten. It will return [fs.ErrInvalid] if the file
// was opened in read-only mode. The returned [io.Writer] can be used to write
// to the blob and does not need be closed, but the file must be closed to
// finalize the blob.
func (r *Root) Create(name string) (io.Writer, error) {
if r.flags&os.O_RDWR != 0 {
w, err := os.CreateTemp(r.root.Name(), "")
if err != nil {
return nil, err
}
r.blobs[name] = &blob{Name: name, tempfile: w, hash: sha256.New()}
return r.blobs[name], nil
}
return nil, fs.ErrInvalid
}
// Close closes the file. If the file was opened in read-write mode, it
// will finalize any writeable blobs and write the manifest.
func (r *Root) Close() error {
if r.flags&os.O_RDWR != 0 {
for _, b := range r.blobs {
if b.tempfile != nil {
fi, err := b.tempfile.Stat()
if err != nil {
return err
}
if err := b.tempfile.Close(); err != nil {
return err
}
b.Size = fi.Size()
b.Digest = fmt.Sprintf("sha256:%x", b.hash.Sum(nil))
if suffix, ok := strings.CutPrefix(b.Name, NamePrefix); ok {
if b.Name == NamePrefix {
b.MediaType = "application/vnd.docker.container.image.v1+json"
} else {
b.MediaType = MediaTypePrefix + suffix
}
b.Name = ""
}
rel, err := filepath.Rel(r.root.Name(), b.tempfile.Name())
if err != nil {
return err
}
if err := r.root.Rename(rel, filepath.Join("blobs", b.Filepath())); err != nil {
return err
}
}
}
p := filepath.Join("manifests", r.name.Filepath())
if _, err := r.root.Stat(filepath.Dir(p)); errors.Is(err, os.ErrNotExist) {
if err := r.root.MkdirAll(filepath.Dir(p), 0o750); err != nil {
return err
}
} else if err != nil {
return err
}
f, err := r.root.OpenFile(p, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o640)
if err != nil {
return err
}
defer f.Close()
if err := json.NewEncoder(f).Encode(manifest{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: r.blobs[NamePrefix],
Layers: func() []*blob {
blobs := make([]*blob, 0, len(r.blobs))
for name, b := range r.blobs {
if name != NamePrefix {
blobs = append(blobs, b)
}
}
return blobs
}(),
}); err != nil {
return err
}
}
return r.root.Close()
}
// Name returns the name of the file.
func (r Root) Name() Name {
return r.name
}
// Names returns an iterator over the names in the file.
func (r Root) Names() iter.Seq[string] {
return maps.Keys(r.blobs)
}
// Glob returns an iterator over the names in the file that match the given
// pattern.
//
// The pattern syntax is the same as [filepath.Match]. As with filepath.Match,
// the only possible returned error is ErrBadPattern, when pattern is malformed.
func (r Root) Glob(pattern string) (iter.Seq[string], error) {
if _, err := filepath.Match(pattern, ""); err != nil {
return nil, err
}
return func(yield func(string) bool) {
for name, blob := range r.blobs {
if matched, _ := filepath.Match(pattern, name); matched {
if !yield(blob.Filepath()) {
return
}
}
}
}, nil
}
func (r Root) JoinPath(parts ...string) string {
return filepath.Join(append([]string{r.root.Name()}, parts...)...)
}

90
types/model/file_test.go Normal file
View File

@@ -0,0 +1,90 @@
package model
import (
"io"
"strings"
"testing"
)
// setup is a helper function to set up the test environment.
func setup(t *testing.T, models map[Name]map[string]io.Reader) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
for m, s := range models {
f, err := Create(m)
if err != nil {
t.Fatal(err)
}
for n, r := range s {
w, err := f.Create(n)
if err != nil {
t.Fatal(err)
}
if _, err := io.Copy(w, r); err != nil {
t.Fatal(err)
}
}
if err := f.Close(); err != nil {
t.Fatal(err)
}
}
}
func TestOpen(t *testing.T) {
setup(t, map[Name]map[string]io.Reader{
ParseName("namespace/model"): {
"./.": strings.NewReader(`{"key":"value"}`),
},
ParseName("namespace/model:8b"): {
"./.": strings.NewReader(`{"foo":"bar"}`),
},
ParseName("another/model"): {
"./.": strings.NewReader(`{"another":"config"}`),
},
})
f, err := Open(ParseName("namespace/model"))
if err != nil {
t.Fatal(err)
}
for _, name := range []string{"./."} {
r, err := f.Open(name)
if err != nil {
t.Fatal(err)
}
if _, err := io.ReadAll(r); err != nil {
t.Fatal(err)
}
if err := r.Close(); err != nil {
t.Fatal(err)
}
}
if err := f.Close(); err != nil {
t.Fatal(err)
}
t.Run("does not exist", func(t *testing.T) {
if _, err := Open(ParseName("namespace/unknown")); err == nil {
t.Error("expected error for unknown model")
}
})
t.Run("write", func(t *testing.T) {
f, err := Open(ParseName("namespace/model"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
if _, err := f.Create("new-blob"); err == nil {
t.Error("expected error creating blob in read-only mode")
}
})
}

33
types/model/files.go Normal file
View File

@@ -0,0 +1,33 @@
package model
import (
"io/fs"
"iter"
"path/filepath"
)
func All() (iter.Seq[Name], error) {
r, err := root()
if err != nil {
return nil, err
}
manifests, err := r.OpenRoot("manifests")
if err != nil {
return nil, err
}
matches, err := fs.Glob(manifests.FS(), "*/*/*/*")
if err != nil {
return nil, err
}
return func(yield func(Name) bool) {
for _, match := range matches {
name := ParseNameFromFilepath(filepath.ToSlash(match))
if !yield(name) {
return
}
}
}, nil
}

View File

@@ -227,6 +227,17 @@ func (n Name) String() string {
return b.String()
}
// Set implements [flag.Value]. It parses the provided input as a name string
// and sets the receiver to the parsed value. If the parsed name is not valid,
// ErrUnqualifiedName is returned.
func (n *Name) Set(s string) error {
*n = ParseName(s)
if !n.IsValid() {
return ErrUnqualifiedName
}
return nil
}
// DisplayShortest returns a short string version of the name.
func (n Name) DisplayShortest() string {
var sb strings.Builder

94
x/mlxrunner/cache.go Normal file
View File

@@ -0,0 +1,94 @@
package mlxrunner
import (
"log/slog"
"github.com/ollama/ollama/x/mlxrunner/cache"
)
type CacheEntry struct {
Caches []cache.Cache
Count int
Entries map[int32]*CacheEntry
}
func (s Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
current := &CacheEntry{Entries: s.CacheEntries}
index, cacheIndex := 0, -1
for _, token := range tokens {
if _, ok := current.Entries[token]; !ok {
break
}
current = current.Entries[token]
if len(current.Caches) > 0 {
cacheIndex = index
}
index += 1
}
if cacheIndex == len(tokens)-1 {
slog.Info("Cache hit", "type", "exact", "total", len(tokens), "cached", len(tokens), "left", len(tokens))
return current.Caches, []int32{}
} else if cacheIndex > 1 {
slog.Info("Cache hit", "type", "partial", "total", len(tokens), "cached", cacheIndex+1, "left", len(tokens[cacheIndex+1:]))
return current.Caches, tokens[cacheIndex+1:]
} else if index > 0 && cacheIndex < 0 {
type stackItem struct {
entry *CacheEntry
tokens []int32
}
var best, item stackItem
stack := []stackItem{{entry: current, tokens: []int32{}}}
for len(stack) > 0 {
item, stack = stack[len(stack)-1], stack[:len(stack)-1]
if len(item.entry.Caches) > 0 {
if len(best.tokens) == 0 || len(item.tokens) < len(best.tokens) {
best = item
}
} else {
for token, entry := range item.entry.Entries {
stack = append(stack, stackItem{
entry: entry,
tokens: append(item.tokens, token),
})
}
}
}
prefix := min(len(tokens)-1, index)
caches := make([]cache.Cache, len(best.entry.Caches))
trim := len(best.tokens)+1
for i := range caches {
caches[i] = best.entry.Caches[i].Clone()
caches[i].Trim(trim)
}
slog.Info("Cache hit", "type", "prefix", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]), "trimmed", trim)
return caches, tokens[prefix:]
}
slog.Info("Cache miss", "left", len(tokens))
return nil, tokens
}
func (s *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
current := &CacheEntry{Entries: s.CacheEntries}
for _, token := range tokens {
if _, ok := current.Entries[token]; !ok {
current.Entries[token] = &CacheEntry{
Entries: make(map[int32]*CacheEntry),
}
}
current = current.Entries[token]
}
if len(current.Caches) > 0 {
current.Count += 1
} else {
current.Caches = caches
}
}

196
x/mlxrunner/cache/cache.go vendored Normal file
View File

@@ -0,0 +1,196 @@
package cache
import (
"log/slog"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
type Cache interface {
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
State() (keys, values *mlx.Array)
Trim(int) int
Clone() Cache
Offset() int
Len() int
}
type KVCache struct {
keys, values *mlx.Array
offset int
step int
}
func NewKVCache() *KVCache {
return &KVCache{step: 256, keys: &mlx.Array{}, values: &mlx.Array{}}
}
func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
prev := c.offset
// Grow buffer if needed
if !c.keys.Valid() || (prev+L) > c.keys.Dim(2) {
steps := (c.step + L - 1) / c.step
newKeys := mlx.Zeros(keys.DType(), B, H, steps*c.step, Dk)
newValues := mlx.Zeros(values.DType(), B, H, steps*c.step, Dv)
if c.keys.Valid() {
if prev%c.step != 0 {
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice()))
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice()))
}
c.keys.Set(c.keys.Concatenate(2, newKeys))
c.values.Set(c.values.Concatenate(2, newValues))
} else {
c.keys, c.values = newKeys, newValues
}
}
c.offset += L
c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
}
func (c *KVCache) State() (*mlx.Array, *mlx.Array) {
if c.offset == c.keys.Dim(2) {
return c.keys, c.values
}
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
}
func (c *KVCache) Trim(n int) int {
n = min(c.offset, n)
c.offset -= n
return n
}
func (c *KVCache) Clone() Cache {
return &KVCache{
keys: c.keys.Clone(),
values: c.values.Clone(),
offset: c.offset,
step: c.step,
}
}
func (c *KVCache) Offset() int { return c.offset }
func (c *KVCache) Len() int { return c.offset }
// RotatingKVCache implements sliding window attention with bounded memory
type RotatingKVCache struct {
maxSize int
idx int
*KVCache
}
func NewRotatingKVCache(maxSize int) *RotatingKVCache {
return &RotatingKVCache{maxSize: maxSize, KVCache: NewKVCache()}
}
func (c *RotatingKVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
if keys.Dim(2) > 1 {
return c.concat(keys, values)
}
return c.update(keys, values)
}
func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV *mlx.Array) {
slog.Debug("(*RotatingKVCache).concat", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
if !c.keys.Valid() {
c.keys, c.values = keys, values
} else {
if c.idx < c.keys.Dim(2) {
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
}
// Trim to max_size to maintain sliding window
if trim := c.idx - c.maxSize + 1; trim > 0 {
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.keys.Dim(2)), mlx.Slice()))
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.values.Dim(2)), mlx.Slice()))
}
c.keys.Set(c.keys.Concatenate(2, keys))
c.values.Set(c.values.Concatenate(2, values))
c.idx = c.keys.Dim(2)
}
c.offset += keys.Dim(2)
c.idx = c.keys.Dim(2)
return c.keys, c.values
}
func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
slog.Debug("(*RotatingKVCache).update", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
prev := c.offset
// Grow buffer if not yet at max
if !c.keys.Valid() || (prev >= c.keys.Dim(2) && c.keys.Dim(2) < c.maxSize) {
newSize := min(c.step, c.maxSize-prev)
newKeys := mlx.Zeros(keys.DType(), B, H, newSize, Dk)
newValues := mlx.Zeros(values.DType(), B, H, newSize, Dv)
if c.keys.Valid() {
c.keys.Set(c.keys.Concatenate(2, newKeys))
c.values.Set(c.values.Concatenate(2, newValues))
} else {
c.keys, c.values = newKeys, newValues
}
c.idx = prev
}
// Trim to max_size to maintain sliding window
if trim := c.keys.Dim(2) - c.maxSize; trim > 0 {
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.keys.Dim(2)), mlx.Slice()))
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.values.Dim(2)), mlx.Slice()))
c.idx = c.maxSize
}
// Rotate when hitting max
if c.idx >= c.maxSize {
c.idx = 0
}
c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.idx+L), mlx.Slice()))
c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.idx+L), mlx.Slice()))
c.offset += L
c.idx += L
validLen := min(c.offset, c.maxSize)
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice())
}
func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) {
if c.offset < c.keys.Dim(2) {
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
}
return c.keys, c.values
}
func (c *RotatingKVCache) Trim(n int) int {
n = min(c.offset, n)
c.offset -= n
c.idx -= n
return n
}
func (c *RotatingKVCache) Clone() Cache {
return &RotatingKVCache{
maxSize: c.maxSize,
idx: c.idx,
KVCache: c.KVCache.Clone().(*KVCache),
}
}
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }

174
x/mlxrunner/client.go Normal file
View File

@@ -0,0 +1,174 @@
package mlxrunner
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"math"
"net"
"net/http"
"net/url"
"os/exec"
"strconv"
"strings"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
)
type Client struct {
Port int
*exec.Cmd
}
func (c *Client) JoinPath(path string) string {
return (&url.URL{
Scheme: "http",
Host: net.JoinHostPort("127.0.0.1", strconv.Itoa(c.Port)),
}).JoinPath(path).String()
}
func (c *Client) CheckError(w *http.Response) error {
if w.StatusCode >= 400 {
return errors.New(w.Status)
}
return nil
}
// Close implements llm.LlamaServer.
func (c *Client) Close() error {
return c.Cmd.Process.Kill()
}
// Completion implements llm.LlamaServer.
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(req); err != nil {
return err
}
w, err := http.Post(c.JoinPath("/v1/completions"), "application/json", &b)
if err != nil {
return err
}
defer w.Body.Close()
if err := c.CheckError(w); err != nil {
return err
}
scanner := bufio.NewScanner(w.Body)
for scanner.Scan() {
bts := scanner.Bytes()
var resp llm.CompletionResponse
if err := json.Unmarshal(bts, &resp); err != nil {
return err
}
fn(resp)
}
return nil
}
func (c *Client) ContextLength() int {
return math.MaxInt
}
// Detokenize implements llm.LlamaServer.
func (c *Client) Detokenize(ctx context.Context, tokens []int) (string, error) {
panic("unimplemented")
}
// Embedding implements llm.LlamaServer.
func (c *Client) Embedding(ctx context.Context, input string) ([]float32, int, error) {
panic("unimplemented")
}
// GetDeviceInfos implements llm.LlamaServer.
func (c *Client) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
panic("unimplemented")
}
// GetPort implements llm.LlamaServer.
func (c *Client) GetPort() int {
return c.Port
}
// HasExited implements llm.LlamaServer.
func (c *Client) HasExited() bool {
panic("unimplemented")
}
// Load implements llm.LlamaServer.
func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) ([]ml.DeviceID, error) {
w, err := http.Post(c.JoinPath("/v1/models"), "application/json", nil)
if err != nil {
return nil, err
}
defer w.Body.Close()
return []ml.DeviceID{}, nil
}
// ModelPath implements llm.LlamaServer.
func (c *Client) ModelPath() string {
panic("unimplemented")
}
// Pid implements llm.LlamaServer.
func (c *Client) Pid() int {
panic("unimplemented")
}
// Ping implements llm.LlamaServer.
func (c *Client) Ping(ctx context.Context) error {
w, err := http.Get(c.JoinPath("/v1/status"))
if err != nil {
return err
}
defer w.Body.Close()
return nil
}
// Tokenize implements llm.LlamaServer.
func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
w, err := http.Post(c.JoinPath("/v1/tokenize"), "text/plain", strings.NewReader(content))
if err != nil {
return nil, err
}
defer w.Body.Close()
var tokens []int
if err := json.NewDecoder(w.Body).Decode(&tokens); err != nil {
return nil, err
}
return tokens, nil
}
// TotalSize implements llm.LlamaServer.
func (c *Client) TotalSize() uint64 {
panic("unimplemented")
}
// VRAMByGPU implements llm.LlamaServer.
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
panic("unimplemented")
}
// VRAMSize implements llm.LlamaServer.
func (c *Client) VRAMSize() uint64 {
panic("unimplemented")
}
// WaitUntilRunning implements llm.LlamaServer.
func (c *Client) WaitUntilRunning(ctx context.Context) error {
panic("unimplemented")
}
var _ llm.LlamaServer = (*Client)(nil)

3
x/mlxrunner/mlx/.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
_deps
build
dist

View File

@@ -0,0 +1,26 @@
cmake_minimum_required(VERSION 3.5)
project(mlx)
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
set(CMAKE_INSTALL_PREFIX "${CMAKE_CURRENT_SOURCE_DIR}/dist" CACHE PATH "" FORCE)
endif()
set(MLX_BUILD_GGUF OFF CACHE BOOL "" FORCE)
set(MLX_BUILD_SAFETENSORS ON CACHE BOOL "" FORCE)
set(MLX_C_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
set(CMAKE_INSTALL_RPATH "@loader_path")
include(FetchContent)
set(MLX_C_GIT_TAG "v0.4.0" CACHE STRING "")
FetchContent_Declare(
mlx-c
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
GIT_TAG ${MLX_C_GIT_TAG}
)
FetchContent_MakeAvailable(mlx-c)

21
x/mlxrunner/mlx/act.go Normal file
View File

@@ -0,0 +1,21 @@
package mlx
// #include "generated.h"
import "C"
import "math"
func GELUApprox(t *Array) *Array {
return t.Multiply(
FromValue[float32](0.5),
).Multiply(
t.Add(
t.Power(FromValue[float32](3.0)).Multiply(FromValue[float32](0.044715)),
).Multiply(
FromValue(float32(math.Sqrt(2 / math.Pi))),
).Tanh().Add(FromValue[float32](1.0)),
).AsType(t.DType())
}
func SILU(t *Array) *Array {
return t.Multiply(t.Sigmoid()).AsType(t.DType())
}

264
x/mlxrunner/mlx/array.go Normal file
View File

@@ -0,0 +1,264 @@
package mlx
// #include "generated.h"
import "C"
import (
"encoding/binary"
"log/slog"
"reflect"
"strings"
"time"
"unsafe"
"github.com/ollama/ollama/logutil"
)
type tensorDesc struct {
name string
inputs []*Array
numRefs int
}
func (d tensorDesc) LogValue() slog.Value {
return slog.GroupValue(
slog.String("name", d.name),
slog.Int("inputs", len(d.inputs)),
slog.Int("num_refs", d.numRefs),
)
}
type Array struct {
ctx C.mlx_array
desc tensorDesc
}
// constructor utilities
func New(name string, inputs ...*Array) *Array {
t := &Array{
desc: tensorDesc{
name: name,
inputs: inputs,
},
}
for _, input := range inputs {
input.desc.numRefs++
}
logutil.Trace("New", "t", t)
return t
}
type scalarTypes interface {
~bool | ~int | ~float32 | ~float64 | ~complex64
}
func FromValue[T scalarTypes](t T) *Array {
tt := New("")
switch v := any(t).(type) {
case bool:
tt.ctx = C.mlx_array_new_bool(C.bool(v))
case int:
tt.ctx = C.mlx_array_new_int(C.int(v))
case float32:
tt.ctx = C.mlx_array_new_float32(C.float(v))
case float64:
tt.ctx = C.mlx_array_new_float64(C.double(v))
case complex64:
tt.ctx = C.mlx_array_new_complex(C.float(real(v)), C.float(imag(v)))
default:
panic("unsupported type")
}
return tt
}
type arrayTypes interface {
~bool | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
~int8 | ~int16 | ~int32 | ~int64 |
~float32 | ~float64 |
~complex64
}
func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
if len(shape) == 0 {
panic("shape must be provided for non-scalar tensors")
}
cShape := make([]C.int, len(shape))
for i := range shape {
cShape[i] = C.int(shape[i])
}
var dtype DType
switch reflect.TypeOf(s).Elem().Kind() {
case reflect.Bool:
dtype = DTypeBool
case reflect.Uint8:
dtype = DTypeUint8
case reflect.Uint16:
dtype = DTypeUint16
case reflect.Uint32:
dtype = DTypeUint32
case reflect.Uint64:
dtype = DTypeUint64
case reflect.Int8:
dtype = DTypeInt8
case reflect.Int16:
dtype = DTypeInt16
case reflect.Int32:
dtype = DTypeInt32
case reflect.Int64:
dtype = DTypeInt64
case reflect.Float32:
dtype = DTypeFloat32
case reflect.Float64:
dtype = DTypeFloat64
case reflect.Complex64:
dtype = DTypeComplex64
default:
panic("unsupported type")
}
bts := make([]byte, binary.Size(s))
if _, err := binary.Encode(bts, binary.LittleEndian, s); err != nil {
panic(err)
}
tt := New("")
tt.ctx = C.mlx_array_new_data(unsafe.Pointer(&bts[0]), unsafe.SliceData(cShape), C.int(len(cShape)), C.mlx_dtype(dtype))
return tt
}
func (t *Array) Set(other *Array) {
other.desc.numRefs++
t.desc.inputs = []*Array{other}
C.mlx_array_set(&t.ctx, other.ctx)
}
func (t *Array) Clone() *Array {
tt := New(t.desc.name, t.desc.inputs...)
C.mlx_array_set(&tt.ctx, t.ctx)
return tt
}
// misc. utilities
func (t *Array) Valid() bool {
return t.ctx.ctx != nil
}
func (t *Array) String() string {
str := C.mlx_string_new()
defer C.mlx_string_free(str)
C.mlx_array_tostring(&str, t.ctx)
return strings.TrimSpace(C.GoString(C.mlx_string_data(str)))
}
func (t *Array) LogValue() slog.Value {
attrs := []slog.Attr{slog.Any("", t.desc)}
if t.Valid() {
attrs = append(attrs,
slog.Any("dtype", t.DType()),
slog.Any("shape", t.Dims()),
slog.Int("num_bytes", t.NumBytes()),
)
}
return slog.GroupValue(attrs...)
}
// shape utilities
func (t Array) Size() int {
return int(C.mlx_array_size(t.ctx))
}
func (t Array) NumBytes() int {
return int(C.mlx_array_nbytes(t.ctx))
}
func (t Array) NumDims() int {
return int(C.mlx_array_ndim(t.ctx))
}
func (t Array) Dims() []int {
dims := make([]int, t.NumDims())
for i := range dims {
dims[i] = t.Dim(i)
}
return dims
}
func (t Array) Dim(dim int) int {
return int(C.mlx_array_dim(t.ctx, C.int(dim)))
}
func (t Array) DType() DType {
return DType(C.mlx_array_dtype(t.ctx))
}
// data utilities
func (t Array) Int() int {
var item C.int64_t
C.mlx_array_item_int64(&item, t.ctx)
return int(item)
}
func (t Array) Float() float64 {
var item C.double
C.mlx_array_item_float64(&item, t.ctx)
return float64(item)
}
func (t Array) Ints() []int {
ints := make([]int, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) {
ints[i] = int(f)
}
return ints
}
func (t Array) Floats() []float32 {
floats := make([]float32, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) {
floats[i] = float32(f)
}
return floats
}
func Free(s ...*Array) (n int) {
now := time.Now()
defer func() {
if n > 0 {
logutil.Trace("Freed tensors", "num_bytes", PrettyBytes(n), "took", time.Since(now))
}
}()
free := make([]*Array, 0, 8192)
fn := func(t *Array) {
if t.Valid() {
free = append(free, t.desc.inputs...)
t.desc.numRefs--
if t.desc.numRefs <= 0 {
logutil.Trace("Free", "t", t)
n += t.NumBytes()
C.mlx_array_free(t.ctx)
t.ctx.ctx = nil
}
}
}
for _, t := range s {
fn(t)
}
for len(free) > 0 {
tail := free[len(free)-1]
free = free[:len(free)-1]
fn(tail)
}
return n
}

View File

@@ -0,0 +1,43 @@
package mlx
import "testing"
func TestFromValue(t *testing.T) {
for got, want := range map[*Tensor]DType{
FromValue(true): DTypeBool,
FromValue(false): DTypeBool,
FromValue(int(7)): DTypeInt32,
FromValue(float32(3.14)): DTypeFloat32,
FromValue(float64(2.71)): DTypeFloat64,
FromValue(complex64(1 + 2i)): DTypeComplex64,
} {
t.Run(want.String(), func(t *testing.T) {
if got.DType() != want {
t.Errorf("want %v, got %v", want, got)
}
})
}
}
func TestFromValues(t *testing.T) {
for got, want := range map[*Tensor]DType{
FromValues([]bool{true, false, true}, 3): DTypeBool,
FromValues([]uint8{1, 2, 3}, 3): DTypeUint8,
FromValues([]uint16{1, 2, 3}, 3): DTypeUint16,
FromValues([]uint32{1, 2, 3}, 3): DTypeUint32,
FromValues([]uint64{1, 2, 3}, 3): DTypeUint64,
FromValues([]int8{-1, -2, -3}, 3): DTypeInt8,
FromValues([]int16{-1, -2, -3}, 3): DTypeInt16,
FromValues([]int32{-1, -2, -3}, 3): DTypeInt32,
FromValues([]int64{-1, -2, -3}, 3): DTypeInt64,
FromValues([]float32{3.14, 2.71, 1.61}, 3): DTypeFloat32,
FromValues([]float64{3.14, 2.71, 1.61}, 3): DTypeFloat64,
FromValues([]complex64{1 + 2i, 3 + 4i, 5 + 6i}, 3): DTypeComplex64,
} {
t.Run(want.String(), func(t *testing.T) {
if got.DType() != want {
t.Errorf("want %v, got %v", want, got)
}
})
}
}

94
x/mlxrunner/mlx/dtype.go Normal file
View File

@@ -0,0 +1,94 @@
package mlx
// #include "generated.h"
import "C"
type DType int
func (t DType) String() string {
switch t {
case DTypeBool:
return "BOOL"
case DTypeUint8:
return "U8"
case DTypeUint16:
return "U16"
case DTypeUint32:
return "U32"
case DTypeUint64:
return "U64"
case DTypeInt8:
return "I8"
case DTypeInt16:
return "I16"
case DTypeInt32:
return "I32"
case DTypeInt64:
return "I64"
case DTypeFloat16:
return "F16"
case DTypeFloat32:
return "F32"
case DTypeFloat64:
return "F64"
case DTypeBFloat16:
return "BF16"
case DTypeComplex64:
return "C64"
default:
return "Unknown"
}
}
func (t *DType) UnmarshalJSON(b []byte) error {
switch string(b) {
case `"BOOL"`:
*t = DTypeBool
case `"U8"`:
*t = DTypeUint8
case `"U16"`:
*t = DTypeUint16
case `"U32"`:
*t = DTypeUint32
case `"U64"`:
*t = DTypeUint64
case `"I8"`:
*t = DTypeInt8
case `"I16"`:
*t = DTypeInt16
case `"I32"`:
*t = DTypeInt32
case `"I64"`:
*t = DTypeInt64
case `"F16"`:
*t = DTypeFloat16
case `"F64"`:
*t = DTypeFloat64
case `"F32"`:
*t = DTypeFloat32
case `"BF16"`:
*t = DTypeBFloat16
case `"C64"`:
*t = DTypeComplex64
default:
return nil
}
return nil
}
const (
DTypeBool DType = C.MLX_BOOL
DTypeUint8 DType = C.MLX_UINT8
DTypeUint16 DType = C.MLX_UINT16
DTypeUint32 DType = C.MLX_UINT32
DTypeUint64 DType = C.MLX_UINT64
DTypeInt8 DType = C.MLX_INT8
DTypeInt16 DType = C.MLX_INT16
DTypeInt32 DType = C.MLX_INT32
DTypeInt64 DType = C.MLX_INT64
DTypeFloat16 DType = C.MLX_FLOAT16
DTypeFloat32 DType = C.MLX_FLOAT32
DTypeFloat64 DType = C.MLX_FLOAT64
DTypeBFloat16 DType = C.MLX_BFLOAT16
DTypeComplex64 DType = C.MLX_COMPLEX64
)

34
x/mlxrunner/mlx/dynamic.c Normal file
View File

@@ -0,0 +1,34 @@
#include "dynamic.h"
#include <stdio.h>
#ifdef _WIN32
#include <windows.h>
#define DLOPEN(path) LoadLibraryA(path)
#define DLCLOSE(handle) FreeLibrary((HMODULE)(handle))
#else
#ifdef __APPLE__
#include <mach-o/dyld.h>
#include <libgen.h>
#endif
#include <dlfcn.h>
#define DLOPEN(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
#define DLCLOSE(handle) dlclose(handle)
#endif
static int mlx_dynamic_open(mlx_dynamic_handle* handle, const char* path) {
handle->ctx = (void*) DLOPEN(path);
CHECK(handle->ctx != NULL);
return 0;
}
int mlx_dynamic_load(mlx_dynamic_handle* handle, const char *path) {
return mlx_dynamic_open(handle, path);
}
void mlx_dynamic_unload(mlx_dynamic_handle* handle) {
if (handle->ctx) {
DLCLOSE(handle->ctx);
handle->ctx = NULL;
}
}

View File

@@ -0,0 +1,63 @@
package mlx
// #include "dynamic.h"
// #include "generated.h"
// #include <stdlib.h>
import "C"
import (
"io/fs"
"log/slog"
"os"
"path/filepath"
"runtime"
"unsafe"
)
func init() {
switch runtime.GOOS {
case "darwin":
case "windows":
default:
return
}
paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH")
if !ok {
slog.Debug("OLLAMA_LIBRARY_PATH not set, skipping mlx dynamic loading")
return
}
for _, path := range filepath.SplitList(paths) {
matches, err := fs.Glob(os.DirFS(path), "libmlxc.*")
if err != nil {
panic(err)
}
for _, match := range matches {
path := filepath.Join(paths, match)
slog.Info("Loading MLX dynamic library", "path", path)
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
var handle C.mlx_dynamic_handle
if C.mlx_dynamic_load(&handle, cPath) != 0 {
slog.Error("Failed to load MLX dynamic library", "path", path)
continue
}
if C.mlx_dynamic_load_symbols(handle) != 0 {
slog.Error("Failed to load MLX dynamic library symbols", "path", path)
C.mlx_dynamic_unload(&handle)
continue
}
slog.Info("Loaded MLX dynamic library", "path", path)
return
}
}
panic("Failed to load any MLX dynamic library")
}

27
x/mlxrunner/mlx/dynamic.h Normal file
View File

@@ -0,0 +1,27 @@
#ifndef MLX_DYNAMIC_H
#define MLX_DYNAMIC_H
#ifdef _WIN32
#include <windows.h>
#define DLSYM(handle, symbol) GetProcAddress((HMODULE)(handle), symbol)
#else
#include <dlfcn.h>
#define DLSYM(handle, symbol) dlsym(handle.ctx, symbol)
#endif
#define ERROR(fmt, ...) fprintf(stderr, "%s %s - ERROR - %s:%d - " fmt "\n", __DATE__, __TIME__, __FILE__, __LINE__, ##__VA_ARGS__); return 1
#define CHECK(x) if (!(x)) { ERROR("CHECK failed: " #x); }
#define CHECK_LOAD(handle, x) x##_ = DLSYM(handle, #x); CHECK(x##_)
typedef struct {
void* ctx;
} mlx_dynamic_handle;
int mlx_dynamic_load(
mlx_dynamic_handle* handle,
const char *path);
void mlx_dynamic_unload(
mlx_dynamic_handle* handle);
#endif // MLX_DYNAMIC_H

72
x/mlxrunner/mlx/fast.go Normal file
View File

@@ -0,0 +1,72 @@
package mlx
// #include "generated.h"
import "C"
import (
"unsafe"
)
func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *Array {
if mask == nil {
mask = New("")
}
sinks := New("")
mode := "causal"
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
out := New("FAST_SDPA", query, key, value, mask, sinks)
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
return out
}
type LayerNorm struct {
Weight Array `weight:"weight"`
Bias Array `weight:"bias"`
}
func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
out := New("FAST_LAYERNORM", x)
C.mlx_fast_layer_norm(&out.ctx, x.ctx, r.Weight.ctx, r.Bias.ctx, C.float(eps), DefaultStream().ctx)
return out
}
type RMSNorm struct {
Weight Array `weight:"weight"`
}
func (r RMSNorm) Forward(x *Array, eps float32) *Array {
out := New("FAST_RMSNORM", x)
C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx)
return out
}
type RoPE struct {
Dims int
Traditional bool
Base float32 `json:"rope_theta"`
Scale float32
}
func (r RoPE) Forward(t *Array, offset int) *Array {
freqs := New("")
out := New("FAST_ROPE", t, freqs)
C.mlx_fast_rope(
&out.ctx,
t.ctx,
C.int(r.Dims),
C._Bool(r.Traditional),
C.mlx_optional_float{
value: C.float(r.Base),
has_value: C._Bool(func() bool { return r.Base != 0 }()),
},
C.float(r.Scale),
C.int(offset),
freqs.ctx,
DefaultStream().ctx,
)
return out
}

6568
x/mlxrunner/mlx/generated.c Normal file
View File

File diff suppressed because it is too large Load Diff

4872
x/mlxrunner/mlx/generated.h Normal file
View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,24 @@
// This code is auto-generated; DO NOT EDIT.
#include "generated.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
{{ range .Functions }}
{{ .Type }} (*{{ .Name }}_){{ .Parameters }} = NULL;
{{- end }}
int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
{{- range .Functions }}
CHECK_LOAD(handle, {{ .Name }});
{{- end }}
return 0;
}
{{- range .Functions }}
{{ .Type }} {{ .Name }}{{ .Parameters }} {{ "{" }}
return {{ .Name }}_({{ .Args }});
{{ "}" }}
{{- end }}

View File

@@ -0,0 +1,20 @@
// This code is auto-generated; DO NOT EDIT.
#ifndef MLX_GENERATED_H
#define MLX_GENERATED_H
#include "dynamic.h"
#include "mlx/c/mlx.h"
{{ range .Functions }}
#undef {{ .Name }}
{{- end }}
{{ range .Functions }}
extern {{ .Type }} (*{{ .Name }}_){{ .Parameters }};
{{- end }}
int mlx_dynamic_load_symbols(mlx_dynamic_handle handle);
{{ range .Functions }}
{{ .Type }} {{ .Name }}{{ .Parameters }};
{{- end }}
#endif // MLX_GENERATED_H

View File

@@ -0,0 +1,135 @@
package main
import (
"embed"
"flag"
"fmt"
"os"
"path/filepath"
"slices"
"strings"
"text/template"
tree_sitter "github.com/tree-sitter/go-tree-sitter"
tree_sitter_cpp "github.com/tree-sitter/tree-sitter-cpp/bindings/go"
)
//go:embed *.gotmpl
var fsys embed.FS
type Function struct {
Type,
Name,
Parameters,
Args string
}
func ParseFunction(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) Function {
var fn Function
fn.Name = node.ChildByFieldName("declarator").Utf8Text(source)
if params := node.ChildByFieldName("parameters"); params != nil {
fn.Parameters = params.Utf8Text(source)
fn.Args = ParseParameters(params, tc, source)
}
var types []string
for node.Parent() != nil && node.Parent().Kind() != "declaration" {
if node.Parent().Kind() == "pointer_declarator" {
types = append(types, "*")
}
node = node.Parent()
}
for sibling := node.PrevSibling(); sibling != nil; sibling = sibling.PrevSibling() {
types = append(types, sibling.Utf8Text(source))
}
slices.Reverse(types)
fn.Type = strings.Join(types, " ")
return fn
}
func ParseParameters(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) string {
var s []string
for _, child := range node.Children(tc) {
if child.IsNamed() {
child := child.ChildByFieldName("declarator")
for child != nil && child.Kind() != "identifier" {
if child.Kind() == "parenthesized_declarator" {
child = child.Child(1)
} else {
child = child.ChildByFieldName("declarator")
}
}
if child != nil {
s = append(s, child.Utf8Text(source))
}
}
}
return strings.Join(s, ", ")
}
func main() {
var output string
flag.StringVar(&output, "output", ".", "Output directory for generated files")
flag.Parse()
parser := tree_sitter.NewParser()
defer parser.Close()
language := tree_sitter.NewLanguage(tree_sitter_cpp.Language())
parser.SetLanguage(language)
query, _ := tree_sitter.NewQuery(language, `(function_declarator declarator: (identifier)) @func`)
defer query.Close()
qc := tree_sitter.NewQueryCursor()
defer qc.Close()
var funs []Function
for _, arg := range flag.Args() {
bts, err := os.ReadFile(arg)
if err != nil {
fmt.Fprintf(os.Stderr, "Error reading file %s: %v\n", arg, err)
continue
}
tree := parser.Parse(bts, nil)
defer tree.Close()
tc := tree.Walk()
defer tc.Close()
matches := qc.Matches(query, tree.RootNode(), bts)
for match := matches.Next(); match != nil; match = matches.Next() {
for _, capture := range match.Captures {
funs = append(funs, ParseFunction(&capture.Node, tc, bts))
}
}
}
tmpl, err := template.New("").ParseFS(fsys, "*.gotmpl")
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing template: %v\n", err)
return
}
for _, tmpl := range tmpl.Templates() {
name := filepath.Join(output, strings.TrimSuffix(tmpl.Name(), ".gotmpl"))
fmt.Println("Generating", name)
f, err := os.Create(name)
if err != nil {
fmt.Fprintf(os.Stderr, "Error creating file %s: %v\n", name, err)
continue
}
defer f.Close()
if err := tmpl.Execute(f, map[string]any{
"Functions": funs,
}); err != nil {
fmt.Fprintf(os.Stderr, "Error executing template %s: %v\n", tmpl.Name(), err)
}
}
}

91
x/mlxrunner/mlx/io.go Normal file
View File

@@ -0,0 +1,91 @@
package mlx
// #include "generated.h"
import "C"
import (
"iter"
"log/slog"
"maps"
"slices"
"unsafe"
"github.com/ollama/ollama/types/model"
)
func Load(path string) iter.Seq2[string, *Array] {
return func(yield func(string, *Array) bool) {
string2array := C.mlx_map_string_to_array_new()
defer C.mlx_map_string_to_array_free(string2array)
string2string := C.mlx_map_string_to_string_new()
defer C.mlx_map_string_to_string_free(string2string)
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
cpu := C.mlx_default_cpu_stream_new()
defer C.mlx_stream_free(cpu)
C.mlx_load_safetensors(&string2array, &string2string, cPath, cpu)
it := C.mlx_map_string_to_array_iterator_new(string2array)
defer C.mlx_map_string_to_array_iterator_free(it)
for {
var key *C.char
value := C.mlx_array_new()
if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 {
break
}
name := C.GoString(key)
if !yield(name, &Array{ctx: value, desc: tensorDesc{name: name, numRefs: 1000}}) {
break
}
}
}
}
func LoadAll(root *model.Root, pattern string, states map[string]*Array, afterLoadFuncs []func(*model.Root) error) error {
matches, err := root.Glob(pattern)
if err != nil {
return err
}
weights := make(map[string]*Array)
for match := range matches {
slog.Debug("Loading weights from", "file", match)
maps.Copy(weights, maps.Collect(Load(root.JoinPath("blobs", match))))
}
var numBytes int
for name, weight := range states {
if _, ok := weights[name]; ok {
slog.Debug("Loading weight", "name", name, "weight", weight)
*weight = *weights[name]
numBytes += weight.NumBytes()
}
}
for _, afterLoadFunc := range afterLoadFuncs {
if err := afterLoadFunc(root); err != nil {
return err
}
}
Eval(slices.Collect(maps.Values(states))...)
ClearCache()
slog.Info("Loaded weights", "count", len(states), "num_bytes", PrettyBytes(numBytes), "memory", Memory{})
return nil
}
func UnloadAll(states map[string]*Array) {
weights := slices.Collect(maps.Values(states))
for _, weight := range weights {
weight.desc.numRefs = 0
}
numBytes := Free(weights...)
slog.Info("Unloaded weights", "count", len(states), "num_bytes", PrettyBytes(numBytes), "memory", Memory{})
}

85
x/mlxrunner/mlx/memory.go Normal file
View File

@@ -0,0 +1,85 @@
package mlx
// #include "generated.h"
import "C"
import (
"fmt"
"log/slog"
"strconv"
)
func (b Byte) String() string {
return strconv.FormatInt(int64(b), 10) + " B"
}
func (b KibiByte) String() string {
return strconv.FormatFloat(float64(b)/(1<<10), 'f', 2, 64) + " KiB"
}
func (b MebiByte) String() string {
return strconv.FormatFloat(float64(b)/(1<<(2*10)), 'f', 2, 64) + " MiB"
}
func (b GibiByte) String() string {
return strconv.FormatFloat(float64(b)/(1<<(3*10)), 'f', 2, 64) + " GiB"
}
func (b TebiByte) String() string {
return strconv.FormatFloat(float64(b)/(1<<(4*10)), 'f', 2, 64) + " TiB"
}
func PrettyBytes(n int) fmt.Stringer {
switch {
case n < 1<<10:
return Byte(n)
case n < 1<<(2*10):
return KibiByte(n)
case n < 1<<(3*10):
return MebiByte(n)
case n < 1<<(4*10):
return GibiByte(n)
default:
return TebiByte(n)
}
}
func ActiveMemory() int {
var active C.size_t
C.mlx_get_active_memory(&active)
return int(active)
}
func CacheMemory() int {
var cache C.size_t
C.mlx_get_cache_memory(&cache)
return int(cache)
}
func PeakMemory() int {
var peak C.size_t
C.mlx_get_peak_memory(&peak)
return int(peak)
}
type Memory struct{}
func (Memory) LogValue() slog.Value {
return slog.GroupValue(
slog.Any("active", PrettyBytes(ActiveMemory())),
slog.Any("cache", PrettyBytes(CacheMemory())),
slog.Any("peak", PrettyBytes(PeakMemory())),
)
}
type (
Byte int
KibiByte int
MebiByte int
GibiByte int
TebiByte int
)
func ClearCache() {
C.mlx_clear_cache()
}

43
x/mlxrunner/mlx/mlx.go Normal file
View File

@@ -0,0 +1,43 @@
package mlx
//go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist -DCMAKE_BUILD_TYPE=Release
//go:generate cmake --build build --parallel
//go:generate cmake --install build
//go:generate sh -c "go run generator/main.go -output=. ./dist/include/mlx/c/*.h"
// #cgo CXXFLAGS: -std=c++17
// #cgo CPPFLAGS: -I${SRCDIR}/dist/include
// #cgo LDFLAGS: -L${SRCDIR}/dist/lib -lstdc++
// #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
// #include "generated.h"
import "C"
import (
"unsafe"
)
func doEval(outputs []*Array, async bool) {
vectorData := make([]C.mlx_array, 0, len(outputs))
for _, output := range outputs {
if output.Valid() {
vectorData = append(vectorData, output.ctx)
}
}
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
defer C.mlx_vector_array_free(vector)
if async {
C.mlx_async_eval(vector)
} else {
C.mlx_eval(vector)
}
}
func AsyncEval(outputs ...*Array) {
doEval(outputs, true)
}
func Eval(outputs ...*Array) {
doEval(outputs, false)
}

30
x/mlxrunner/mlx/nn.go Normal file
View File

@@ -0,0 +1,30 @@
package mlx
type Linear struct {
Weight Array `weight:"weight"`
Bias Array `weight:"bias"`
}
// Forward computes the linear transformation: x @ Weight.T + Bias
func (m Linear) Forward(x *Array) *Array {
w := m.Weight.Transpose(1, 0)
if m.Bias.Valid() {
return m.Bias.Addmm(x, w, 1.0, 1.0)
}
return x.Matmul(w)
}
type Embedding struct {
Weight Array `weight:"weight"`
}
func (e *Embedding) Forward(indices *Array) *Array {
return e.Weight.TakeAxis(indices, 0)
}
func (e *Embedding) AsLinear() Linear {
return Linear{
Weight: e.Weight,
}
}

192
x/mlxrunner/mlx/ops.go Normal file
View File

@@ -0,0 +1,192 @@
package mlx
// #include "generated.h"
import "C"
import (
"unsafe"
)
func (t *Array) Abs() *Array {
out := New("ABS", t)
C.mlx_abs(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Add(other *Array) *Array {
out := New("ADD", t, other)
C.mlx_add(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Addmm(a, b *Array, alpha, beta float32) *Array {
out := New("ADDMM", t, a, b)
C.mlx_addmm(&out.ctx, t.ctx, a.ctx, b.ctx, C.float(alpha), C.float(beta), DefaultStream().ctx)
return out
}
func (t *Array) Argmax(axis int, keepDims bool) *Array {
out := New("ARGMAX", t)
C.mlx_argmax_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) ArgpartitionAxis(kth int, axis int) *Array {
out := New("ARGPARTITION", t)
C.mlx_argpartition_axis(&out.ctx, t.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) AsType(dtype DType) *Array {
out := New("AS_TYPE", t)
C.mlx_astype(&out.ctx, t.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
return out
}
func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array {
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
}
cStrides := make([]C.int64_t, len(strides))
for i, s := range strides {
cStrides[i] = C.int64_t(s)
}
out := New("AS_STRIDED", t)
C.mlx_as_strided(
&out.ctx, t.ctx,
unsafe.SliceData(cShape), C.size_t(len(shape)),
unsafe.SliceData(cStrides), C.size_t(len(strides)),
C.size_t(offset),
DefaultStream().ctx,
)
return out
}
func (t *Array) Concatenate(axis int, others ...*Array) *Array {
vectorData := make([]C.mlx_array, len(others)+1)
vectorData[0] = t.ctx
for i := range others {
vectorData[i+1] = others[i].ctx
}
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
defer C.mlx_vector_array_free(vector)
out := New("CONCATENATE", t)
C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) ExpandDims(axis int) *Array {
out := New("EXPAND_DIMS", t)
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Logsumexp(keepDims bool) *Array {
out := New("LOGSUMEXP", t)
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) Matmul(other *Array) *Array {
out := New("MATMUL", t, other)
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Multiply(other *Array) *Array {
out := New("MULTIPLY", t, other)
C.mlx_multiply(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Negative() *Array {
out := New("NEGATIVE", t)
C.mlx_negative(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Power(exponent *Array) *Array {
out := New("POWER", t, exponent)
C.mlx_power(&out.ctx, t.ctx, exponent.ctx, DefaultStream().ctx)
return out
}
func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array {
out := New("PUT_ALONG_AXIS", t, indices, values)
C.mlx_put_along_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Reshape(axes ...int) *Array {
cAxes := make([]C.int, len(axes))
for i := range axes {
cAxes[i] = C.int(axes[i])
}
out := New("RESHAPE", t)
C.mlx_reshape(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
return out
}
func (t *Array) Sigmoid() *Array {
out := New("SIGMOID", t)
C.mlx_sigmoid(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Sqrt() *Array {
out := New("SQRT", t)
C.mlx_sqrt(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Squeeze(axis int) *Array {
out := New("SQUEEZE", t)
C.mlx_squeeze_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Subtract(other *Array) *Array {
out := New("SUBTRACT", t, other)
C.mlx_subtract(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) TakeAxis(indices *Array, axis int) *Array {
out := New("TAKE_AXIS", t, indices)
C.mlx_take_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Tanh() *Array {
out := New("TANH", t)
C.mlx_tanh(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Transpose(axes ...int) *Array {
cAxes := make([]C.int, len(axes))
for i, axis := range axes {
cAxes[i] = C.int(axis)
}
out := New("TRANSPOSE", t)
C.mlx_transpose_axes(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
return out
}
func Zeros(dtype DType, shape ...int) *Array {
cAxes := make([]C.int, len(shape))
for i := range shape {
cAxes[i] = C.int(shape[i])
}
t := New("ZEROS")
C.mlx_zeros(&t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), C.mlx_dtype(dtype), DefaultStream().ctx)
return t
}

11
x/mlxrunner/mlx/random.go Normal file
View File

@@ -0,0 +1,11 @@
package mlx
// #include "generated.h"
import "C"
func (t *Array) Categorical(axis int) *Array {
key := New("")
out := New("", t, key)
C.mlx_random_categorical(&out.ctx, t.ctx, C.int(axis), key.ctx, DefaultStream().ctx)
return out
}

84
x/mlxrunner/mlx/slice.go Normal file
View File

@@ -0,0 +1,84 @@
package mlx
// #include "generated.h"
import "C"
import (
"cmp"
"unsafe"
)
type slice struct {
args []int
}
func Slice(args ...int) slice {
return slice{args: args}
}
func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
if len(slices) != len(dims) {
panic("number of slice arguments must match number of tensor dimensions")
}
args := [3][]C.int{
make([]C.int, len(slices)),
make([]C.int, len(slices)),
make([]C.int, len(slices)),
}
for i, s := range slices {
switch len(s.args) {
case 0:
// slice[:]
args[0][i] = C.int(0)
args[1][i] = C.int(dims[i])
args[2][i] = C.int(1)
case 1:
// slice[i]
args[0][i] = C.int(s.args[0])
args[1][i] = C.int(s.args[0] + 1)
args[2][i] = C.int(1)
case 2:
// slice[i:j]
args[0][i] = C.int(s.args[0])
args[1][i] = cmp.Or(C.int(s.args[1]), C.int(dims[i]))
args[2][i] = C.int(1)
case 3:
// slice[i:j:k]
args[0][i] = C.int(s.args[0])
args[1][i] = cmp.Or(C.int(s.args[1]), C.int(dims[i]))
args[2][i] = C.int(s.args[2])
default:
panic("invalid slice arguments")
}
}
return args[0], args[1], args[2]
}
func (t *Array) Slice(slices ...slice) *Array {
starts, stops, strides := makeSlices(t.Dims(), slices...)
out := New("SLICE", t)
C.mlx_slice(
&out.ctx, t.ctx,
unsafe.SliceData(starts), C.size_t(len(starts)),
unsafe.SliceData(stops), C.size_t(len(stops)),
unsafe.SliceData(strides), C.size_t(len(strides)),
DefaultStream().ctx,
)
return out
}
func (t *Array) SliceUpdate(other *Array, slices ...slice) *Array {
starts, stops, strides := makeSlices(t.Dims(), slices...)
out := New("SLICE_UPDATE", t, other)
C.mlx_slice_update(
&out.ctx, t.ctx, other.ctx,
unsafe.SliceData(starts), C.size_t(len(starts)),
unsafe.SliceData(stops), C.size_t(len(stops)),
unsafe.SliceData(strides), C.size_t(len(strides)),
DefaultStream().ctx,
)
return out
}

43
x/mlxrunner/mlx/stream.go Normal file
View File

@@ -0,0 +1,43 @@
package mlx
// #include "generated.h"
import "C"
import (
"log/slog"
"sync"
)
type Device struct {
ctx C.mlx_device
}
func (d Device) LogValue() slog.Value {
str := C.mlx_string_new()
defer C.mlx_string_free(str)
C.mlx_device_tostring(&str, d.ctx)
return slog.StringValue(C.GoString(C.mlx_string_data(str)))
}
var DefaultDevice = sync.OnceValue(func() Device {
d := C.mlx_device_new()
C.mlx_get_default_device(&d)
return Device{d}
})
type Stream struct {
ctx C.mlx_stream
}
func (s Stream) LogValue() slog.Value {
str := C.mlx_string_new()
defer C.mlx_string_free(str)
C.mlx_stream_tostring(&str, s.ctx)
return slog.StringValue(C.GoString(C.mlx_string_data(str)))
}
var DefaultStream = sync.OnceValue(func() Stream {
s := C.mlx_stream_new()
C.mlx_get_default_stream(&s, DefaultDevice().ctx)
return Stream{s}
})

View File

@@ -0,0 +1,8 @@
package base
import "github.com/ollama/ollama/x/mlxrunner/cache"
// Cacher is implemented by models that support custom caching mechanisms.
type Cacher interface {
Cache() []cache.Cache
}

View File

@@ -0,0 +1,113 @@
package base
import (
"encoding/json"
"errors"
"log/slog"
"reflect"
"strconv"
"strings"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
type Model interface {
// Forward performs a forward pass through the model.
Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array
// NumLayers returns the number of layers in the model.
// This is used to initialize caches.
// TODO: consider moving cache initialization into the model itself.
NumLayers() int
}
type TextGeneration interface {
Model
Unembed(*mlx.Array) *mlx.Array
}
func Weights(m Model) (map[string]*mlx.Array, []func(*model.Root) error) {
mapping := make(map[string]*mlx.Array)
var afterLoadFuncs []func(*model.Root) error
var fn func(v reflect.Value, tags []string)
fn = func(v reflect.Value, tags []string) {
t := v.Type()
if method := v.Addr().MethodByName("AfterLoad"); method.IsValid() {
var afterLoadFunc func(*model.Root) error
reflect.ValueOf(&afterLoadFunc).Elem().Set(method)
afterLoadFuncs = append(afterLoadFuncs, afterLoadFunc)
}
for _, field := range reflect.VisibleFields(t) {
if field.IsExported() {
tt, vv := field.Type, v.FieldByIndex(field.Index)
// create local copy so tags are not modified between fields
tags := tags
if tag := field.Tag.Get("weight"); tag != "" {
// TODO: use model.Tag
tags = append(tags, tag)
}
if tt == reflect.TypeOf((*mlx.Array)(nil)).Elem() {
name := strings.Join(tags, ".")
mapping[name] = vv.Addr().Interface().(*mlx.Array)
continue
}
switch tt.Kind() {
case reflect.Interface:
vv = vv.Elem()
fallthrough
case reflect.Pointer:
vv = vv.Elem()
fallthrough
case reflect.Struct:
fn(vv, tags)
case reflect.Slice, reflect.Array:
for i := range vv.Len() {
fn(vv.Index(i), append(tags, strconv.Itoa(i)))
}
}
}
}
}
fn(reflect.ValueOf(m).Elem(), []string{})
return mapping, afterLoadFuncs
}
var m = make(map[string]func(*model.Root) (Model, error))
func Register(name string, f func(*model.Root) (Model, error)) {
if _, exists := m[name]; exists {
panic("model already registered: " + name)
}
m[name] = f
}
func New(root *model.Root) (Model, error) {
c, err := root.Open("config.json")
if err != nil {
return nil, err
}
defer c.Close()
var config struct {
Architectures []string `json:"architectures"`
}
if err := json.NewDecoder(c).Decode(&config); err != nil {
return nil, err
}
slog.Info("Model architecture", "arch", config.Architectures[0])
if f, exists := m[config.Architectures[0]]; exists {
return f(root)
}
return nil, errors.New("unknown architecture")
}

View File

@@ -0,0 +1,84 @@
package gemma
import (
"cmp"
"encoding/json"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
)
type Model struct {
Text TextModel `weight:"language_model"`
}
func (m *Model) NumLayers() int {
return len(m.Text.Layers)
}
func (m Model) Cache() []cache.Cache {
caches := make([]cache.Cache, m.NumLayers())
for i := range caches {
if (i+1)%m.Text.Options.SlidingWindowPattern == 0 {
caches[i] = cache.NewKVCache()
} else {
caches[i] = cache.NewRotatingKVCache(m.Text.Options.SlidingWindow)
}
}
return caches
}
func (m *Model) Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array {
return m.Text.Forward(inputs, cache)
}
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
return m.Text.EmbedTokens.AsLinear().Forward(x)
}
func init() {
base.Register("Gemma3ForConditionalGeneration", func(root *model.Root) (base.Model, error) {
bts, err := root.ReadFile("config.json")
if err != nil {
return nil, err
}
var opts struct {
Text TextOptions `json:"text_config"`
}
if err := json.Unmarshal(bts, &opts); err != nil {
return nil, err
}
opts.Text.NumAttentionHeads = cmp.Or(opts.Text.NumAttentionHeads, 8)
opts.Text.NumKeyValueHeads = cmp.Or(opts.Text.NumKeyValueHeads, 4)
opts.Text.HeadDim = cmp.Or(opts.Text.HeadDim, 256)
opts.Text.RMSNormEps = cmp.Or(opts.Text.RMSNormEps, 1e-6)
opts.Text.SlidingWindowPattern = cmp.Or(opts.Text.SlidingWindowPattern, 6)
// TODO: implement json.Unmarshaler
opts.Text.RoPE = map[bool]mlx.RoPE{
true: {Dims: opts.Text.HeadDim, Traditional: false, Base: 1_000_000, Scale: 1. / 8.},
false: {Dims: opts.Text.HeadDim, Traditional: false, Base: 10_000, Scale: 1},
}
return &Model{
Text: TextModel{
Layers: make([]TextDecoderLayer, opts.Text.NumHiddenLayers),
Options: opts.Text,
},
}, nil
})
}
type RMSNorm struct {
mlx.RMSNorm
}
func (m *RMSNorm) AfterLoad(*model.Root) error {
m.Weight.Set(m.Weight.Add(mlx.FromValue(1)))
return nil
}

View File

@@ -0,0 +1,118 @@
package gemma
import (
"math"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
type TextOptions struct {
HiddenSize int `json:"hidden_size"`
NumHiddenLayers int `json:"num_hidden_layers"`
IntermediateSize int `json:"intermediate_size"`
NumAttentionHeads int `json:"num_attention_heads"`
NumKeyValueHeads int `json:"num_key_value_heads"`
HeadDim int `json:"head_dim"`
RMSNormEps float32 `json:"rms_norm_eps"`
SlidingWindow int `json:"sliding_window"`
SlidingWindowPattern int `json:"sliding_window_pattern"`
RoPE map[bool]mlx.RoPE
}
type TextModel struct {
EmbedTokens mlx.Embedding `weight:"model.embed_tokens"`
Layers []TextDecoderLayer `weight:"model.layers"`
Norm RMSNorm `weight:"model.norm"`
Options TextOptions
}
func (m TextModel) Forward(inputs *mlx.Array, caches []cache.Cache) *mlx.Array {
B, L := inputs.Dim(0), inputs.Dim(1)
hiddenStates := m.EmbedTokens.Forward(inputs)
hiddenSize := mlx.FromValue(m.Options.HiddenSize).AsType(hiddenStates.DType())
hiddenStates = hiddenStates.Multiply(hiddenSize.Sqrt())
for i, layer := range m.Layers {
hiddenStates = layer.Forward(hiddenStates, caches[i], B, L, m.Options.RoPE[(i+1)%m.Options.SlidingWindowPattern == 0], m.Options)
}
hiddenStates = m.Norm.Forward(hiddenStates, m.Options.RMSNormEps)
return hiddenStates
}
type TextDecoderLayer struct {
InputNorm RMSNorm `weight:"input_layernorm"`
Attention TextAttention `weight:"self_attn"`
PostAttnNorm RMSNorm `weight:"post_attention_layernorm"`
PreFFNorm RMSNorm `weight:"pre_feedforward_layernorm"`
MLP TextMLP `weight:"mlp"`
PostFFNorm RMSNorm `weight:"post_feedforward_layernorm"`
}
func (m TextDecoderLayer) Forward(hiddenStates *mlx.Array, cache cache.Cache, B, L int, rope mlx.RoPE, opts TextOptions) *mlx.Array {
residual := hiddenStates
hiddenStates = m.InputNorm.Forward(hiddenStates, opts.RMSNormEps)
hiddenStates = m.Attention.Forward(hiddenStates, cache, B, L, rope, opts)
hiddenStates = m.PostAttnNorm.Forward(hiddenStates, opts.RMSNormEps)
hiddenStates = hiddenStates.Add(residual)
residual = hiddenStates
hiddenStates = m.PreFFNorm.Forward(hiddenStates, opts.RMSNormEps)
hiddenStates = m.MLP.Forward(hiddenStates, opts)
hiddenStates = m.PostFFNorm.Forward(hiddenStates, opts.RMSNormEps)
hiddenStates = hiddenStates.Add(residual)
return hiddenStates
}
type TextAttention struct {
QProj mlx.Linear `weight:"q_proj"`
QNorm RMSNorm `weight:"q_norm"`
KProj mlx.Linear `weight:"k_proj"`
KNorm RMSNorm `weight:"k_norm"`
VProj mlx.Linear `weight:"v_proj"`
OProj mlx.Linear `weight:"o_proj"`
}
func (m TextAttention) Forward(hiddenStates *mlx.Array, cache cache.Cache, B, L int, rope mlx.RoPE, opts TextOptions) *mlx.Array {
query := m.QProj.Forward(hiddenStates)
key := m.KProj.Forward(hiddenStates)
value := m.VProj.Forward(hiddenStates)
query = query.AsStrided(
[]int{B, opts.NumAttentionHeads, L, opts.HeadDim},
[]int{L * opts.NumAttentionHeads * opts.HeadDim, opts.HeadDim, opts.NumAttentionHeads * opts.HeadDim, 1},
0)
key = key.AsStrided(
[]int{B, opts.NumKeyValueHeads, L, opts.HeadDim},
[]int{L * opts.NumKeyValueHeads * opts.HeadDim, opts.HeadDim, opts.NumKeyValueHeads * opts.HeadDim, 1},
0)
value = value.AsStrided(
[]int{B, opts.NumKeyValueHeads, L, opts.HeadDim},
[]int{L * opts.NumKeyValueHeads * opts.HeadDim, opts.HeadDim, opts.NumKeyValueHeads * opts.HeadDim, 1},
0)
query = m.QNorm.Forward(query, opts.RMSNormEps)
key = m.KNorm.Forward(key, opts.RMSNormEps)
query = rope.Forward(query, cache.Offset())
key = rope.Forward(key, cache.Offset())
key, value = cache.Update(key, value)
attention := mlx.ScaledDotProductAttention(query, key, value, nil, 1.0/float32(math.Sqrt(float64(opts.HeadDim))))
attention = attention.Transpose(0, 2, 1, 3).Reshape(B, L, -1)
return m.OProj.Forward(attention)
}
type TextMLP struct {
GateProj mlx.Linear `weight:"gate_proj"`
UpProj mlx.Linear `weight:"up_proj"`
DownProj mlx.Linear `weight:"down_proj"`
}
func (m TextMLP) Forward(h *mlx.Array, opts TextOptions) *mlx.Array {
return m.DownProj.Forward(mlx.GELUApprox(m.GateProj.Forward(h)).Multiply(m.UpProj.Forward(h)))
}

View File

@@ -0,0 +1,130 @@
package llama
import (
"encoding/json"
"log/slog"
"math"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
)
type Options struct {
HiddenAct string `json:"hidden_act"`
HiddenSize int `json:"hidden_size"`
IntermediateSize int `json:"intermediate_size"`
NumAttentionHeads int `json:"num_attention_heads"`
NumHiddenLayers int `json:"num_hidden_layers"`
NumKeyValueHeads int `json:"num_key_value_heads"`
RMSNormEps float32 `json:"rms_norm_eps"`
mlx.RoPE
}
type Model struct {
EmbedTokens mlx.Embedding `weight:"model.embed_tokens"`
Layers []Layer `weight:"model.layers"`
Norm mlx.RMSNorm `weight:"model.norm"`
Output mlx.Linear `weight:"lm_head"`
Options
}
func (m Model) NumLayers() int {
return len(m.Layers)
}
func (m Model) Forward(inputs *mlx.Array, caches []cache.Cache) *mlx.Array {
slog.Debug("Model.forward", "input shape", inputs.Dims(), "m.EmbedTokens", m.EmbedTokens.Weight.Dims())
B, L := inputs.Dim(0), inputs.Dim(1)
hiddenStates := m.EmbedTokens.Forward(inputs)
for i, layer := range m.Layers {
hiddenStates = layer.Forward(hiddenStates, caches[i], B, L, m.Options)
}
hiddenStates = m.Norm.Forward(hiddenStates, m.RMSNormEps)
hiddenStates = m.Output.Forward(hiddenStates)
slog.Debug("Model.forward", "output shape", hiddenStates.Dims(), "m.Output", m.Output.Weight.Dims())
return hiddenStates
}
type Layer struct {
AttentionNorm mlx.RMSNorm `weight:"input_layernorm"`
Attention Attention `weight:"self_attn"`
MLPNorm mlx.RMSNorm `weight:"post_attention_layernorm"`
MLP MLP `weight:"mlp"`
}
func (m Layer) Forward(hiddenStates *mlx.Array, c cache.Cache, B, L int, opts Options) *mlx.Array {
residual := hiddenStates
hiddenStates = m.AttentionNorm.Forward(hiddenStates, opts.RMSNormEps)
hiddenStates = m.Attention.Forward(hiddenStates, c, B, L, opts)
hiddenStates = hiddenStates.Add(residual)
residual = hiddenStates
hiddenStates = m.MLPNorm.Forward(hiddenStates, opts.RMSNormEps)
hiddenStates = m.MLP.Forward(hiddenStates)
hiddenStates = hiddenStates.Add(residual)
return hiddenStates
}
type Attention struct {
QueryProj mlx.Linear `weight:"q_proj"`
KeyProj mlx.Linear `weight:"k_proj"`
ValueProj mlx.Linear `weight:"v_proj"`
OutputProj mlx.Linear `weight:"o_proj"`
}
func (m Attention) Forward(hiddenStates *mlx.Array, cache cache.Cache, B, L int, opts Options) *mlx.Array {
query := m.QueryProj.Forward(hiddenStates)
query = query.Reshape(B, L, opts.NumAttentionHeads, -1).Transpose(0, 2, 1, 3)
key := m.KeyProj.Forward(hiddenStates)
key = key.Reshape(B, L, opts.NumKeyValueHeads, -1).Transpose(0, 2, 1, 3)
value := m.ValueProj.Forward(hiddenStates)
value = value.Reshape(B, L, opts.NumKeyValueHeads, -1).Transpose(0, 2, 1, 3)
query = opts.RoPE.Forward(query, cache.Offset())
key = opts.RoPE.Forward(key, cache.Offset())
key, value = cache.Update(key, value)
attention := mlx.ScaledDotProductAttention(query, key, value, nil, 1.0/float32(math.Sqrt(float64(key.Dim(-1)))))
attention = attention.Transpose(0, 2, 1, 3).Reshape(B, L, -1)
return m.OutputProj.Forward(attention)
}
type MLP struct {
Gate mlx.Linear `weight:"gate_proj"`
Up mlx.Linear `weight:"up_proj"`
Down mlx.Linear `weight:"down_proj"`
}
func (m MLP) Forward(h *mlx.Array) *mlx.Array {
return m.Down.Forward(mlx.SILU(m.Gate.Forward(h)).Multiply(m.Up.Forward(h)))
}
func init() {
base.Register("MistralForCausalLM", func(root *model.Root) (base.Model, error) {
bts, err := root.ReadFile("config.json")
if err != nil {
return nil, err
}
var opts Options
// TODO: implement json.Unmarshal for Options
if err := json.Unmarshal(bts, &opts); err != nil {
return nil, err
}
if err := json.Unmarshal(bts, &opts.RoPE); err != nil {
return nil, err
}
return &Model{
Layers: make([]Layer, opts.NumHiddenLayers),
Options: opts,
}, nil
})
}

View File

@@ -0,0 +1,6 @@
package model
import (
_ "github.com/ollama/ollama/x/mlxrunner/model/gemma/3"
_ "github.com/ollama/ollama/x/mlxrunner/model/llama"
)

138
x/mlxrunner/pipeline.go Normal file
View File

@@ -0,0 +1,138 @@
package mlxrunner
import (
"bytes"
"errors"
"log/slog"
"time"
"unicode/utf8"
"github.com/ollama/ollama/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
)
func (r *Runner) TextGenerationPipeline(request Request) error {
model, ok := r.Model.(base.TextGeneration)
if !ok {
return errors.New("model does not support causal language modeling")
}
inputs, err := r.Tokenizer.Encode(request.Prompt, true)
if err != nil {
return err
}
caches, tokens := r.FindNearestCache(inputs)
if len(caches) == 0 {
if cacher, ok := model.(base.Cacher); ok {
caches = cacher.Cache()
} else {
caches = make([]cache.Cache, model.NumLayers())
for i := range caches {
caches[i] = cache.NewKVCache()
}
}
}
total, processed := len(tokens), 0
slog.Info("Prompt processing progress", "processed", processed, "total", total)
for total-processed > 1 {
n := min(2<<10, total-processed-1)
temp := model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
defer mlx.Free(temp)
mlx.Eval(func() []*mlx.Array {
s := make([]*mlx.Array, 2*len(caches))
for i, c := range caches {
s[2*i], s[2*i+1] = c.State()
}
return s
}()...)
processed += n
slog.Info("Prompt processing progress", "processed", processed, "total", total)
mlx.ClearCache()
}
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
logits := model.Unembed(model.Forward(token.ExpandDims(0), caches))
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
// TODO: additional logit processing (logit bias, repetition penalty, etc.)
logprobs := logits.Subtract(logits.Logsumexp(true))
return request.Sample(logprobs), logprobs
}
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed))
mlx.AsyncEval(sample, logprobs)
// buffer partial, multibyte unicode
var b bytes.Buffer
now := time.Now()
final := Response{PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
outputs := make([]int32, 0, request.Options.MaxTokens)
for i := range request.Options.MaxTokens {
nextSample, nextLogprobs := step(sample)
mlx.AsyncEval(nextSample, nextLogprobs)
if i == 0 {
slog.Info("Prompt processing progress", "processed", total, "total", total)
mlx.Eval(sample)
final.PromptTokensDuration = time.Since(now)
now = time.Now()
}
output := int32(sample.Int())
outputs = append(outputs, output)
if r.Tokenizer.Is(output, tokenizer.SpecialEOS) {
final.Token = int(output)
final.DoneReason = 0
final.CompletionTokens = i
break
}
request.Responses <- Response{
Text: r.Decode(output, &b),
Token: int(output),
}
mlx.Free(sample, logprobs)
if i%256 == 0 {
mlx.ClearCache()
}
sample, logprobs = nextSample, nextLogprobs
}
mlx.Free(sample, logprobs)
final.CompletionTokensDuration = time.Since(now)
request.Responses <- final
r.InsertCache(append(inputs, outputs...), caches)
return nil
}
func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
token, err := r.Tokenizer.Decode([]int32{sample})
if err != nil {
slog.Error("Failed to decode tokens", "error", err)
return ""
}
if _, err := b.WriteString(token); err != nil {
slog.Error("Failed to write token to buffer", "error", err)
return ""
}
if text := b.String(); utf8.ValidString(text) {
b.Reset()
return text
} else if b.Len() >= utf8.UTFMax {
b.Reset()
return text
}
return ""
}

110
x/mlxrunner/runner.go Normal file
View File

@@ -0,0 +1,110 @@
package mlxrunner
import (
"context"
"log/slog"
"net"
"net/http"
"time"
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/tokenizer"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
_ "github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/mlxrunner/sample"
)
type Request struct {
TextCompletionsRequest
Responses chan Response
Pipeline func(Request) error
sample.Sampler
caches []cache.Cache
}
type TextCompletionsRequest struct {
Prompt string `json:"prompt"`
Options struct {
Temperature float32 `json:"temperature"`
TopP float32 `json:"top_p"`
MinP float32 `json:"min_p"`
TopK int `json:"top_k"`
MaxTokens int `json:"max_tokens"`
// Deprecated: use MaxTokens instead
NumPredict int `json:"num_predict"`
} `json:"options"`
}
type Response struct {
Text string `json:"content,omitempty"`
Token int `json:"token,omitempty"`
Logprobs []float32 `json:"logprobs,omitempty"`
Done bool `json:"done,omitempty"`
DoneReason int `json:"done_reason,omitempty"`
PromptTokens int `json:"prompt_eval_count,omitempty"`
PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"`
CompletionTokens int `json:"eval_count,omitempty"`
CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
}
type Runner struct {
Model base.Model
Tokenizer tokenizer.Tokenizer
Requests chan Request
CacheEntries map[int32]*CacheEntry
}
func (r *Runner) Load(name model.Name) (err error) {
root, err := model.Open(name)
if err != nil {
return err
}
defer root.Close()
r.Model, err = base.New(root)
if err != nil {
return err
}
r.Tokenizer, err = tokenizer.New(root)
if err != nil {
return err
}
weights, afterLoadFuncs := base.Weights(r.Model)
return mlx.LoadAll(root, "model*.safetensors", weights, afterLoadFuncs)
}
func (r *Runner) Run(host, port string, mux http.Handler) error {
g, ctx := errgroup.WithContext(context.Background())
g.Go(func() error {
for {
select {
case <-ctx.Done():
return nil
case request := <-r.Requests:
if err := request.Pipeline(request); err != nil {
break
}
close(request.Responses)
}
}
})
g.Go(func() error {
slog.Info("Starting HTTP server", "host", host, "port", port)
return http.ListenAndServe(net.JoinHostPort(host, port), mux)
})
return g.Wait()
}

View File

@@ -0,0 +1,75 @@
package sample
import (
"math"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
type Sampler interface {
Sample(*mlx.Array) *mlx.Array
}
func New(temp, top_p, min_p float32, top_k int) Sampler {
if temp == 0 {
return greedy{}
}
var samplers []Sampler
if top_p > 0 && top_p < 1 {
samplers = append(samplers, TopP(top_p))
}
if min_p != 0 {
samplers = append(samplers, MinP(min_p))
}
if top_k > 0 {
samplers = append(samplers, TopK(top_k))
}
samplers = append(samplers, Temperature(temp))
return chain(samplers)
}
type greedy struct{}
func (greedy) Sample(logits *mlx.Array) *mlx.Array {
return logits.Argmax(-1, false)
}
type chain []Sampler
func (c chain) Sample(logits *mlx.Array) *mlx.Array {
for _, sampler := range c {
logits = sampler.Sample(logits)
}
return logits
}
type Temperature float32
func (t Temperature) Sample(logits *mlx.Array) *mlx.Array {
return logits.Multiply(mlx.FromValue(1 / float32(t))).Categorical(-1)
}
type TopP float32
func (p TopP) Sample(logprobs *mlx.Array) *mlx.Array {
// TODO: implement
return logprobs
}
type MinP float32
func (p MinP) Sample(logprobs *mlx.Array) *mlx.Array {
// TODO: implement
return logprobs
}
type TopK int
func (k TopK) Sample(logprobs *mlx.Array) *mlx.Array {
mask := logprobs.Negative().ArgpartitionAxis(int(k)-1, -1).Slice(mlx.Slice(), mlx.Slice(int(k), 0))
return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
}

180
x/mlxrunner/server.go Normal file
View File

@@ -0,0 +1,180 @@
package mlxrunner
import (
"bytes"
"cmp"
"encoding/json"
"flag"
"io"
"log/slog"
"net/http"
"os"
"strconv"
"time"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/mlxrunner/sample"
)
func Execute(args []string) error {
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
var (
name model.Name
port int
)
flagSet := flag.NewFlagSet("mlxrunner", flag.ExitOnError)
flagSet.Var(&name, "model", "Model name")
flagSet.IntVar(&port, "port", 0, "Port to listen on")
_ = flagSet.Bool("verbose", false, "Enable debug logging")
flagSet.Parse(args)
runner := Runner{
Requests: make(chan Request),
CacheEntries: make(map[int32]*CacheEntry),
}
if err := runner.Load(name); err != nil {
return err
}
mux := http.NewServeMux()
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
if err := json.NewEncoder(w).Encode(map[string]any{
"status": 0,
"progress": 100,
}); err != nil {
slog.Error("Failed to encode response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
})
mux.HandleFunc("/v1/models", func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case "POST":
fallthrough
case "GET":
if err := json.NewEncoder(w).Encode(map[string]any{
"Success": true,
}); err != nil {
slog.Error("Failed to encode response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
case "DELETE":
// TODO: cleanup model and cache
}
})
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
request := Request{Responses: make(chan Response)}
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
slog.Error("Failed to decode request", "error", err)
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
if request.Options.MaxTokens < 1 {
request.Options.MaxTokens = 16 << 10
}
request.Pipeline = runner.TextGenerationPipeline
request.Sampler = sample.New(
request.Options.Temperature,
request.Options.TopP,
request.Options.MinP,
request.Options.TopK,
)
runner.Requests <- request
w.Header().Set("Content-Type", "application/jsonl")
w.WriteHeader(http.StatusOK)
enc := json.NewEncoder(w)
for response := range request.Responses {
if err := enc.Encode(response); err != nil {
slog.Error("Failed to encode response", "error", err)
return
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
}
})
mux.HandleFunc("POST /v1/tokenize", func(w http.ResponseWriter, r *http.Request) {
var b bytes.Buffer
if _, err := io.Copy(&b, r.Body); err != nil {
slog.Error("Failed to read request body", "error", err)
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
tokens, err := runner.Tokenizer.Encode(b.String(), true)
if err != nil {
slog.Error("Failed to tokenize text", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
if err := json.NewEncoder(w).Encode(tokens); err != nil {
slog.Error("Failed to encode response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
})
for source, target := range map[string]string{
"GET /health": "/v1/status",
"POST /load": "/v1/models",
"POST /completion": "/v1/completions",
} {
mux.Handle(source, http.RedirectHandler(target, http.StatusPermanentRedirect))
}
return runner.Run("127.0.0.1", strconv.Itoa(port), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
recorder := &statusRecorder{ResponseWriter: w, code: http.StatusOK}
t := time.Now()
mux.ServeHTTP(recorder, r)
var level slog.Level
switch {
case recorder.code >= 500:
level = slog.LevelError
case recorder.code >= 400:
level = slog.LevelWarn
case recorder.code >= 300:
return
}
slog.Log(r.Context(), level, "ServeHTTP", "method", r.Method, "path", r.URL.Path, "took", time.Since(t), "status", recorder.Status())
}))
}
type statusRecorder struct {
http.ResponseWriter
code int
}
func (w *statusRecorder) WriteHeader(code int) {
w.code = code
w.ResponseWriter.WriteHeader(code)
}
func (w *statusRecorder) Status() string {
return strconv.Itoa(w.code) + " " + http.StatusText(w.code)
}
func (w *statusRecorder) Flush() {
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}