Compare commits

..

3 Commits

Author SHA1 Message Date
Jeffrey Morgan
099a0f18ef build: fix Dockerfile mlx directory (#14131) 2026-02-06 17:08:53 -08:00
Richard Lyons
fff696ee31 docs: increased RAM requirement for parallelism 2026-02-06 15:49:39 -08:00
Jeffrey Morgan
2e3ce6eab3 anthropic: do not count image tokens for now (#14127) 2026-02-06 15:33:18 -08:00
50 changed files with 20 additions and 15886 deletions

View File

@@ -147,7 +147,7 @@ ARG PARALLEL
WORKDIR /go/src/github.com/ollama/ollama
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
COPY x/ml/backend/mlx x/ml/backend/mlx
COPY x/imagegen/mlx x/imagegen/mlx
COPY go.mod go.sum .
COPY MLX_VERSION .
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local

View File

@@ -897,11 +897,5 @@ func countContentBlock(block any) int {
}
}
if source, ok := blockMap["source"].(map[string]any); ok {
if data, ok := source["data"].(string); ok {
total += len(data)
}
}
return total
}

View File

@@ -312,7 +312,7 @@ Parallel request processing for a given model results in increasing the context
The following server settings may be used to adjust how Ollama handles concurrent requests on most platforms:
- `OLLAMA_MAX_LOADED_MODELS` - The maximum number of models that can be loaded concurrently provided they fit in available memory. The default is 3 \* the number of GPUs or 3 for CPU inference.
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time, default 1. Required RAM will scale by `OLLAMA_NUM_PARALLEL` * `OLLAMA_CONTEXT_LENGTH`.
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.

5
go.mod
View File

@@ -13,7 +13,7 @@ require (
github.com/mattn/go-sqlite3 v1.14.24
github.com/olekukonko/tablewriter v0.0.5
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.10.0
github.com/stretchr/testify v1.9.0
github.com/x448/float16 v0.8.4
golang.org/x/sync v0.17.0
golang.org/x/sys v0.37.0
@@ -29,8 +29,6 @@ require (
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
github.com/tkrajina/typescriptify-golang-structs v0.2.0
github.com/tree-sitter/go-tree-sitter v0.25.0
github.com/tree-sitter/tree-sitter-cpp v0.23.4
github.com/wk8/go-ordered-map/v2 v2.1.8
golang.org/x/image v0.22.0
golang.org/x/mod v0.30.0
@@ -52,7 +50,6 @@ require (
github.com/google/flatbuffers v24.3.25+incompatible // indirect
github.com/kr/text v0.2.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-pointer v0.0.1 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.2.0 // indirect

31
go.sum
View File

@@ -152,8 +152,6 @@ github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-pointer v0.0.1 h1:n+XhsuGeVO6MEAp7xyEukFINEa+Quek5psIR/ylA6o0=
github.com/mattn/go-pointer v0.0.1/go.mod h1:2zXcozF6qYGgmsG+SeTZz3oAbFLdD3OWqnUbNvJZAlc=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
@@ -208,39 +206,12 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tkrajina/go-reflector v0.5.5 h1:gwoQFNye30Kk7NrExj8zm3zFtrGPqOkzFMLuQZg1DtQ=
github.com/tkrajina/go-reflector v0.5.5/go.mod h1:ECbqLgccecY5kPmPmXg1MrHW585yMcDkVl6IvJe64T4=
github.com/tkrajina/typescriptify-golang-structs v0.2.0 h1:ZedWk82egydDspGTryAatbX0/1NZDQbdiZLoCbOk4f8=
github.com/tkrajina/typescriptify-golang-structs v0.2.0/go.mod h1:sjU00nti/PMEOZb07KljFlR+lJ+RotsC0GBQMv9EKls=
github.com/tree-sitter/go-tree-sitter v0.25.0 h1:sx6kcg8raRFCvc9BnXglke6axya12krCJF5xJ2sftRU=
github.com/tree-sitter/go-tree-sitter v0.25.0/go.mod h1:r77ig7BikoZhHrrsjAnv8RqGti5rtSyvDHPzgTPsUuU=
github.com/tree-sitter/tree-sitter-c v0.23.4 h1:nBPH3FV07DzAD7p0GfNvXM+Y7pNIoPenQWBpvM++t4c=
github.com/tree-sitter/tree-sitter-c v0.23.4/go.mod h1:MkI5dOiIpeN94LNjeCp8ljXN/953JCwAby4bClMr6bw=
github.com/tree-sitter/tree-sitter-cpp v0.23.4 h1:LaWZsiqQKvR65yHgKmnaqA+uz6tlDJTJFCyFIeZU/8w=
github.com/tree-sitter/tree-sitter-cpp v0.23.4/go.mod h1:doqNW64BriC7WBCQ1klf0KmJpdEvfxyXtoEybnBo6v8=
github.com/tree-sitter/tree-sitter-embedded-template v0.23.2 h1:nFkkH6Sbe56EXLmZBqHHcamTpmz3TId97I16EnGy4rg=
github.com/tree-sitter/tree-sitter-embedded-template v0.23.2/go.mod h1:HNPOhN0qF3hWluYLdxWs5WbzP/iE4aaRVPMsdxuzIaQ=
github.com/tree-sitter/tree-sitter-go v0.23.4 h1:yt5KMGnTHS+86pJmLIAZMWxukr8W7Ae1STPvQUuNROA=
github.com/tree-sitter/tree-sitter-go v0.23.4/go.mod h1:Jrx8QqYN0v7npv1fJRH1AznddllYiCMUChtVjxPK040=
github.com/tree-sitter/tree-sitter-html v0.23.2 h1:1UYDV+Yd05GGRhVnTcbP58GkKLSHHZwVaN+lBZV11Lc=
github.com/tree-sitter/tree-sitter-html v0.23.2/go.mod h1:gpUv/dG3Xl/eebqgeYeFMt+JLOY9cgFinb/Nw08a9og=
github.com/tree-sitter/tree-sitter-java v0.23.5 h1:J9YeMGMwXYlKSP3K4Us8CitC6hjtMjqpeOf2GGo6tig=
github.com/tree-sitter/tree-sitter-java v0.23.5/go.mod h1:NRKlI8+EznxA7t1Yt3xtraPk1Wzqh3GAIC46wxvc320=
github.com/tree-sitter/tree-sitter-javascript v0.23.1 h1:1fWupaRC0ArlHJ/QJzsfQ3Ibyopw7ZfQK4xXc40Zveo=
github.com/tree-sitter/tree-sitter-javascript v0.23.1/go.mod h1:lmGD1EJdCA+v0S1u2fFgepMg/opzSg/4pgFym2FPGAs=
github.com/tree-sitter/tree-sitter-json v0.24.8 h1:tV5rMkihgtiOe14a9LHfDY5kzTl5GNUYe6carZBn0fQ=
github.com/tree-sitter/tree-sitter-json v0.24.8/go.mod h1:F351KK0KGvCaYbZ5zxwx/gWWvZhIDl0eMtn+1r+gQbo=
github.com/tree-sitter/tree-sitter-php v0.23.11 h1:iHewsLNDmznh8kgGyfWfujsZxIz1YGbSd2ZTEM0ZiP8=
github.com/tree-sitter/tree-sitter-php v0.23.11/go.mod h1:T/kbfi+UcCywQfUNAJnGTN/fMSUjnwPXA8k4yoIks74=
github.com/tree-sitter/tree-sitter-python v0.23.6 h1:qHnWFR5WhtMQpxBZRwiaU5Hk/29vGju6CVtmvu5Haas=
github.com/tree-sitter/tree-sitter-python v0.23.6/go.mod h1:cpdthSy/Yoa28aJFBscFHlGiU+cnSiSh1kuDVtI8YeM=
github.com/tree-sitter/tree-sitter-ruby v0.23.1 h1:T/NKHUA+iVbHM440hFx+lzVOzS4dV6z8Qw8ai+72bYo=
github.com/tree-sitter/tree-sitter-ruby v0.23.1/go.mod h1:kUS4kCCQloFcdX6sdpr8p6r2rogbM6ZjTox5ZOQy8cA=
github.com/tree-sitter/tree-sitter-rust v0.23.2 h1:6AtoooCW5GqNrRpfnvl0iUhxTAZEovEmLKDbyHlfw90=
github.com/tree-sitter/tree-sitter-rust v0.23.2/go.mod h1:hfeGWic9BAfgTrc7Xf6FaOAguCFJRo3RBbs7QJ6D7MI=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=

View File

@@ -4,7 +4,6 @@ import (
"github.com/ollama/ollama/runner/llamarunner"
"github.com/ollama/ollama/runner/ollamarunner"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/mlxrunner"
)
func Execute(args []string) error {
@@ -18,8 +17,6 @@ func Execute(args []string) error {
return ollamarunner.Execute(args[1:])
case "--imagegen-engine":
return imagegen.Execute(args[1:])
case "--mlx-engine":
return mlxrunner.Execute(args[1:])
}
}
return llamarunner.Execute(args)

View File

@@ -5,13 +5,9 @@ import (
"errors"
"fmt"
"log/slog"
"math/rand"
"os"
"os/exec"
"reflect"
"slices"
"sort"
"strconv"
"strings"
"sync"
"time"
@@ -26,7 +22,6 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/mlxrunner"
)
type LlmRequest struct {
@@ -200,14 +195,25 @@ func (s *Scheduler) processPending(ctx context.Context) {
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
}
// Check for experimental safetensors LLM models
if pending.model.Config.ModelFormat == "safetensors" {
if s.loadSafetensors(pending) {
// Check for image generation models - all use MLX runner
if slices.Contains(pending.model.Config.Capabilities, "image") {
if s.loadMLX(pending) {
break
}
continue
}
// Check for experimental safetensors LLM models
if pending.model.Config.ModelFormat == "safetensors" {
if slices.Contains(pending.model.Config.Capabilities, "completion") {
// LLM model with safetensors format - use MLX runner
if s.loadMLX(pending) {
break
}
continue
}
}
// Load model for fitting
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
@@ -557,90 +563,9 @@ iGPUScan:
return false
}
func subproc(args, environ []string) (*exec.Cmd, int, error) {
exe, err := os.Executable()
if err != nil {
return nil, 0, fmt.Errorf("unable to lookup executable path: %w", err)
}
for range 3 {
// get a random port in the ephemeral range
port := rand.Intn(65535-49152) + 49152
cmd := exec.Command(exe, slices.Concat([]string{"runner"}, args, []string{"--port", strconv.Itoa(port)})...)
cmd.Env = slices.Concat(os.Environ(), environ)
cmd.Stdout = os.Stderr
cmd.Stderr = os.Stderr
if err := cmd.Start(); err != nil {
continue
}
return cmd, port, nil
}
return nil, 0, fmt.Errorf("unable to start subprocess after multiple attempts")
}
func (s *Scheduler) loadSafetensors(req *LlmRequest) bool {
if slices.Contains(req.model.Config.Capabilities, "image") {
return s.loadImageGen(req)
}
args := []string{"--mlx-engine", "--model", req.model.ShortName}
environ := []string{}
cmd, port, err := subproc(args, environ)
if err != nil {
req.errCh <- fmt.Errorf("failed to start mlx subprocess: %w", err)
return true
}
sessionDuration := envconfig.KeepAlive()
if req.sessionDuration != nil {
sessionDuration = req.sessionDuration.Duration
}
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
Options: &req.opts,
loading: false,
sessionDuration: sessionDuration,
llama: &mlxrunner.Client{
Cmd: cmd,
Port: port,
},
}
s.loadedMu.Lock()
s.loaded[req.model.ModelPath] = runner
s.loadedMu.Unlock()
runner.refMu.Lock()
if sessionDuration > 0 {
runner.expireTimer = time.AfterFunc(sessionDuration, func() {
s.expiredCh <- runner
})
}
runner.refMu.Unlock()
req.useLoadedRunner(runner, s.finishedReqCh)
for range time.Tick(20 * time.Millisecond) {
if err := func() error {
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
return runner.llama.Ping(ctx)
}(); err != nil {
continue
}
break
}
return true
}
// loadImageGen loads an experimental safetensors model using the unified MLX runner.
// loadMLX loads an experimental safetensors model using the unified MLX runner.
// This supports both LLM (completion) and image generation models.
func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
// Determine mode based on capabilities
var mode imagegen.ModelMode
if slices.Contains(req.model.Config.Capabilities, "image") {

View File

@@ -1,14 +1,5 @@
package tokenizer
import (
"encoding/json"
"errors"
"io"
"os"
"github.com/ollama/ollama/types/model"
)
const (
TOKEN_TYPE_NORMAL = iota + 1
TOKEN_TYPE_UNKNOWN
@@ -24,287 +15,3 @@ type Tokenizer interface {
Is(int32, Special) bool
Vocabulary() *Vocabulary
}
func New(root *model.Root) (Tokenizer, error) {
f, err := root.Open("tokenizer.json")
if err != nil {
return nil, err
}
defer f.Close()
var tokenizer struct {
Model struct {
Type string `json:"type"`
Vocab map[string]int32 `json:"vocab"`
Merges json.RawMessage `json:"merges"`
} `json:"model"`
PreTokenizer json.RawMessage `json:"pre_tokenizer"`
Decoder json.RawMessage `json:"decoder"`
AddedTokens []struct {
ID int32 `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
} `json:"added_tokens"`
}
if err := json.NewDecoder(f).Decode(&tokenizer); err != nil {
return nil, err
}
special := make(map[int32]struct{})
for _, token := range tokenizer.AddedTokens {
tokenizer.Model.Vocab[token.Content] = token.ID
special[token.ID] = struct{}{}
}
vocab, err := specialTokens(root, tokenizer.Model.Vocab)
if err != nil {
return nil, err
}
vocab.Values = make([]string, len(tokenizer.Model.Vocab))
vocab.Scores = make([]float32, len(tokenizer.Model.Vocab))
vocab.Types = make([]int32, len(tokenizer.Model.Vocab))
for content, id := range tokenizer.Model.Vocab {
vocab.Values[id] = content
vocab.Scores[id] = float32(id)
vocab.Types[id] = TOKEN_TYPE_NORMAL
if _, ok := special[id]; ok {
vocab.Types[id] = TOKEN_TYPE_USER_DEFINED
}
}
if tokenizer.Model.Merges != nil {
var pairs [][]string
if err := json.Unmarshal(tokenizer.Model.Merges, &pairs); err == nil {
vocab.Merges = make([]string, len(pairs))
for i, pair := range pairs {
vocab.Merges[i] = pair[0] + " " + pair[1]
}
} else if err := json.Unmarshal(tokenizer.Model.Merges, &vocab.Merges); err != nil {
return nil, err
}
}
vocab.valuesOnce.Do(func() {})
vocab.values = tokenizer.Model.Vocab
if tokenizer.Model.Type == "WordPiece" {
return NewWordPiece(vocab, true), nil
}
if tokenizer.Decoder != nil {
var decoder struct {
Type string `json:"type"`
Decoders []struct {
Type string `json:"type"`
Pattern struct {
String string `json:"string"`
} `json:"pattern"`
} `json:"decoders"`
}
if err := json.Unmarshal(tokenizer.Decoder, &decoder); err != nil {
return nil, err
}
if decoder.Type == "Sequence" {
for _, d := range decoder.Decoders {
if d.Type == "Replace" && d.Pattern.String == "▁" {
return NewSentencePiece(vocab), nil
}
}
}
}
var pretokenizers []string
if tokenizer.PreTokenizer != nil {
var pretokenizer struct {
Type string `json:"type"`
Pretokenizers []struct {
Type string `json:"type"`
Pattern struct {
Regex string
} `json:"pattern"`
IndividualDigits bool `json:"individual_digits"`
}
}
if err := json.Unmarshal(tokenizer.PreTokenizer, &pretokenizer); err != nil {
return nil, err
}
if pretokenizer.Type == "Sequence" {
for _, pretokenizer := range pretokenizer.Pretokenizers {
switch pretokenizer.Type {
case "Digits":
if pretokenizer.IndividualDigits {
pretokenizers = append(pretokenizers, `\d`)
} else {
pretokenizers = append(pretokenizers, `\d+`)
}
case "Punctuation":
pretokenizers = append(pretokenizers, `[^\p{L}\p{N}]+`)
case "Split":
pretokenizers = append(pretokenizers, pretokenizer.Pattern.Regex)
case "WhitespaceSplit":
pretokenizers = append(pretokenizers, `\s+(?!\S)|\s+`)
}
}
}
}
return NewBytePairEncoding(vocab, pretokenizers...), nil
}
// valueOrValues is a type that can unmarshal from either a single value or an array of values.
type valueOrValues[E any] []E
func (m *valueOrValues[E]) UnmarshalJSON(data []byte) error {
var s []E
if err := json.Unmarshal(data, &s); err != nil {
var e E
if err := json.Unmarshal(data, &e); err != nil {
return err
}
s = []E{e}
}
*m = valueOrValues[E](s)
return nil
}
type specialTokenIDs struct {
BOSTokenID valueOrValues[int32] `json:"bos_token_id"`
EOSTokenID valueOrValues[int32] `json:"eos_token_id"`
}
// stringOrContent is a type that can unmarshal from either a string or an object with a "content" field.
type stringOrContent string
func (t *stringOrContent) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
var m map[string]any
if err := json.Unmarshal(data, &m); err != nil {
return err
}
if content, ok := m["content"].(string); ok {
s = content
}
}
*t = stringOrContent(s)
return nil
}
func specialTokens(root *model.Root, values map[string]int32) (*Vocabulary, error) {
var vocab Vocabulary
for _, c := range []struct {
name string
fn func(io.Reader) error
}{
{
name: "generation_config.json",
fn: func(r io.Reader) error {
var c specialTokenIDs
if err := json.NewDecoder(r).Decode(&c); err != nil {
return err
}
vocab.BOS = c.BOSTokenID
vocab.EOS = c.EOSTokenID
return nil
},
},
{
name: "config.json",
fn: func(r io.Reader) error {
var c specialTokenIDs
if err := json.NewDecoder(r).Decode(&c); err != nil {
return err
}
if len(vocab.BOS) == 0 {
vocab.BOS = c.BOSTokenID
}
if len(vocab.EOS) == 0 {
vocab.EOS = c.EOSTokenID
}
return nil
},
},
{
name: "tokenizer_config.json",
fn: func(r io.Reader) error {
var c struct {
BOSToken stringOrContent `json:"bos_token"`
EOSToken stringOrContent `json:"eos_token"`
PADToken stringOrContent `json:"pad_token"`
AddBOSToken bool `json:"add_bos_token"`
AddEOSToken bool `json:"add_eos_token"`
}
if err := json.NewDecoder(r).Decode(&c); err != nil {
return err
}
if len(vocab.BOS) == 0 && c.BOSToken != "" {
if id, ok := values[string(c.BOSToken)]; ok {
vocab.BOS = []int32{id}
}
}
if len(vocab.EOS) == 0 && c.EOSToken != "" {
if id, ok := values[string(c.EOSToken)]; ok {
vocab.EOS = []int32{id}
}
}
vocab.AddBOS = c.AddBOSToken
vocab.AddEOS = c.AddEOSToken
return nil
},
},
{
name: "special_tokens_map.json",
fn: func(r io.Reader) error {
var c map[string]stringOrContent
if err := json.NewDecoder(r).Decode(&c); err != nil {
return err
}
if bos, ok := c["bos_token"]; ok && len(vocab.BOS) == 0 {
if id, ok := values[string(bos)]; ok {
vocab.BOS = []int32{id}
}
}
if eos, ok := c["eos_token"]; ok && len(vocab.EOS) == 0 {
if id, ok := values[string(eos)]; ok {
vocab.EOS = []int32{id}
}
}
return nil
},
},
} {
if err := func() error {
f, err := root.Open(c.name)
if errors.Is(err, os.ErrNotExist) {
return nil
} else if err != nil {
return err
}
defer f.Close()
return c.fn(f)
}(); err != nil {
return nil, err
}
}
return &vocab, nil
}

View File

@@ -1,316 +0,0 @@
package model
import (
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"hash"
"io"
"io/fs"
"iter"
"maps"
"mime"
"os"
"path/filepath"
"strings"
"github.com/ollama/ollama/envconfig"
)
func root() (*os.Root, error) {
root, err := os.OpenRoot(envconfig.Models())
if err != nil {
return nil, err
}
for _, sub := range []string{"manifests", "blobs"} {
if _, err := root.Stat(sub); errors.Is(err, fs.ErrNotExist) {
if err := root.MkdirAll(sub, 0o750); err != nil {
return nil, err
}
} else if err != nil {
return nil, err
}
}
return root, nil
}
// Open opens an existing file for reading. It will return [fs.ErrNotExist]
// if the file does not exist. The returned [*Root] can only be used for reading.
// It is the caller's responsibility to close the file when done.
func Open(n Name) (*Root, error) {
r, err := root()
if err != nil {
return nil, err
}
f, err := r.Open(filepath.Join("manifests", n.Filepath()))
if err != nil {
return nil, err
}
defer f.Close()
var m manifest
if err := json.NewDecoder(f).Decode(&m); err != nil {
return nil, err
}
blobs := make(map[string]*blob, len(m.Layers)+1)
blobs[NamePrefix] = m.Config
for _, layer := range m.Layers {
if layer.Name == "" && layer.MediaType != "" {
mediatype, _, err := mime.ParseMediaType(layer.MediaType)
if err != nil {
return nil, err
}
if suffix, ok := strings.CutPrefix(mediatype, MediaTypePrefix); ok {
layer.Name = NamePrefix + suffix
}
}
blobs[layer.Name] = layer
}
return &Root{
root: r,
name: n,
blobs: blobs,
flags: os.O_RDONLY,
}, nil
}
// Create creates a new file. The returned [Root] can be used for both reading
// and writing. It is the caller's responsibility to close the file when done
// in order to finalize any new blobs and write the manifest.
func Create(n Name) (*Root, error) {
r, err := root()
if err != nil {
return nil, err
}
return &Root{
root: r,
name: n,
blobs: make(map[string]*blob),
flags: os.O_RDWR,
}, nil
}
type blob struct {
Digest string `json:"digest"`
MediaType string `json:"mediaType"`
Name string `json:"name,omitempty"`
Size int64 `json:"size"`
// tempfile is the temporary file where the blob data is written.
tempfile *os.File
// hash is the hash.Hash used to compute the blob digest.
hash hash.Hash
}
func (b *blob) Write(p []byte) (int, error) {
return io.MultiWriter(b.tempfile, b.hash).Write(p)
}
func (b *blob) Filepath() string {
return strings.ReplaceAll(b.Digest, ":", "-")
}
type manifest struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config *blob `json:"config"`
Layers []*blob `json:"layers"`
}
// Root represents a model file. It can be used to read and write blobs
// associated with the model.
//
// Blobs are identified by name. Certain names are special and reserved;
// see [NamePrefix] for details.
type Root struct {
root *os.Root
name Name
blobs map[string]*blob
flags int
}
const MediaTypePrefix = "application/vnd.ollama"
// NamePrefix is the prefix used for identifying special names. Names
// with this prefix are idenfitied by their media types:
//
// - name: NamePrefix + suffix
// - mediaType: [MediaTypePrefix] + suffix
//
// For example:
//
// - name: "./..image.model"
// - mediaType: "application/vnd.ollama.image.model"
//
// NamePrefix by itself identifies the manifest config.
const NamePrefix = "./."
// Open opens the named blob for reading. It is the caller's responsibility
// to close the returned [io.ReadCloser] when done. It will return
// [fs.ErrNotExist] if the blob does not exist.
func (r Root) Open(name string) (io.ReadCloser, error) {
if b, ok := r.blobs[name]; ok {
r, err := r.root.Open(filepath.Join("blobs", b.Filepath()))
if err != nil {
return nil, err
}
return r, nil
}
return nil, fs.ErrNotExist
}
func (r Root) ReadFile(name string) ([]byte, error) {
f, err := r.Open(name)
if err != nil {
return nil, err
}
defer f.Close()
return io.ReadAll(f)
}
// Create creates or replaces a named blob in the file. If the blob already
// exists, it will be overwritten. It will return [fs.ErrInvalid] if the file
// was opened in read-only mode. The returned [io.Writer] can be used to write
// to the blob and does not need be closed, but the file must be closed to
// finalize the blob.
func (r *Root) Create(name string) (io.Writer, error) {
if r.flags&os.O_RDWR != 0 {
w, err := os.CreateTemp(r.root.Name(), "")
if err != nil {
return nil, err
}
r.blobs[name] = &blob{Name: name, tempfile: w, hash: sha256.New()}
return r.blobs[name], nil
}
return nil, fs.ErrInvalid
}
// Close closes the file. If the file was opened in read-write mode, it
// will finalize any writeable blobs and write the manifest.
func (r *Root) Close() error {
if r.flags&os.O_RDWR != 0 {
for _, b := range r.blobs {
if b.tempfile != nil {
fi, err := b.tempfile.Stat()
if err != nil {
return err
}
if err := b.tempfile.Close(); err != nil {
return err
}
b.Size = fi.Size()
b.Digest = fmt.Sprintf("sha256:%x", b.hash.Sum(nil))
if suffix, ok := strings.CutPrefix(b.Name, NamePrefix); ok {
if b.Name == NamePrefix {
b.MediaType = "application/vnd.docker.container.image.v1+json"
} else {
b.MediaType = MediaTypePrefix + suffix
}
b.Name = ""
}
rel, err := filepath.Rel(r.root.Name(), b.tempfile.Name())
if err != nil {
return err
}
if err := r.root.Rename(rel, filepath.Join("blobs", b.Filepath())); err != nil {
return err
}
}
}
p := filepath.Join("manifests", r.name.Filepath())
if _, err := r.root.Stat(filepath.Dir(p)); errors.Is(err, os.ErrNotExist) {
if err := r.root.MkdirAll(filepath.Dir(p), 0o750); err != nil {
return err
}
} else if err != nil {
return err
}
f, err := r.root.OpenFile(p, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o640)
if err != nil {
return err
}
defer f.Close()
if err := json.NewEncoder(f).Encode(manifest{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: r.blobs[NamePrefix],
Layers: func() []*blob {
blobs := make([]*blob, 0, len(r.blobs))
for name, b := range r.blobs {
if name != NamePrefix {
blobs = append(blobs, b)
}
}
return blobs
}(),
}); err != nil {
return err
}
}
return r.root.Close()
}
// Name returns the name of the file.
func (r Root) Name() Name {
return r.name
}
// Names returns an iterator over the names in the file.
func (r Root) Names() iter.Seq[string] {
return maps.Keys(r.blobs)
}
// Glob returns an iterator over the names in the file that match the given
// pattern.
//
// The pattern syntax is the same as [filepath.Match]. As with filepath.Match,
// the only possible returned error is ErrBadPattern, when pattern is malformed.
func (r Root) Glob(pattern string) (iter.Seq[string], error) {
if _, err := filepath.Match(pattern, ""); err != nil {
return nil, err
}
return func(yield func(string) bool) {
for name := range r.blobs {
if matched, _ := filepath.Match(pattern, name); matched {
if !yield(name) {
return
}
}
}
}, nil
}
func (r Root) JoinPath(parts ...string) string {
return filepath.Join(append([]string{r.root.Name()}, parts...)...)
}
func (r Root) Real(name string) string {
if b, ok := r.blobs[name]; ok {
return b.Filepath()
}
return ""
}

View File

@@ -1,90 +0,0 @@
package model
import (
"io"
"strings"
"testing"
)
// setup is a helper function to set up the test environment.
func setup(t *testing.T, models map[Name]map[string]io.Reader) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
for m, s := range models {
f, err := Create(m)
if err != nil {
t.Fatal(err)
}
for n, r := range s {
w, err := f.Create(n)
if err != nil {
t.Fatal(err)
}
if _, err := io.Copy(w, r); err != nil {
t.Fatal(err)
}
}
if err := f.Close(); err != nil {
t.Fatal(err)
}
}
}
func TestOpen(t *testing.T) {
setup(t, map[Name]map[string]io.Reader{
ParseName("namespace/model"): {
"./.": strings.NewReader(`{"key":"value"}`),
},
ParseName("namespace/model:8b"): {
"./.": strings.NewReader(`{"foo":"bar"}`),
},
ParseName("another/model"): {
"./.": strings.NewReader(`{"another":"config"}`),
},
})
f, err := Open(ParseName("namespace/model"))
if err != nil {
t.Fatal(err)
}
for _, name := range []string{"./."} {
r, err := f.Open(name)
if err != nil {
t.Fatal(err)
}
if _, err := io.ReadAll(r); err != nil {
t.Fatal(err)
}
if err := r.Close(); err != nil {
t.Fatal(err)
}
}
if err := f.Close(); err != nil {
t.Fatal(err)
}
t.Run("does not exist", func(t *testing.T) {
if _, err := Open(ParseName("namespace/unknown")); err == nil {
t.Error("expected error for unknown model")
}
})
t.Run("write", func(t *testing.T) {
f, err := Open(ParseName("namespace/model"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
if _, err := f.Create("new-blob"); err == nil {
t.Error("expected error creating blob in read-only mode")
}
})
}

View File

@@ -1,33 +0,0 @@
package model
import (
"io/fs"
"iter"
"path/filepath"
)
func All() (iter.Seq[Name], error) {
r, err := root()
if err != nil {
return nil, err
}
manifests, err := r.OpenRoot("manifests")
if err != nil {
return nil, err
}
matches, err := fs.Glob(manifests.FS(), "*/*/*/*")
if err != nil {
return nil, err
}
return func(yield func(Name) bool) {
for _, match := range matches {
name := ParseNameFromFilepath(filepath.ToSlash(match))
if !yield(name) {
return
}
}
}, nil
}

View File

@@ -227,17 +227,6 @@ func (n Name) String() string {
return b.String()
}
// Set implements [flag.Value]. It parses the provided input as a name string
// and sets the receiver to the parsed value. If the parsed name is not valid,
// ErrUnqualifiedName is returned.
func (n *Name) Set(s string) error {
*n = ParseName(s)
if !n.IsValid() {
return ErrUnqualifiedName
}
return nil
}
// DisplayShortest returns a short string version of the name.
func (n Name) DisplayShortest() string {
var sb strings.Builder

View File

@@ -1,94 +0,0 @@
package mlxrunner
import (
"log/slog"
"github.com/ollama/ollama/x/mlxrunner/cache"
)
type CacheEntry struct {
Caches []cache.Cache
Count int
Entries map[int32]*CacheEntry
}
func (s Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
current := &CacheEntry{Entries: s.CacheEntries}
index, cacheIndex := 0, -1
for _, token := range tokens {
if _, ok := current.Entries[token]; !ok {
break
}
current = current.Entries[token]
if len(current.Caches) > 0 {
cacheIndex = index
}
index += 1
}
if cacheIndex == len(tokens)-1 {
slog.Info("Cache hit", "type", "exact", "total", len(tokens), "cached", len(tokens), "left", len(tokens))
return current.Caches, []int32{}
} else if cacheIndex > 1 {
slog.Info("Cache hit", "type", "partial", "total", len(tokens), "cached", cacheIndex+1, "left", len(tokens[cacheIndex+1:]))
return current.Caches, tokens[cacheIndex+1:]
} else if index > 0 && cacheIndex < 0 {
type stackItem struct {
entry *CacheEntry
tokens []int32
}
var best, item stackItem
stack := []stackItem{{entry: current, tokens: []int32{}}}
for len(stack) > 0 {
item, stack = stack[len(stack)-1], stack[:len(stack)-1]
if len(item.entry.Caches) > 0 {
if len(best.tokens) == 0 || len(item.tokens) < len(best.tokens) {
best = item
}
} else {
for token, entry := range item.entry.Entries {
stack = append(stack, stackItem{
entry: entry,
tokens: append(item.tokens, token),
})
}
}
}
prefix := min(len(tokens)-1, index)
caches := make([]cache.Cache, len(best.entry.Caches))
trim := len(best.tokens)+1
for i := range caches {
caches[i] = best.entry.Caches[i].Clone()
caches[i].Trim(trim)
}
slog.Info("Cache hit", "type", "prefix", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]), "trimmed", trim)
return caches, tokens[prefix:]
}
slog.Info("Cache miss", "left", len(tokens))
return nil, tokens
}
func (s *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
current := &CacheEntry{Entries: s.CacheEntries}
for _, token := range tokens {
if _, ok := current.Entries[token]; !ok {
current.Entries[token] = &CacheEntry{
Entries: make(map[int32]*CacheEntry),
}
}
current = current.Entries[token]
}
if len(current.Caches) > 0 {
current.Count += 1
} else {
current.Caches = caches
}
}

View File

@@ -1,196 +0,0 @@
package cache
import (
"log/slog"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
type Cache interface {
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
State() (keys, values *mlx.Array)
Trim(int) int
Clone() Cache
Offset() int
Len() int
}
type KVCache struct {
keys, values *mlx.Array
offset int
step int
}
func NewKVCache() *KVCache {
return &KVCache{step: 256}
}
func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
prev := c.offset
// Grow buffer if needed
if c.keys == nil || (prev+L) > c.keys.Dim(2) {
steps := (c.step + L - 1) / c.step
newKeys := mlx.Zeros(keys.DType(), B, H, steps*c.step, Dk)
newValues := mlx.Zeros(values.DType(), B, H, steps*c.step, Dv)
if c.keys != nil {
if prev%c.step != 0 {
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice()))
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice()))
}
c.keys.Set(c.keys.Concatenate(2, newKeys))
c.values.Set(c.values.Concatenate(2, newValues))
} else {
c.keys, c.values = newKeys, newValues
}
}
c.offset += L
c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
}
func (c *KVCache) State() (*mlx.Array, *mlx.Array) {
if c.offset == c.keys.Dim(2) {
return c.keys, c.values
}
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
}
func (c *KVCache) Trim(n int) int {
n = min(c.offset, n)
c.offset -= n
return n
}
func (c *KVCache) Clone() Cache {
return &KVCache{
keys: c.keys.Clone(),
values: c.values.Clone(),
offset: c.offset,
step: c.step,
}
}
func (c *KVCache) Offset() int { return c.offset }
func (c *KVCache) Len() int { return c.offset }
// RotatingKVCache implements sliding window attention with bounded memory
type RotatingKVCache struct {
maxSize int
idx int
*KVCache
}
func NewRotatingKVCache(maxSize int) *RotatingKVCache {
return &RotatingKVCache{maxSize: maxSize, KVCache: NewKVCache()}
}
func (c *RotatingKVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
if keys.Dim(2) > 1 {
return c.concat(keys, values)
}
return c.update(keys, values)
}
func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV *mlx.Array) {
slog.Debug("(*RotatingKVCache).concat", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
if c.keys == nil {
c.keys, c.values = keys, values
} else {
if c.idx < c.keys.Dim(2) {
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
}
// Trim to max_size to maintain sliding window
if trim := c.idx - c.maxSize + 1; trim > 0 {
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.keys.Dim(2)), mlx.Slice()))
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.values.Dim(2)), mlx.Slice()))
}
c.keys.Set(c.keys.Concatenate(2, keys))
c.values.Set(c.values.Concatenate(2, values))
c.idx = c.keys.Dim(2)
}
c.offset += keys.Dim(2)
c.idx = c.keys.Dim(2)
return c.keys, c.values
}
func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
slog.Debug("(*RotatingKVCache).update", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
prev := c.offset
// Grow buffer if not yet at max
if c.keys == nil || (prev >= c.keys.Dim(2) && c.keys.Dim(2) < c.maxSize) {
newSize := min(c.step, c.maxSize-prev)
newKeys := mlx.Zeros(keys.DType(), B, H, newSize, Dk)
newValues := mlx.Zeros(values.DType(), B, H, newSize, Dv)
if c.keys != nil {
c.keys.Set(c.keys.Concatenate(2, newKeys))
c.values.Set(c.values.Concatenate(2, newValues))
} else {
c.keys, c.values = newKeys, newValues
}
c.idx = prev
}
// Trim to max_size to maintain sliding window
if trim := c.keys.Dim(2) - c.maxSize; trim > 0 {
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.keys.Dim(2)), mlx.Slice()))
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.values.Dim(2)), mlx.Slice()))
c.idx = c.maxSize
}
// Rotate when hitting max
if c.idx >= c.maxSize {
c.idx = 0
}
c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.idx+L), mlx.Slice()))
c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.idx+L), mlx.Slice()))
c.offset += L
c.idx += L
validLen := min(c.offset, c.maxSize)
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice())
}
func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) {
if c.offset < c.keys.Dim(2) {
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
}
return c.keys, c.values
}
func (c *RotatingKVCache) Trim(n int) int {
n = min(c.offset, n)
c.offset -= n
c.idx -= n
return n
}
func (c *RotatingKVCache) Clone() Cache {
return &RotatingKVCache{
maxSize: c.maxSize,
idx: c.idx,
KVCache: c.KVCache.Clone().(*KVCache),
}
}
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }

View File

@@ -1,174 +0,0 @@
package mlxrunner
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"math"
"net"
"net/http"
"net/url"
"os/exec"
"strconv"
"strings"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
)
type Client struct {
Port int
*exec.Cmd
}
func (c *Client) JoinPath(path string) string {
return (&url.URL{
Scheme: "http",
Host: net.JoinHostPort("127.0.0.1", strconv.Itoa(c.Port)),
}).JoinPath(path).String()
}
func (c *Client) CheckError(w *http.Response) error {
if w.StatusCode >= 400 {
return errors.New(w.Status)
}
return nil
}
// Close implements llm.LlamaServer.
func (c *Client) Close() error {
return c.Cmd.Process.Kill()
}
// Completion implements llm.LlamaServer.
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(req); err != nil {
return err
}
w, err := http.Post(c.JoinPath("/v1/completions"), "application/json", &b)
if err != nil {
return err
}
defer w.Body.Close()
if err := c.CheckError(w); err != nil {
return err
}
scanner := bufio.NewScanner(w.Body)
for scanner.Scan() {
bts := scanner.Bytes()
var resp llm.CompletionResponse
if err := json.Unmarshal(bts, &resp); err != nil {
return err
}
fn(resp)
}
return nil
}
func (c *Client) ContextLength() int {
return math.MaxInt
}
// Detokenize implements llm.LlamaServer.
func (c *Client) Detokenize(ctx context.Context, tokens []int) (string, error) {
panic("unimplemented")
}
// Embedding implements llm.LlamaServer.
func (c *Client) Embedding(ctx context.Context, input string) ([]float32, int, error) {
panic("unimplemented")
}
// GetDeviceInfos implements llm.LlamaServer.
func (c *Client) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
panic("unimplemented")
}
// GetPort implements llm.LlamaServer.
func (c *Client) GetPort() int {
return c.Port
}
// HasExited implements llm.LlamaServer.
func (c *Client) HasExited() bool {
panic("unimplemented")
}
// Load implements llm.LlamaServer.
func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) ([]ml.DeviceID, error) {
w, err := http.Post(c.JoinPath("/v1/models"), "application/json", nil)
if err != nil {
return nil, err
}
defer w.Body.Close()
return []ml.DeviceID{}, nil
}
// ModelPath implements llm.LlamaServer.
func (c *Client) ModelPath() string {
panic("unimplemented")
}
// Pid implements llm.LlamaServer.
func (c *Client) Pid() int {
panic("unimplemented")
}
// Ping implements llm.LlamaServer.
func (c *Client) Ping(ctx context.Context) error {
w, err := http.Get(c.JoinPath("/v1/status"))
if err != nil {
return err
}
defer w.Body.Close()
return nil
}
// Tokenize implements llm.LlamaServer.
func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
w, err := http.Post(c.JoinPath("/v1/tokenize"), "text/plain", strings.NewReader(content))
if err != nil {
return nil, err
}
defer w.Body.Close()
var tokens []int
if err := json.NewDecoder(w.Body).Decode(&tokens); err != nil {
return nil, err
}
return tokens, nil
}
// TotalSize implements llm.LlamaServer.
func (c *Client) TotalSize() uint64 {
panic("unimplemented")
}
// VRAMByGPU implements llm.LlamaServer.
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
panic("unimplemented")
}
// VRAMSize implements llm.LlamaServer.
func (c *Client) VRAMSize() uint64 {
panic("unimplemented")
}
// WaitUntilRunning implements llm.LlamaServer.
func (c *Client) WaitUntilRunning(ctx context.Context) error {
panic("unimplemented")
}
var _ llm.LlamaServer = (*Client)(nil)

View File

@@ -1,3 +0,0 @@
_deps
build
dist

View File

@@ -1,26 +0,0 @@
cmake_minimum_required(VERSION 3.5)
project(mlx)
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
set(CMAKE_INSTALL_PREFIX "${CMAKE_CURRENT_SOURCE_DIR}/dist" CACHE PATH "" FORCE)
endif()
set(MLX_BUILD_GGUF OFF CACHE BOOL "" FORCE)
set(MLX_BUILD_SAFETENSORS ON CACHE BOOL "" FORCE)
set(MLX_C_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
set(CMAKE_INSTALL_RPATH "@loader_path")
include(FetchContent)
set(MLX_C_GIT_TAG "v0.4.1" CACHE STRING "")
FetchContent_Declare(
mlx-c
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
GIT_TAG ${MLX_C_GIT_TAG}
)
FetchContent_MakeAvailable(mlx-c)

View File

@@ -1,45 +0,0 @@
package mlx
// #include "generated.h"
import "C"
import (
"math"
"sync"
)
var geluApprox = sync.OnceValue(func() *Closure {
return Compile(func(inputs []*Array) []*Array {
input := inputs[0]
return []*Array{
input.Multiply(
FromValue[float32](0.5),
).Multiply(
input.Add(
input.Power(FromValue[float32](3.0)).Multiply(FromValue[float32](0.044715)),
).Multiply(
FromValue(float32(math.Sqrt(2 / math.Pi))),
).Tanh().Add(FromValue[float32](1.0)),
).AsType(input.DType()),
}
}, true)
})
var silu = sync.OnceValue(func() *Closure {
return Compile(func(inputs []*Array) []*Array {
input := inputs[0]
return []*Array{
input.Multiply(
input.Sigmoid(),
).AsType(input.DType()),
}
}, true)
})
func GELUApprox(t *Array) *Array {
return geluApprox().Call([]*Array{t})[0]
}
func SILU(t *Array) *Array {
return silu().Call([]*Array{t})[0]
}

View File

@@ -1,271 +0,0 @@
package mlx
// #include "generated.h"
import "C"
import (
"encoding/binary"
"log/slog"
"reflect"
"strings"
"time"
"unsafe"
"github.com/ollama/ollama/logutil"
)
type tensorDesc struct {
name string
inputs []*Array
numRefs int
}
func (d tensorDesc) LogValue() slog.Value {
return slog.GroupValue(
slog.String("name", d.name),
slog.Int("inputs", len(d.inputs)),
slog.Int("num_refs", d.numRefs),
)
}
type Array struct {
ctx C.mlx_array
desc tensorDesc
}
// constructor utilities
func New(name string, inputs ...*Array) *Array {
t := &Array{
desc: tensorDesc{
name: name,
inputs: inputs,
},
}
for _, input := range inputs {
input.desc.numRefs++
}
logutil.Trace("New", "t", t)
return t
}
type scalarTypes interface {
~bool | ~int | ~float32 | ~float64 | ~complex64
}
func FromValue[T scalarTypes](t T) *Array {
tt := New("")
switch v := any(t).(type) {
case bool:
tt.ctx = C.mlx_array_new_bool(C.bool(v))
case int:
tt.ctx = C.mlx_array_new_int(C.int(v))
case float32:
tt.ctx = C.mlx_array_new_float32(C.float(v))
case float64:
tt.ctx = C.mlx_array_new_float64(C.double(v))
case complex64:
tt.ctx = C.mlx_array_new_complex(C.float(real(v)), C.float(imag(v)))
default:
panic("unsupported type")
}
return tt
}
type arrayTypes interface {
~bool | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
~int8 | ~int16 | ~int32 | ~int64 |
~float32 | ~float64 |
~complex64
}
func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
if len(shape) == 0 {
panic("shape must be provided for non-scalar tensors")
}
cShape := make([]C.int, len(shape))
for i := range shape {
cShape[i] = C.int(shape[i])
}
var dtype DType
switch reflect.TypeOf(s).Elem().Kind() {
case reflect.Bool:
dtype = DTypeBool
case reflect.Uint8:
dtype = DTypeUint8
case reflect.Uint16:
dtype = DTypeUint16
case reflect.Uint32:
dtype = DTypeUint32
case reflect.Uint64:
dtype = DTypeUint64
case reflect.Int8:
dtype = DTypeInt8
case reflect.Int16:
dtype = DTypeInt16
case reflect.Int32:
dtype = DTypeInt32
case reflect.Int64:
dtype = DTypeInt64
case reflect.Float32:
dtype = DTypeFloat32
case reflect.Float64:
dtype = DTypeFloat64
case reflect.Complex64:
dtype = DTypeComplex64
default:
panic("unsupported type")
}
bts := make([]byte, binary.Size(s))
if _, err := binary.Encode(bts, binary.LittleEndian, s); err != nil {
panic(err)
}
tt := New("")
tt.ctx = C.mlx_array_new_data(unsafe.Pointer(&bts[0]), unsafe.SliceData(cShape), C.int(len(cShape)), C.mlx_dtype(dtype))
return tt
}
func (t *Array) Set(other *Array) {
other.desc.numRefs++
t.desc.inputs = []*Array{other}
C.mlx_array_set(&t.ctx, other.ctx)
}
func (t *Array) Clone() *Array {
tt := New(t.desc.name, t.desc.inputs...)
C.mlx_array_set(&tt.ctx, t.ctx)
return tt
}
// misc. utilities
func (t *Array) Valid() bool {
return t.ctx.ctx != nil
}
func (t *Array) String() string {
str := C.mlx_string_new()
defer C.mlx_string_free(str)
C.mlx_array_tostring(&str, t.ctx)
return strings.TrimSpace(C.GoString(C.mlx_string_data(str)))
}
func (t *Array) LogValue() slog.Value {
attrs := []slog.Attr{slog.Any("", t.desc)}
if t.Valid() {
attrs = append(attrs,
slog.Any("dtype", t.DType()),
slog.Any("shape", t.Dims()),
slog.Int("num_bytes", t.NumBytes()),
)
}
return slog.GroupValue(attrs...)
}
// shape utilities
func (t Array) Size() int {
return int(C.mlx_array_size(t.ctx))
}
func (t Array) NumBytes() int {
return int(C.mlx_array_nbytes(t.ctx))
}
func (t Array) NumDims() int {
return int(C.mlx_array_ndim(t.ctx))
}
func (t Array) Dims() []int {
dims := make([]int, t.NumDims())
for i := range dims {
dims[i] = t.Dim(i)
}
return dims
}
func (t Array) Dim(dim int) int {
return int(C.mlx_array_dim(t.ctx, C.int(dim)))
}
func (t Array) DType() DType {
return DType(C.mlx_array_dtype(t.ctx))
}
// data utilities
func (t Array) Int() int {
var item C.int64_t
C.mlx_array_item_int64(&item, t.ctx)
return int(item)
}
func (t Array) Float() float64 {
var item C.double
C.mlx_array_item_float64(&item, t.ctx)
return float64(item)
}
func (t Array) Ints() []int {
ints := make([]int, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) {
ints[i] = int(f)
}
return ints
}
func (t Array) Floats() []float32 {
floats := make([]float32, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) {
floats[i] = float32(f)
}
return floats
}
func (t Array) Save(name string) error {
cName := C.CString(name)
defer C.free(unsafe.Pointer(cName))
C.mlx_save(cName, t.ctx)
return nil
}
func Free(s ...*Array) (n int) {
now := time.Now()
defer func() {
if n > 0 {
logutil.Trace("Freed tensors", "num_bytes", PrettyBytes(n), "took", time.Since(now))
}
}()
free := make([]*Array, 0, 8192)
fn := func(t *Array) {
if t.Valid() {
free = append(free, t.desc.inputs...)
t.desc.numRefs--
if t.desc.numRefs <= 0 {
logutil.Trace("Free", "t", t)
n += t.NumBytes()
C.mlx_array_free(t.ctx)
t.ctx.ctx = nil
}
}
}
for _, t := range s {
fn(t)
}
for len(free) > 0 {
tail := free[len(free)-1]
free = free[:len(free)-1]
fn(tail)
}
return n
}

View File

@@ -1,43 +0,0 @@
package mlx
import "testing"
func TestFromValue(t *testing.T) {
for got, want := range map[*Tensor]DType{
FromValue(true): DTypeBool,
FromValue(false): DTypeBool,
FromValue(int(7)): DTypeInt32,
FromValue(float32(3.14)): DTypeFloat32,
FromValue(float64(2.71)): DTypeFloat64,
FromValue(complex64(1 + 2i)): DTypeComplex64,
} {
t.Run(want.String(), func(t *testing.T) {
if got.DType() != want {
t.Errorf("want %v, got %v", want, got)
}
})
}
}
func TestFromValues(t *testing.T) {
for got, want := range map[*Tensor]DType{
FromValues([]bool{true, false, true}, 3): DTypeBool,
FromValues([]uint8{1, 2, 3}, 3): DTypeUint8,
FromValues([]uint16{1, 2, 3}, 3): DTypeUint16,
FromValues([]uint32{1, 2, 3}, 3): DTypeUint32,
FromValues([]uint64{1, 2, 3}, 3): DTypeUint64,
FromValues([]int8{-1, -2, -3}, 3): DTypeInt8,
FromValues([]int16{-1, -2, -3}, 3): DTypeInt16,
FromValues([]int32{-1, -2, -3}, 3): DTypeInt32,
FromValues([]int64{-1, -2, -3}, 3): DTypeInt64,
FromValues([]float32{3.14, 2.71, 1.61}, 3): DTypeFloat32,
FromValues([]float64{3.14, 2.71, 1.61}, 3): DTypeFloat64,
FromValues([]complex64{1 + 2i, 3 + 4i, 5 + 6i}, 3): DTypeComplex64,
} {
t.Run(want.String(), func(t *testing.T) {
if got.DType() != want {
t.Errorf("want %v, got %v", want, got)
}
})
}
}

View File

@@ -1,76 +0,0 @@
package mlx
// #include "generated.h"
// int goClosureFunc(mlx_vector_array*, mlx_vector_array, void*);
// void goClosureDestructor(void*);
import "C"
import (
"runtime/cgo"
"unsafe"
)
type Closure struct {
ctx C.mlx_closure
}
func (c Closure) Call(inputs []*Array) []*Array {
inputsVector := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(inputsVector)
for _, input := range inputs {
C.mlx_vector_array_append_value(inputsVector, input.ctx)
}
outputsVector := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(outputsVector)
C.mlx_closure_apply(&outputsVector, c.ctx, inputsVector)
outputs := make([]*Array, int(C.mlx_vector_array_size(outputsVector)))
for i := range outputs {
t := New("", inputs...)
C.mlx_vector_array_get(&t.ctx, outputsVector, C.size_t(i))
outputs[i] = t
}
return outputs
}
func Compile(fn func([]*Array) []*Array, shapeless bool) *Closure {
closure := C.mlx_closure_new_func_payload(
(*[0]byte)(C.goClosureFunc),
unsafe.Pointer(cgo.NewHandle(fn)),
(*[0]byte)(C.goClosureDestructor),
)
compiled := C.mlx_closure_new()
C.mlx_compile(&compiled, closure, C.bool(shapeless))
return &Closure{ctx: compiled}
}
//export goClosureFunc
func goClosureFunc(outputsVector *C.mlx_vector_array, inputsVector C.mlx_vector_array, payload unsafe.Pointer) C.int {
handle := cgo.Handle(payload)
fn := handle.Value().(func([]*Array) []*Array)
inputs := make([]*Array, int(C.mlx_vector_array_size(inputsVector)))
for i := range inputs {
t := New("")
C.mlx_vector_array_get(&t.ctx, inputsVector, C.size_t(i))
inputs[i] = t
}
var outputs []C.mlx_array
for _, output := range fn(inputs) {
outputs = append(outputs, output.ctx)
}
C.mlx_vector_array_set_data(outputsVector, unsafe.SliceData(outputs), C.size_t(len(outputs)))
return 0
}
//export goClosureDestructor
func goClosureDestructor(payload unsafe.Pointer) {
cgo.Handle(payload).Delete()
}

View File

@@ -1,94 +0,0 @@
package mlx
// #include "generated.h"
import "C"
type DType int
func (t DType) String() string {
switch t {
case DTypeBool:
return "BOOL"
case DTypeUint8:
return "U8"
case DTypeUint16:
return "U16"
case DTypeUint32:
return "U32"
case DTypeUint64:
return "U64"
case DTypeInt8:
return "I8"
case DTypeInt16:
return "I16"
case DTypeInt32:
return "I32"
case DTypeInt64:
return "I64"
case DTypeFloat16:
return "F16"
case DTypeFloat32:
return "F32"
case DTypeFloat64:
return "F64"
case DTypeBFloat16:
return "BF16"
case DTypeComplex64:
return "C64"
default:
return "Unknown"
}
}
func (t *DType) UnmarshalJSON(b []byte) error {
switch string(b) {
case `"BOOL"`:
*t = DTypeBool
case `"U8"`:
*t = DTypeUint8
case `"U16"`:
*t = DTypeUint16
case `"U32"`:
*t = DTypeUint32
case `"U64"`:
*t = DTypeUint64
case `"I8"`:
*t = DTypeInt8
case `"I16"`:
*t = DTypeInt16
case `"I32"`:
*t = DTypeInt32
case `"I64"`:
*t = DTypeInt64
case `"F16"`:
*t = DTypeFloat16
case `"F64"`:
*t = DTypeFloat64
case `"F32"`:
*t = DTypeFloat32
case `"BF16"`:
*t = DTypeBFloat16
case `"C64"`:
*t = DTypeComplex64
default:
return nil
}
return nil
}
const (
DTypeBool DType = C.MLX_BOOL
DTypeUint8 DType = C.MLX_UINT8
DTypeUint16 DType = C.MLX_UINT16
DTypeUint32 DType = C.MLX_UINT32
DTypeUint64 DType = C.MLX_UINT64
DTypeInt8 DType = C.MLX_INT8
DTypeInt16 DType = C.MLX_INT16
DTypeInt32 DType = C.MLX_INT32
DTypeInt64 DType = C.MLX_INT64
DTypeFloat16 DType = C.MLX_FLOAT16
DTypeFloat32 DType = C.MLX_FLOAT32
DTypeFloat64 DType = C.MLX_FLOAT64
DTypeBFloat16 DType = C.MLX_BFLOAT16
DTypeComplex64 DType = C.MLX_COMPLEX64
)

View File

@@ -1,34 +0,0 @@
#include "dynamic.h"
#include <stdio.h>
#ifdef _WIN32
#include <windows.h>
#define DLOPEN(path) LoadLibraryA(path)
#define DLCLOSE(handle) FreeLibrary((HMODULE)(handle))
#else
#ifdef __APPLE__
#include <mach-o/dyld.h>
#include <libgen.h>
#endif
#include <dlfcn.h>
#define DLOPEN(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
#define DLCLOSE(handle) dlclose(handle)
#endif
static int mlx_dynamic_open(mlx_dynamic_handle* handle, const char* path) {
handle->ctx = (void*) DLOPEN(path);
CHECK(handle->ctx != NULL);
return 0;
}
int mlx_dynamic_load(mlx_dynamic_handle* handle, const char *path) {
return mlx_dynamic_open(handle, path);
}
void mlx_dynamic_unload(mlx_dynamic_handle* handle) {
if (handle->ctx) {
DLCLOSE(handle->ctx);
handle->ctx = NULL;
}
}

View File

@@ -1,63 +0,0 @@
package mlx
// #include "dynamic.h"
// #include "generated.h"
// #include <stdlib.h>
import "C"
import (
"io/fs"
"log/slog"
"os"
"path/filepath"
"runtime"
"unsafe"
)
func init() {
switch runtime.GOOS {
case "darwin":
case "windows":
default:
return
}
paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH")
if !ok {
slog.Debug("OLLAMA_LIBRARY_PATH not set, skipping mlx dynamic loading")
return
}
for _, path := range filepath.SplitList(paths) {
matches, err := fs.Glob(os.DirFS(path), "libmlxc.*")
if err != nil {
panic(err)
}
for _, match := range matches {
path := filepath.Join(paths, match)
slog.Info("Loading MLX dynamic library", "path", path)
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
var handle C.mlx_dynamic_handle
if C.mlx_dynamic_load(&handle, cPath) != 0 {
slog.Error("Failed to load MLX dynamic library", "path", path)
continue
}
if C.mlx_dynamic_load_symbols(handle) != 0 {
slog.Error("Failed to load MLX dynamic library symbols", "path", path)
C.mlx_dynamic_unload(&handle)
continue
}
slog.Info("Loaded MLX dynamic library", "path", path)
return
}
}
panic("Failed to load any MLX dynamic library")
}

View File

@@ -1,27 +0,0 @@
#ifndef MLX_DYNAMIC_H
#define MLX_DYNAMIC_H
#ifdef _WIN32
#include <windows.h>
#define DLSYM(handle, symbol) GetProcAddress((HMODULE)(handle), symbol)
#else
#include <dlfcn.h>
#define DLSYM(handle, symbol) dlsym(handle.ctx, symbol)
#endif
#define ERROR(fmt, ...) fprintf(stderr, "%s %s - ERROR - %s:%d - " fmt "\n", __DATE__, __TIME__, __FILE__, __LINE__, ##__VA_ARGS__); return 1
#define CHECK(x) if (!(x)) { ERROR("CHECK failed: " #x); }
#define CHECK_LOAD(handle, x) x##_ = DLSYM(handle, #x); CHECK(x##_)
typedef struct {
void* ctx;
} mlx_dynamic_handle;
int mlx_dynamic_load(
mlx_dynamic_handle* handle,
const char *path);
void mlx_dynamic_unload(
mlx_dynamic_handle* handle);
#endif // MLX_DYNAMIC_H

View File

@@ -1,72 +0,0 @@
package mlx
// #include "generated.h"
import "C"
import (
"unsafe"
)
func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *Array {
if mask == nil {
mask = New("")
}
sinks := New("")
mode := "causal"
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
out := New("FAST_SDPA", query, key, value, mask, sinks)
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
return out
}
type LayerNorm struct {
Weight Array `weight:"weight"`
Bias Array `weight:"bias"`
}
func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
out := New("FAST_LAYERNORM", x)
C.mlx_fast_layer_norm(&out.ctx, x.ctx, r.Weight.ctx, r.Bias.ctx, C.float(eps), DefaultStream().ctx)
return out
}
type RMSNorm struct {
Weight Array `weight:"weight"`
}
func (r RMSNorm) Forward(x *Array, eps float32) *Array {
out := New("FAST_RMSNORM", x)
C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx)
return out
}
type RoPE struct {
Dims int
Traditional bool
Base float32 `json:"rope_theta"`
Scale float32
}
func (r RoPE) Forward(t *Array, offset int) *Array {
freqs := New("")
out := New("FAST_ROPE", t, freqs)
C.mlx_fast_rope(
&out.ctx,
t.ctx,
C.int(r.Dims),
C._Bool(r.Traditional),
C.mlx_optional_float{
value: C.float(r.Base),
has_value: C._Bool(func() bool { return r.Base != 0 }()),
},
C.float(r.Scale),
C.int(offset),
freqs.ctx,
DefaultStream().ctx,
)
return out
}

View File

File diff suppressed because it is too large Load Diff

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,24 +0,0 @@
// This code is auto-generated; DO NOT EDIT.
#include "generated.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
{{ range .Functions }}
{{ .Type }} (*{{ .Name }}_){{ .Parameters }} = NULL;
{{- end }}
int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
{{- range .Functions }}
CHECK_LOAD(handle, {{ .Name }});
{{- end }}
return 0;
}
{{- range .Functions }}
{{ .Type }} {{ .Name }}{{ .Parameters }} {{ "{" }}
return {{ .Name }}_({{ .Args }});
{{ "}" }}
{{- end }}

View File

@@ -1,20 +0,0 @@
// This code is auto-generated; DO NOT EDIT.
#ifndef MLX_GENERATED_H
#define MLX_GENERATED_H
#include "dynamic.h"
#include "mlx/c/mlx.h"
{{ range .Functions }}
#undef {{ .Name }}
{{- end }}
{{ range .Functions }}
extern {{ .Type }} (*{{ .Name }}_){{ .Parameters }};
{{- end }}
int mlx_dynamic_load_symbols(mlx_dynamic_handle handle);
{{ range .Functions }}
{{ .Type }} {{ .Name }}{{ .Parameters }};
{{- end }}
#endif // MLX_GENERATED_H

View File

@@ -1,135 +0,0 @@
package main
import (
"embed"
"flag"
"fmt"
"os"
"path/filepath"
"slices"
"strings"
"text/template"
tree_sitter "github.com/tree-sitter/go-tree-sitter"
tree_sitter_cpp "github.com/tree-sitter/tree-sitter-cpp/bindings/go"
)
//go:embed *.gotmpl
var fsys embed.FS
type Function struct {
Type,
Name,
Parameters,
Args string
}
func ParseFunction(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) Function {
var fn Function
fn.Name = node.ChildByFieldName("declarator").Utf8Text(source)
if params := node.ChildByFieldName("parameters"); params != nil {
fn.Parameters = params.Utf8Text(source)
fn.Args = ParseParameters(params, tc, source)
}
var types []string
for node.Parent() != nil && node.Parent().Kind() != "declaration" {
if node.Parent().Kind() == "pointer_declarator" {
types = append(types, "*")
}
node = node.Parent()
}
for sibling := node.PrevSibling(); sibling != nil; sibling = sibling.PrevSibling() {
types = append(types, sibling.Utf8Text(source))
}
slices.Reverse(types)
fn.Type = strings.Join(types, " ")
return fn
}
func ParseParameters(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) string {
var s []string
for _, child := range node.Children(tc) {
if child.IsNamed() {
child := child.ChildByFieldName("declarator")
for child != nil && child.Kind() != "identifier" {
if child.Kind() == "parenthesized_declarator" {
child = child.Child(1)
} else {
child = child.ChildByFieldName("declarator")
}
}
if child != nil {
s = append(s, child.Utf8Text(source))
}
}
}
return strings.Join(s, ", ")
}
func main() {
var output string
flag.StringVar(&output, "output", ".", "Output directory for generated files")
flag.Parse()
parser := tree_sitter.NewParser()
defer parser.Close()
language := tree_sitter.NewLanguage(tree_sitter_cpp.Language())
parser.SetLanguage(language)
query, _ := tree_sitter.NewQuery(language, `(function_declarator declarator: (identifier)) @func`)
defer query.Close()
qc := tree_sitter.NewQueryCursor()
defer qc.Close()
var funs []Function
for _, arg := range flag.Args() {
bts, err := os.ReadFile(arg)
if err != nil {
fmt.Fprintf(os.Stderr, "Error reading file %s: %v\n", arg, err)
continue
}
tree := parser.Parse(bts, nil)
defer tree.Close()
tc := tree.Walk()
defer tc.Close()
matches := qc.Matches(query, tree.RootNode(), bts)
for match := matches.Next(); match != nil; match = matches.Next() {
for _, capture := range match.Captures {
funs = append(funs, ParseFunction(&capture.Node, tc, bts))
}
}
}
tmpl, err := template.New("").ParseFS(fsys, "*.gotmpl")
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing template: %v\n", err)
return
}
for _, tmpl := range tmpl.Templates() {
name := filepath.Join(output, strings.TrimSuffix(tmpl.Name(), ".gotmpl"))
fmt.Println("Generating", name)
f, err := os.Create(name)
if err != nil {
fmt.Fprintf(os.Stderr, "Error creating file %s: %v\n", name, err)
continue
}
defer f.Close()
if err := tmpl.Execute(f, map[string]any{
"Functions": funs,
}); err != nil {
fmt.Fprintf(os.Stderr, "Error executing template %s: %v\n", tmpl.Name(), err)
}
}
}

View File

@@ -1,166 +0,0 @@
package mlx
// #include "generated.h"
import "C"
import (
"encoding/binary"
"encoding/json"
"errors"
"io"
"iter"
"log/slog"
"maps"
"slices"
"unsafe"
"github.com/ollama/ollama/types/model"
)
func Load(path string) iter.Seq2[string, *Array] {
return func(yield func(string, *Array) bool) {
string2array := C.mlx_map_string_to_array_new()
defer C.mlx_map_string_to_array_free(string2array)
string2string := C.mlx_map_string_to_string_new()
defer C.mlx_map_string_to_string_free(string2string)
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
cpu := C.mlx_default_cpu_stream_new()
defer C.mlx_stream_free(cpu)
C.mlx_load_safetensors(&string2array, &string2string, cPath, cpu)
it := C.mlx_map_string_to_array_iterator_new(string2array)
defer C.mlx_map_string_to_array_iterator_free(it)
for {
var key *C.char
value := C.mlx_array_new()
if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 {
break
}
name := C.GoString(key)
if !yield(name, &Array{ctx: value, desc: tensorDesc{name: name, numRefs: 1000}}) {
break
}
}
}
}
func Parse(root *model.Root, path string) (map[string]Quantization, error) {
f, err := root.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
var n uint64
if err := binary.Read(f, binary.LittleEndian, &n); err != nil {
return nil, err
}
bts := make([]byte, n)
if _, err := io.ReadFull(f, bts); err != nil {
return nil, err
}
var m struct {
Metadata struct {
Quantization map[string]Quantization `json:"quantization"`
} `json:"__metadata__"`
}
if err := json.Unmarshal(bts, &m); err != nil {
return nil, err
}
return m.Metadata.Quantization, nil
}
func LoadWeights(root *model.Root, match string, states map[string]*Array) error {
slog.Debug("Loading weights from", "file", match)
for name, weight := range Load(root.JoinPath("blobs", root.Real(match))) {
if state, ok := states[name]; ok {
*state = *weight
}
}
return nil
}
func LoadQuantizations(root *model.Root, match string, quantizations map[string]*Quantization) error {
slog.Debug("Loading quantizations from", "file", match)
metadata, err := Parse(root, match)
if err != nil {
return err
}
for name := range metadata {
if q, ok := quantizations[name+".weight"]; ok {
q.GroupSize = metadata[name].GroupSize
q.Bits = metadata[name].Bits
q.Mode = metadata[name].Mode
}
}
return nil
}
type AfterLoadFunc func(*model.Root) ([]*Array, error)
func LoadAll(root *model.Root, states map[string]*Array, quantizations map[string]*Quantization, afterLoadFuncs []AfterLoadFunc) error {
matches, err := root.Glob("model*.safetensors")
if err != nil {
return err
}
for match := range matches {
if err := errors.Join(
LoadWeights(root, match, states),
LoadQuantizations(root, match, quantizations),
); err != nil {
return err
}
}
for _, afterLoadFunc := range afterLoadFuncs {
weights, err := afterLoadFunc(root)
if err != nil {
return err
}
for _, weight := range weights {
weight.desc.numRefs = 1000
Eval(weight)
var freeAll func(...*Array)
freeAll = func(inputs ...*Array) {
for _, input := range inputs {
input.desc.numRefs = 0
freeAll(input.desc.inputs...)
}
Free(inputs...)
}
freeAll(weight.desc.inputs...)
}
}
Eval(slices.Collect(maps.Values(states))...)
ClearCache()
slog.Info("Loaded weights", "count", len(states), "memory", Memory{})
return nil
}
func UnloadAll(states map[string]*Array) {
weights := slices.Collect(maps.Values(states))
for _, weight := range weights {
weight.desc.numRefs = 0
}
numBytes := Free(weights...)
slog.Info("Unloaded weights", "count", len(states), "num_bytes", PrettyBytes(numBytes), "memory", Memory{})
}

View File

@@ -1,85 +0,0 @@
package mlx
// #include "generated.h"
import "C"
import (
"fmt"
"log/slog"
"strconv"
)
func (b Byte) String() string {
return strconv.FormatInt(int64(b), 10) + " B"
}
func (b KibiByte) String() string {
return strconv.FormatFloat(float64(b)/(1<<10), 'f', 2, 64) + " KiB"
}
func (b MebiByte) String() string {
return strconv.FormatFloat(float64(b)/(1<<(2*10)), 'f', 2, 64) + " MiB"
}
func (b GibiByte) String() string {
return strconv.FormatFloat(float64(b)/(1<<(3*10)), 'f', 2, 64) + " GiB"
}
func (b TebiByte) String() string {
return strconv.FormatFloat(float64(b)/(1<<(4*10)), 'f', 2, 64) + " TiB"
}
func PrettyBytes(n int) fmt.Stringer {
switch {
case n < 1<<10:
return Byte(n)
case n < 1<<(2*10):
return KibiByte(n)
case n < 1<<(3*10):
return MebiByte(n)
case n < 1<<(4*10):
return GibiByte(n)
default:
return TebiByte(n)
}
}
func ActiveMemory() int {
var active C.size_t
C.mlx_get_active_memory(&active)
return int(active)
}
func CacheMemory() int {
var cache C.size_t
C.mlx_get_cache_memory(&cache)
return int(cache)
}
func PeakMemory() int {
var peak C.size_t
C.mlx_get_peak_memory(&peak)
return int(peak)
}
type Memory struct{}
func (Memory) LogValue() slog.Value {
return slog.GroupValue(
slog.Any("active", PrettyBytes(ActiveMemory())),
slog.Any("cache", PrettyBytes(CacheMemory())),
slog.Any("peak", PrettyBytes(PeakMemory())),
)
}
type (
Byte int
KibiByte int
MebiByte int
GibiByte int
TebiByte int
)
func ClearCache() {
C.mlx_clear_cache()
}

View File

@@ -1,38 +0,0 @@
package mlx
//go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist -DCMAKE_BUILD_TYPE=Release
//go:generate cmake --build build --parallel
//go:generate cmake --install build
//go:generate sh -c "go run generator/main.go -output=. ./dist/include/mlx/c/*.h"
// #cgo CXXFLAGS: -std=c++17
// #cgo CPPFLAGS: -I${SRCDIR}/dist/include
// #cgo LDFLAGS: -L${SRCDIR}/dist/lib -lstdc++
// #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
// #include "generated.h"
import "C"
func doEval(outputs []*Array, async bool) {
vector := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(vector)
for _, output := range outputs {
if output.Valid() {
C.mlx_vector_array_append_value(vector, output.ctx)
}
}
if async {
C.mlx_async_eval(vector)
} else {
C.mlx_eval(vector)
}
}
func AsyncEval(outputs ...*Array) {
doEval(outputs, true)
}
func Eval(outputs ...*Array) {
doEval(outputs, false)
}

View File

@@ -1,102 +0,0 @@
package mlx
import "cmp"
type Quantization struct {
Scales Array `weight:"scales"`
Biases Array `weight:"biases"`
GroupSize int `json:"group_size"`
Bits int `json:"bits"`
Mode string `json:"mode"`
}
type Linear struct {
Weight Array `weight:"weight"`
Bias Array `weight:"bias"`
Quantization
}
// Forward computes the linear transformation: x @ Weight.T + Bias
func (m Linear) Forward(x *Array) *Array {
if m.Scales.Valid() {
x = x.QuantizedMatmul(
&m.Weight,
&m.Scales,
&m.Biases,
true,
m.GroupSize,
m.Bits,
cmp.Or(m.Mode, "affine"),
)
if m.Bias.Valid() {
x = m.Bias.Add(x)
}
return x
}
w := m.Weight.Transpose(1, 0)
if m.Bias.Valid() {
return m.Bias.Addmm(x, w, 1.0, 1.0)
}
return x.Matmul(w)
}
func (m Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array {
if m.Scales.Valid() {
x = x.GatherQMM(
&m.Weight,
&m.Scales,
&m.Biases,
lhs,
rhs,
sorted,
m.GroupSize,
m.Bits,
cmp.Or(m.Mode, "affine"),
sorted,
)
if m.Bias.Valid() {
x = m.Bias.Add(x)
}
return x
} else {
w := m.Weight.Transpose(0, 2, 1)
x = x.GatherMM(w, lhs, rhs, sorted)
}
if m.Bias.Valid() {
x = m.Bias.Add(x)
}
return x
}
type Embedding struct {
Weight Array `weight:"weight"`
Quantization
}
func (e *Embedding) Forward(indices *Array) *Array {
if e.Scales.Valid() {
w := e.Weight.TakeAxis(indices, 0)
return w.Dequantize(
e.Scales.TakeAxis(indices, 0),
e.Biases.TakeAxis(indices, 0),
e.GroupSize,
e.Bits,
cmp.Or(e.Mode, "affine"),
)
}
return e.Weight.TakeAxis(indices, 0)
}
func (e *Embedding) AsLinear() Linear {
return Linear{
Weight: e.Weight,
Quantization: e.Quantization,
}
}

View File

@@ -1,341 +0,0 @@
package mlx
// #include "generated.h"
import "C"
import (
"unsafe"
)
func (t *Array) Abs() *Array {
out := New("ABS", t)
C.mlx_abs(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Add(other *Array) *Array {
out := New("ADD", t, other)
C.mlx_add(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Addmm(a, b *Array, alpha, beta float32) *Array {
out := New("ADDMM", t, a, b)
C.mlx_addmm(&out.ctx, t.ctx, a.ctx, b.ctx, C.float(alpha), C.float(beta), DefaultStream().ctx)
return out
}
func (t *Array) Argmax(axis int, keepDims bool) *Array {
out := New("ARGMAX", t)
C.mlx_argmax_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) ArgpartitionAxis(kth int, axis int) *Array {
out := New("ARGPARTITION", t)
C.mlx_argpartition_axis(&out.ctx, t.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) ArgsortAxis(axis int) *Array {
out := New("ARGSORT_AXIS", t)
C.mlx_argsort_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) AsType(dtype DType) *Array {
out := New("AS_TYPE", t)
C.mlx_astype(&out.ctx, t.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
return out
}
func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array {
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
}
cStrides := make([]C.int64_t, len(strides))
for i, s := range strides {
cStrides[i] = C.int64_t(s)
}
out := New("AS_STRIDED", t)
C.mlx_as_strided(
&out.ctx, t.ctx,
unsafe.SliceData(cShape), C.size_t(len(shape)),
unsafe.SliceData(cStrides), C.size_t(len(strides)),
C.size_t(offset),
DefaultStream().ctx,
)
return out
}
func (t *Array) Concatenate(axis int, others ...*Array) *Array {
vector := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(vector)
s := append([]*Array{t}, others...)
for _, other := range s {
C.mlx_vector_array_append_value(vector, other.ctx)
}
out := New("CONCATENATE", s...)
C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Divide(other *Array) *Array {
out := New("DIVIDE", t, other)
C.mlx_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Dequantize(scales, biases *Array, groupSize, bits int, mode string) *Array {
out := New("DEQUANTIZE", t, scales, biases)
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
C.mlx_dequantize(
&out.ctx,
t.ctx,
scales.ctx,
biases.ctx,
C.mlx_optional_int{
value: C.int(groupSize),
has_value: C.bool(groupSize > 0),
},
C.mlx_optional_int{
value: C.int(bits),
has_value: C.bool(bits > 0),
},
cMode,
C.mlx_optional_dtype{
has_value: false,
},
DefaultStream().ctx,
)
return out
}
func (t *Array) ExpandDims(axis int) *Array {
out := New("EXPAND_DIMS", t)
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Flatten(startAxis, endAxis int) *Array {
out := New("FLATTEN", t)
C.mlx_flatten(&out.ctx, t.ctx, C.int(startAxis), C.int(endAxis), DefaultStream().ctx)
return out
}
func (t *Array) FloorDivide(other *Array) *Array {
out := New("FLOOR_DIVIDE", t, other)
C.mlx_floor_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array {
if lhs == nil {
lhs = New("")
}
if rhs == nil {
rhs = New("")
}
out := New("GATHER_MM", t, other, lhs, rhs)
C.mlx_gather_mm(&out.ctx, t.ctx, other.ctx, lhs.ctx, rhs.ctx, C.bool(sorted), DefaultStream().ctx)
return out
}
func (t *Array) GatherQMM(weight, scales, biases, lhs, rhs *Array, transpose bool, groupSize, bits int, mode string, sorted bool) *Array {
if lhs == nil {
lhs = New("")
}
if rhs == nil {
rhs = New("")
}
out := New("GATHER_QMM", t, weight, scales, biases, lhs, rhs)
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
C.mlx_gather_qmm(
&out.ctx,
t.ctx,
weight.ctx,
scales.ctx,
biases.ctx,
lhs.ctx,
rhs.ctx,
C.bool(transpose),
C.mlx_optional_int{
value: C.int(groupSize),
has_value: C.bool(groupSize > 0),
},
C.mlx_optional_int{
value: C.int(bits),
has_value: C.bool(bits > 0),
},
cMode,
C.bool(sorted),
DefaultStream().ctx,
)
return out
}
func (t *Array) Logsumexp(keepDims bool) *Array {
out := New("LOGSUMEXP", t)
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) Matmul(other *Array) *Array {
out := New("MATMUL", t, other)
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Multiply(other *Array) *Array {
out := New("MULTIPLY", t, other)
C.mlx_multiply(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Negative() *Array {
out := New("NEGATIVE", t)
C.mlx_negative(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Power(exponent *Array) *Array {
out := New("POWER", t, exponent)
C.mlx_power(&out.ctx, t.ctx, exponent.ctx, DefaultStream().ctx)
return out
}
func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array {
out := New("PUT_ALONG_AXIS", t, indices, values)
C.mlx_put_along_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) QuantizedMatmul(weight, scales, biases *Array, transpose bool, groupSize, bits int, mode string) *Array {
out := New("QUANTIZED_MATMUL", t, weight, scales, biases)
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
C.mlx_quantized_matmul(
&out.ctx,
t.ctx,
weight.ctx,
scales.ctx,
biases.ctx,
C.bool(transpose),
C.mlx_optional_int{
value: C.int(groupSize),
has_value: C.bool(groupSize > 0),
},
C.mlx_optional_int{
value: C.int(bits),
has_value: C.bool(bits > 0),
},
cMode,
DefaultStream().ctx,
)
return out
}
func (t *Array) Reshape(axes ...int) *Array {
cAxes := make([]C.int, len(axes))
for i := range axes {
cAxes[i] = C.int(axes[i])
}
out := New("RESHAPE", t)
C.mlx_reshape(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
return out
}
func (t *Array) Sigmoid() *Array {
out := New("SIGMOID", t)
C.mlx_sigmoid(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Sqrt() *Array {
out := New("SQRT", t)
C.mlx_sqrt(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Squeeze(axis int) *Array {
out := New("SQUEEZE", t)
C.mlx_squeeze_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) StackAxis(axis int, others ...*Array) *Array {
vectorData := make([]C.mlx_array, len(others)+1)
vectorData[0] = t.ctx
for i := range others {
vectorData[i+1] = others[i].ctx
}
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
defer C.mlx_vector_array_free(vector)
out := New("STACK_AXIS", append(others, t)...)
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Subtract(other *Array) *Array {
out := New("SUBTRACT", t, other)
C.mlx_subtract(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) SumAxis(axis int, keepDims bool) *Array {
out := New("SUM_AXIS", t)
C.mlx_sum_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) TakeAxis(indices *Array, axis int) *Array {
out := New("TAKE_AXIS", t, indices)
C.mlx_take_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) TakeAlongAxis(indices *Array, axis int) *Array {
out := New("TAKE_ALONG_AXIS", t, indices)
C.mlx_take_along_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Tanh() *Array {
out := New("TANH", t)
C.mlx_tanh(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Transpose(axes ...int) *Array {
cAxes := make([]C.int, len(axes))
for i, axis := range axes {
cAxes[i] = C.int(axis)
}
out := New("TRANSPOSE", t)
C.mlx_transpose_axes(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
return out
}
func Zeros(dtype DType, shape ...int) *Array {
cAxes := make([]C.int, len(shape))
for i := range shape {
cAxes[i] = C.int(shape[i])
}
t := New("ZEROS")
C.mlx_zeros(&t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), C.mlx_dtype(dtype), DefaultStream().ctx)
return t
}

View File

@@ -1,11 +0,0 @@
package mlx
// #include "generated.h"
import "C"
func (t *Array) Categorical(axis int) *Array {
key := New("")
out := New("", t, key)
C.mlx_random_categorical(&out.ctx, t.ctx, C.int(axis), key.ctx, DefaultStream().ctx)
return out
}

View File

@@ -1,84 +0,0 @@
package mlx
// #include "generated.h"
import "C"
import (
"cmp"
"unsafe"
)
type slice struct {
args []int
}
func Slice(args ...int) slice {
return slice{args: args}
}
func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
if len(slices) != len(dims) {
panic("number of slice arguments must match number of tensor dimensions")
}
args := [3][]C.int{
make([]C.int, len(slices)),
make([]C.int, len(slices)),
make([]C.int, len(slices)),
}
for i, s := range slices {
switch len(s.args) {
case 0:
// slice[:]
args[0][i] = C.int(0)
args[1][i] = C.int(dims[i])
args[2][i] = C.int(1)
case 1:
// slice[i]
args[0][i] = C.int(s.args[0])
args[1][i] = C.int(s.args[0] + 1)
args[2][i] = C.int(1)
case 2:
// slice[i:j]
args[0][i] = C.int(s.args[0])
args[1][i] = cmp.Or(C.int(s.args[1]), C.int(dims[i]))
args[2][i] = C.int(1)
case 3:
// slice[i:j:k]
args[0][i] = C.int(s.args[0])
args[1][i] = cmp.Or(C.int(s.args[1]), C.int(dims[i]))
args[2][i] = C.int(s.args[2])
default:
panic("invalid slice arguments")
}
}
return args[0], args[1], args[2]
}
func (t *Array) Slice(slices ...slice) *Array {
starts, stops, strides := makeSlices(t.Dims(), slices...)
out := New("SLICE", t)
C.mlx_slice(
&out.ctx, t.ctx,
unsafe.SliceData(starts), C.size_t(len(starts)),
unsafe.SliceData(stops), C.size_t(len(stops)),
unsafe.SliceData(strides), C.size_t(len(strides)),
DefaultStream().ctx,
)
return out
}
func (t *Array) SliceUpdate(other *Array, slices ...slice) *Array {
starts, stops, strides := makeSlices(t.Dims(), slices...)
out := New("SLICE_UPDATE", t, other)
C.mlx_slice_update(
&out.ctx, t.ctx, other.ctx,
unsafe.SliceData(starts), C.size_t(len(starts)),
unsafe.SliceData(stops), C.size_t(len(stops)),
unsafe.SliceData(strides), C.size_t(len(strides)),
DefaultStream().ctx,
)
return out
}

View File

@@ -1,43 +0,0 @@
package mlx
// #include "generated.h"
import "C"
import (
"log/slog"
"sync"
)
type Device struct {
ctx C.mlx_device
}
func (d Device) LogValue() slog.Value {
str := C.mlx_string_new()
defer C.mlx_string_free(str)
C.mlx_device_tostring(&str, d.ctx)
return slog.StringValue(C.GoString(C.mlx_string_data(str)))
}
var DefaultDevice = sync.OnceValue(func() Device {
d := C.mlx_device_new()
C.mlx_get_default_device(&d)
return Device{d}
})
type Stream struct {
ctx C.mlx_stream
}
func (s Stream) LogValue() slog.Value {
str := C.mlx_string_new()
defer C.mlx_string_free(str)
C.mlx_stream_tostring(&str, s.ctx)
return slog.StringValue(C.GoString(C.mlx_string_data(str)))
}
var DefaultStream = sync.OnceValue(func() Stream {
s := C.mlx_stream_new()
C.mlx_get_default_stream(&s, DefaultDevice().ctx)
return Stream{s}
})

View File

@@ -1,8 +0,0 @@
package base
import "github.com/ollama/ollama/x/mlxrunner/cache"
// Cacher is implemented by models that support custom caching mechanisms.
type Cacher interface {
Cache() []cache.Cache
}

View File

@@ -1,116 +0,0 @@
package base
import (
"encoding/json"
"errors"
"log/slog"
"reflect"
"strconv"
"strings"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
type Model interface {
// Forward performs a forward pass through the model.
Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array
// NumLayers returns the number of layers in the model.
// This is used to initialize caches.
// TODO: consider moving cache initialization into the model itself.
NumLayers() int
}
type TextGeneration interface {
Model
Unembed(*mlx.Array) *mlx.Array
}
func Walk(m Model) (map[string]*mlx.Array, map[string]*mlx.Quantization, []mlx.AfterLoadFunc) {
weights := make(map[string]*mlx.Array)
quantizations := make(map[string]*mlx.Quantization)
var afterLoadFuncs []mlx.AfterLoadFunc
var fn func(v reflect.Value, tags []string)
fn = func(v reflect.Value, tags []string) {
t := v.Type()
if method := v.Addr().MethodByName("AfterLoad"); method.IsValid() {
var afterLoadFunc mlx.AfterLoadFunc
reflect.ValueOf(&afterLoadFunc).Elem().Set(method)
afterLoadFuncs = append(afterLoadFuncs, afterLoadFunc)
}
if t == reflect.TypeOf((*mlx.Array)(nil)).Elem() {
name := strings.Join(tags, ".")
weights[name] = v.Addr().Interface().(*mlx.Array)
return
} else if t == reflect.TypeOf((*mlx.Quantization)(nil)).Elem() {
quantizations[strings.Join(tags, ".")] = v.Addr().Interface().(*mlx.Quantization)
}
for _, field := range reflect.VisibleFields(t) {
if field.IsExported() {
tt, vv := field.Type, v.FieldByIndex(field.Index)
// create local copy so tags are not modified between fields
tags := tags
if tag := field.Tag.Get("weight"); tag != "" {
// TODO: use model.Tag
tags = append(tags, tag)
}
switch tt.Kind() {
case reflect.Interface:
vv = vv.Elem()
fallthrough
case reflect.Pointer:
vv = vv.Elem()
fallthrough
case reflect.Struct:
fn(vv, tags)
case reflect.Slice, reflect.Array:
for i := range vv.Len() {
fn(vv.Index(i), append(tags, strconv.Itoa(i)))
}
}
}
}
}
fn(reflect.ValueOf(m).Elem(), []string{})
return weights, quantizations, afterLoadFuncs
}
var m = make(map[string]func(*model.Root) (Model, error))
func Register(name string, f func(*model.Root) (Model, error)) {
if _, exists := m[name]; exists {
panic("model already registered: " + name)
}
m[name] = f
}
func New(root *model.Root) (Model, error) {
c, err := root.Open("config.json")
if err != nil {
return nil, err
}
defer c.Close()
var config struct {
Architectures []string `json:"architectures"`
}
if err := json.NewDecoder(c).Decode(&config); err != nil {
return nil, err
}
slog.Info("Model architecture", "arch", config.Architectures[0])
if f, exists := m[config.Architectures[0]]; exists {
return f(root)
}
return nil, errors.New("unknown architecture")
}

View File

@@ -1,84 +0,0 @@
package gemma
import (
"cmp"
"encoding/json"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
)
type Model struct {
Text TextModel `weight:"language_model"`
}
func (m *Model) NumLayers() int {
return len(m.Text.Layers)
}
func (m Model) Cache() []cache.Cache {
caches := make([]cache.Cache, m.NumLayers())
for i := range caches {
if (i+1)%m.Text.Options.SlidingWindowPattern == 0 {
caches[i] = cache.NewKVCache()
} else {
caches[i] = cache.NewRotatingKVCache(m.Text.Options.SlidingWindow)
}
}
return caches
}
func (m *Model) Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array {
return m.Text.Forward(inputs, cache)
}
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
return m.Text.EmbedTokens.AsLinear().Forward(x)
}
func init() {
base.Register("Gemma3ForConditionalGeneration", func(root *model.Root) (base.Model, error) {
bts, err := root.ReadFile("config.json")
if err != nil {
return nil, err
}
var opts struct {
Text TextOptions `json:"text_config"`
}
if err := json.Unmarshal(bts, &opts); err != nil {
return nil, err
}
opts.Text.NumAttentionHeads = cmp.Or(opts.Text.NumAttentionHeads, 8)
opts.Text.NumKeyValueHeads = cmp.Or(opts.Text.NumKeyValueHeads, 4)
opts.Text.HeadDim = cmp.Or(opts.Text.HeadDim, 256)
opts.Text.RMSNormEps = cmp.Or(opts.Text.RMSNormEps, 1e-6)
opts.Text.SlidingWindowPattern = cmp.Or(opts.Text.SlidingWindowPattern, 6)
// TODO: implement json.Unmarshaler
opts.Text.RoPE = map[bool]mlx.RoPE{
true: {Dims: opts.Text.HeadDim, Traditional: false, Base: 1_000_000, Scale: 1. / 8.},
false: {Dims: opts.Text.HeadDim, Traditional: false, Base: 10_000, Scale: 1},
}
return &Model{
Text: TextModel{
Layers: make([]TextDecoderLayer, opts.Text.NumHiddenLayers),
Options: opts.Text,
},
}, nil
})
}
type RMSNorm struct {
mlx.RMSNorm
}
func (m *RMSNorm) AfterLoad(*model.Root) ([]*mlx.Array, error) {
m.Weight.Set(m.Weight.Add(mlx.FromValue(1)))
return []*mlx.Array{}, nil
}

View File

@@ -1,118 +0,0 @@
package gemma
import (
"math"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
type TextOptions struct {
HiddenSize int `json:"hidden_size"`
NumHiddenLayers int `json:"num_hidden_layers"`
IntermediateSize int `json:"intermediate_size"`
NumAttentionHeads int `json:"num_attention_heads"`
NumKeyValueHeads int `json:"num_key_value_heads"`
HeadDim int `json:"head_dim"`
RMSNormEps float32 `json:"rms_norm_eps"`
SlidingWindow int `json:"sliding_window"`
SlidingWindowPattern int `json:"sliding_window_pattern"`
RoPE map[bool]mlx.RoPE
}
type TextModel struct {
EmbedTokens mlx.Embedding `weight:"model.embed_tokens"`
Layers []TextDecoderLayer `weight:"model.layers"`
Norm RMSNorm `weight:"model.norm"`
Options TextOptions
}
func (m TextModel) Forward(inputs *mlx.Array, caches []cache.Cache) *mlx.Array {
B, L := inputs.Dim(0), inputs.Dim(1)
hiddenStates := m.EmbedTokens.Forward(inputs)
hiddenSize := mlx.FromValue(m.Options.HiddenSize).AsType(hiddenStates.DType())
hiddenStates = hiddenStates.Multiply(hiddenSize.Sqrt())
for i, layer := range m.Layers {
hiddenStates = layer.Forward(hiddenStates, caches[i], B, L, m.Options.RoPE[(i+1)%m.Options.SlidingWindowPattern == 0], m.Options)
}
hiddenStates = m.Norm.Forward(hiddenStates, m.Options.RMSNormEps)
return hiddenStates
}
type TextDecoderLayer struct {
InputNorm RMSNorm `weight:"input_layernorm"`
Attention TextAttention `weight:"self_attn"`
PostAttnNorm RMSNorm `weight:"post_attention_layernorm"`
PreFFNorm RMSNorm `weight:"pre_feedforward_layernorm"`
MLP TextMLP `weight:"mlp"`
PostFFNorm RMSNorm `weight:"post_feedforward_layernorm"`
}
func (m TextDecoderLayer) Forward(hiddenStates *mlx.Array, cache cache.Cache, B, L int, rope mlx.RoPE, opts TextOptions) *mlx.Array {
residual := hiddenStates
hiddenStates = m.InputNorm.Forward(hiddenStates, opts.RMSNormEps)
hiddenStates = m.Attention.Forward(hiddenStates, cache, B, L, rope, opts)
hiddenStates = m.PostAttnNorm.Forward(hiddenStates, opts.RMSNormEps)
hiddenStates = hiddenStates.Add(residual)
residual = hiddenStates
hiddenStates = m.PreFFNorm.Forward(hiddenStates, opts.RMSNormEps)
hiddenStates = m.MLP.Forward(hiddenStates, opts)
hiddenStates = m.PostFFNorm.Forward(hiddenStates, opts.RMSNormEps)
hiddenStates = hiddenStates.Add(residual)
return hiddenStates
}
type TextAttention struct {
QProj mlx.Linear `weight:"q_proj"`
QNorm RMSNorm `weight:"q_norm"`
KProj mlx.Linear `weight:"k_proj"`
KNorm RMSNorm `weight:"k_norm"`
VProj mlx.Linear `weight:"v_proj"`
OProj mlx.Linear `weight:"o_proj"`
}
func (m TextAttention) Forward(hiddenStates *mlx.Array, cache cache.Cache, B, L int, rope mlx.RoPE, opts TextOptions) *mlx.Array {
query := m.QProj.Forward(hiddenStates)
key := m.KProj.Forward(hiddenStates)
value := m.VProj.Forward(hiddenStates)
query = query.AsStrided(
[]int{B, opts.NumAttentionHeads, L, opts.HeadDim},
[]int{L * opts.NumAttentionHeads * opts.HeadDim, opts.HeadDim, opts.NumAttentionHeads * opts.HeadDim, 1},
0)
key = key.AsStrided(
[]int{B, opts.NumKeyValueHeads, L, opts.HeadDim},
[]int{L * opts.NumKeyValueHeads * opts.HeadDim, opts.HeadDim, opts.NumKeyValueHeads * opts.HeadDim, 1},
0)
value = value.AsStrided(
[]int{B, opts.NumKeyValueHeads, L, opts.HeadDim},
[]int{L * opts.NumKeyValueHeads * opts.HeadDim, opts.HeadDim, opts.NumKeyValueHeads * opts.HeadDim, 1},
0)
query = m.QNorm.Forward(query, opts.RMSNormEps)
key = m.KNorm.Forward(key, opts.RMSNormEps)
query = rope.Forward(query, cache.Offset())
key = rope.Forward(key, cache.Offset())
key, value = cache.Update(key, value)
attention := mlx.ScaledDotProductAttention(query, key, value, nil, 1.0/float32(math.Sqrt(float64(opts.HeadDim))))
attention = attention.Transpose(0, 2, 1, 3).Reshape(B, L, -1)
return m.OProj.Forward(attention)
}
type TextMLP struct {
GateProj mlx.Linear `weight:"gate_proj"`
UpProj mlx.Linear `weight:"up_proj"`
DownProj mlx.Linear `weight:"down_proj"`
}
func (m TextMLP) Forward(h *mlx.Array, opts TextOptions) *mlx.Array {
return m.DownProj.Forward(mlx.GELUApprox(m.GateProj.Forward(h)).Multiply(m.UpProj.Forward(h)))
}

View File

@@ -1,334 +0,0 @@
package glm
import (
"encoding/json"
"math"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
)
type Options struct {
HiddenSize int `json:"hidden_size"`
NumHiddenLayers int `json:"num_hidden_layers"`
IntermediateSize int `json:"intermediate_size"`
NumAttentionHeads int `json:"num_attention_heads"`
NumKeyValueHeads int `json:"num_key_value_heads"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
QLoraRank int `json:"q_lora_rank"`
KVLoraRank int `json:"kv_lora_rank"`
QKRopeHeadDim int `json:"qk_rope_head_dim"`
QKNopeHeadDim int `json:"qk_nope_head_dim"`
NumRoutedExperts int `json:"n_routed_experts"`
NumSharedExperts int `json:"n_shared_experts"`
NumExpertsPerTok int `json:"num_experts_per_tok"`
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
NormTopKProb bool `json:"norm_topk_prob"`
FirstKDenseReplace int `json:"first_k_dense_replace"`
mlx.RoPE
}
type Model struct {
EmbedTokens mlx.Embedding `weight:"model.embed_tokens"`
Layers []Layer `weight:"model.layers"`
Norm mlx.RMSNorm `weight:"model.norm"`
LMHead mlx.Linear `weight:"lm_head"`
Options
}
func (m Model) NumLayers() int {
return len(m.Layers)
}
func (m Model) Forward(inputs *mlx.Array, caches []cache.Cache) *mlx.Array {
B, L := inputs.Dim(0), inputs.Dim(1)
h := m.EmbedTokens.Forward(inputs)
for i, layer := range m.Layers {
h = layer.Forward(h, caches[i], B, L, m.Options)
}
h = m.Norm.Forward(h, m.RMSNormEps)
return h
}
func (m Model) Unembed(x *mlx.Array) *mlx.Array {
return m.LMHead.Forward(x)
}
type Layer struct {
InputLayernorm mlx.RMSNorm `weight:"input_layernorm"`
Attention Attention `weight:"self_attn"`
PostAttentionLayernorm mlx.RMSNorm `weight:"post_attention_layernorm"`
MLP MLP `weight:"mlp"`
}
func (m Layer) Forward(h *mlx.Array, cache cache.Cache, B, L int, opts Options) *mlx.Array {
r := h
h = m.InputLayernorm.Forward(h, opts.RMSNormEps)
h = m.Attention.Forward(h, cache, B, L, opts)
h = h.Add(r)
r = h
h = m.PostAttentionLayernorm.Forward(h, opts.RMSNormEps)
h = m.MLP.Forward(h, B, L, opts)
h = h.Add(r)
return h
}
type MultiLinear struct {
Weight mlx.Array `weight:"weight"`
}
func (m MultiLinear) Forward(x *mlx.Array) *mlx.Array {
return x.Matmul(m.Weight.Transpose(0, 2, 1))
}
type Attention struct {
QAProj mlx.Linear `weight:"q_a_proj"`
QALayernorm mlx.RMSNorm `weight:"q_a_layernorm"`
QBProj mlx.Linear `weight:"q_b_proj"`
KVAProjWithMQA mlx.Linear `weight:"kv_a_proj_with_mqa"`
KVALayernorm mlx.RMSNorm `weight:"kv_a_layernorm"`
KVBProj mlx.Linear `weight:"kv_b_proj"`
embedQ MultiLinear
unembedOut MultiLinear
OProj mlx.Linear `weight:"o_proj"`
}
func (m *Attention) AfterLoad(root *model.Root) ([]*mlx.Array, error) {
bts, err := root.ReadFile("config.json")
if err != nil {
return nil, err
}
var opts struct {
NumAttentionHeads int `json:"num_attention_heads"`
QKNopeHeadDim int `json:"qk_nope_head_dim"`
KVLoraRank int `json:"kv_lora_rank"`
}
if err := json.Unmarshal(bts, &opts); err != nil {
return nil, err
}
w := m.KVBProj.Weight.Reshape(opts.NumAttentionHeads, -1, opts.KVLoraRank)
m.embedQ.Weight.Set(w.Slice(mlx.Slice(), mlx.Slice(0, opts.QKNopeHeadDim), mlx.Slice()).Transpose(0, 2, 1))
m.unembedOut.Weight.Set(w.Slice(mlx.Slice(), mlx.Slice(opts.QKNopeHeadDim, 0), mlx.Slice()))
return []*mlx.Array{
&m.embedQ.Weight,
&m.unembedOut.Weight,
}, nil
}
func (m Attention) Forward(hiddenStates *mlx.Array, cache cache.Cache, B, L int, opts Options) *mlx.Array {
query := m.QAProj.Forward(hiddenStates)
query = m.QALayernorm.Forward(query, opts.RMSNormEps)
query = m.QBProj.Forward(query)
query = query.Reshape(B, L, opts.NumAttentionHeads, -1)
query = query.Transpose(0, 2, 1, 3)
queryNope := query.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(), mlx.Slice(0, opts.QKNopeHeadDim))
queryRope := query.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(), mlx.Slice(opts.QKNopeHeadDim, 0))
compressedKV := m.KVAProjWithMQA.Forward(hiddenStates)
keyRope := compressedKV.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(opts.KVLoraRank, 0))
keyRope = keyRope.Reshape(B, L, 1, opts.QKRopeHeadDim)
keyRope = keyRope.Transpose(0, 2, 1, 3)
kvCompressed := compressedKV.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, opts.KVLoraRank))
var offset int
if cache != nil {
offset = cache.Offset()
}
queryRope = opts.RoPE.Forward(queryRope, offset)
keyRope = opts.RoPE.Forward(keyRope, offset)
key := m.KVALayernorm.Forward(kvCompressed, opts.RMSNormEps).
ExpandDims(1).
Concatenate(3, keyRope)
if cache != nil {
key, _ = cache.Update(key, mlx.Zeros(mlx.DTypeBFloat16, B, 1, L, 0))
}
value := key.Clone().Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(), mlx.Slice(0, opts.KVLoraRank))
query = m.embedQ.Forward(queryNope).Concatenate(3, queryRope)
attention := mlx.ScaledDotProductAttention(query, key, value, nil, float32(1.0/math.Sqrt(float64(opts.QKNopeHeadDim+opts.QKRopeHeadDim))))
attention = m.unembedOut.Forward(attention)
attention = attention.Transpose(0, 2, 1, 3).Reshape(B, L, -1)
return m.OProj.Forward(attention)
}
type MLP interface {
Forward(*mlx.Array, int, int, Options) *mlx.Array
}
type dense struct {
GateProj mlx.Linear `weight:"gate_proj"`
UpProj mlx.Linear `weight:"up_proj"`
DownProj mlx.Linear `weight:"down_proj"`
}
func (m dense) Forward(h *mlx.Array, _, _ int, opts Options) *mlx.Array {
h = mlx.SILU(m.GateProj.Forward(h)).Multiply(m.UpProj.Forward(h))
return m.DownProj.Forward(h)
}
type Gate struct {
Gate mlx.Linear `weight:"gate"`
CorrectionBias mlx.Array `weight:"gate.e_score_correction_bias"`
}
var expertSelect *mlx.Closure
func ExpertSelect(opts Options) *mlx.Closure {
if expertSelect == nil {
expertSelect = mlx.Compile(func(inputs []*mlx.Array) []*mlx.Array {
scores, correctionBias := inputs[0], inputs[1]
scores = scores.Sigmoid()
original := scores
scores = scores.Add(correctionBias)
indices := scores.Negative().ArgpartitionAxis(opts.NumExpertsPerTok-1, -1)
indices = indices.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, opts.NumExpertsPerTok))
scores = original.TakeAlongAxis(indices, -1)
if opts.NumExpertsPerTok > 1 && opts.NormTopKProb {
scores = scores.Divide(scores.SumAxis(-1, true).Add(mlx.FromValue[float32](1e-20)))
}
scores = scores.Multiply(mlx.FromValue(opts.RoutedScalingFactor))
return []*mlx.Array{indices, scores}
}, false)
}
return expertSelect
}
func (m Gate) Forward(h *mlx.Array, opts Options) (indices, scores *mlx.Array) {
outputs := ExpertSelect(opts).Call([]*mlx.Array{
m.Gate.Forward(h).AsType(mlx.DTypeFloat32),
&m.CorrectionBias,
})
return outputs[0], outputs[1]
}
type sparse struct {
Gate
Experts []dense `weight:"experts"`
fused struct {
GateProj mlx.Linear
UpProj mlx.Linear
DownProj mlx.Linear
}
SharedExperts dense `weight:"shared_experts"`
}
func (m *sparse) AfterLoad(*model.Root) ([]*mlx.Array, error) {
w1 := make([]*mlx.Array, len(m.Experts))
w2 := make([]*mlx.Array, len(m.Experts))
w3 := make([]*mlx.Array, len(m.Experts))
for i := range m.Experts {
w1[i] = &m.Experts[i].GateProj.Weight
w2[i] = &m.Experts[i].UpProj.Weight
w3[i] = &m.Experts[i].DownProj.Weight
}
m.fused.GateProj.Weight.Set(w1[0].StackAxis(0, w1[1:]...))
m.fused.UpProj.Weight.Set(w2[0].StackAxis(0, w2[1:]...))
m.fused.DownProj.Weight.Set(w3[0].StackAxis(0, w3[1:]...))
return []*mlx.Array{
&m.fused.GateProj.Weight,
&m.fused.UpProj.Weight,
&m.fused.DownProj.Weight,
}, nil
}
func (m sparse) Forward(h *mlx.Array, B, L int, opts Options) *mlx.Array {
indices, scores := m.Gate.Forward(h, opts)
scores = scores.ExpandDims(-1)
flat := h.ExpandDims(-2).ExpandDims(-2).Reshape(-1, 1, 1, opts.HiddenSize)
indices = indices.Reshape(-1, opts.NumExpertsPerTok)
sort := B*L >= 64
var inverseOrder *mlx.Array
if sort {
indicesAll := indices.Flatten(0, len(indices.Dims())-1)
order := indicesAll.ArgsortAxis(0)
inverseOrder = order.ArgsortAxis(0)
flat = flat.Squeeze(1).TakeAxis(order.FloorDivide(mlx.FromValue(opts.NumExpertsPerTok)), 0).ExpandDims(1)
indices = indicesAll.TakeAxis(order, 0).Reshape(B*L*opts.NumExpertsPerTok, 1)
}
experts := mlx.SILU(m.fused.GateProj.Gather(flat, nil, indices, sort)).
Multiply(m.fused.UpProj.Gather(flat, nil, indices, sort))
experts = m.fused.DownProj.Gather(experts, nil, indices, sort)
if sort {
experts = experts.Squeeze(2).Squeeze(1).TakeAxis(inverseOrder, 0)
experts = experts.Reshape(-1, opts.NumExpertsPerTok, opts.HiddenSize)
} else {
experts = experts.Squeeze(2)
}
experts = experts.Reshape(B, L, opts.NumExpertsPerTok, opts.HiddenSize)
experts = experts.Multiply(scores).SumAxis(-2, false).AsType(experts.DType())
experts = experts.Add(m.SharedExperts.Forward(h, B, L, opts))
return experts.Reshape(B, L, -1)
}
func init() {
base.Register("Glm4MoeLiteForCausalLM", func(root *model.Root) (base.Model, error) {
bts, err := root.ReadFile("config.json")
if err != nil {
return nil, err
}
var opts Options
if err := json.Unmarshal(bts, &opts); err != nil {
return nil, err
}
opts.RoPE = mlx.RoPE{
Dims: opts.QKRopeHeadDim,
Traditional: true,
Base: opts.RopeTheta,
Scale: 1,
}
layers := make([]Layer, opts.NumHiddenLayers)
for i := range layers {
if i < opts.FirstKDenseReplace {
layers[i].MLP = &dense{}
} else {
layers[i].MLP = &sparse{Experts: make([]dense, opts.NumRoutedExperts)}
}
}
return &Model{
Layers: layers,
Options: opts,
}, nil
})
}

View File

@@ -1,130 +0,0 @@
package llama
import (
"encoding/json"
"log/slog"
"math"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
)
type Options struct {
HiddenAct string `json:"hidden_act"`
HiddenSize int `json:"hidden_size"`
IntermediateSize int `json:"intermediate_size"`
NumAttentionHeads int `json:"num_attention_heads"`
NumHiddenLayers int `json:"num_hidden_layers"`
NumKeyValueHeads int `json:"num_key_value_heads"`
RMSNormEps float32 `json:"rms_norm_eps"`
mlx.RoPE
}
type Model struct {
EmbedTokens mlx.Embedding `weight:"model.embed_tokens"`
Layers []Layer `weight:"model.layers"`
Norm mlx.RMSNorm `weight:"model.norm"`
Output mlx.Linear `weight:"lm_head"`
Options
}
func (m Model) NumLayers() int {
return len(m.Layers)
}
func (m Model) Forward(inputs *mlx.Array, caches []cache.Cache) *mlx.Array {
slog.Debug("Model.forward", "input shape", inputs.Dims(), "m.EmbedTokens", m.EmbedTokens.Weight.Dims())
B, L := inputs.Dim(0), inputs.Dim(1)
hiddenStates := m.EmbedTokens.Forward(inputs)
for i, layer := range m.Layers {
hiddenStates = layer.Forward(hiddenStates, caches[i], B, L, m.Options)
}
hiddenStates = m.Norm.Forward(hiddenStates, m.RMSNormEps)
hiddenStates = m.Output.Forward(hiddenStates)
slog.Debug("Model.forward", "output shape", hiddenStates.Dims(), "m.Output", m.Output.Weight.Dims())
return hiddenStates
}
type Layer struct {
AttentionNorm mlx.RMSNorm `weight:"input_layernorm"`
Attention Attention `weight:"self_attn"`
MLPNorm mlx.RMSNorm `weight:"post_attention_layernorm"`
MLP MLP `weight:"mlp"`
}
func (m Layer) Forward(hiddenStates *mlx.Array, c cache.Cache, B, L int, opts Options) *mlx.Array {
residual := hiddenStates
hiddenStates = m.AttentionNorm.Forward(hiddenStates, opts.RMSNormEps)
hiddenStates = m.Attention.Forward(hiddenStates, c, B, L, opts)
hiddenStates = hiddenStates.Add(residual)
residual = hiddenStates
hiddenStates = m.MLPNorm.Forward(hiddenStates, opts.RMSNormEps)
hiddenStates = m.MLP.Forward(hiddenStates)
hiddenStates = hiddenStates.Add(residual)
return hiddenStates
}
type Attention struct {
QueryProj mlx.Linear `weight:"q_proj"`
KeyProj mlx.Linear `weight:"k_proj"`
ValueProj mlx.Linear `weight:"v_proj"`
OutputProj mlx.Linear `weight:"o_proj"`
}
func (m Attention) Forward(hiddenStates *mlx.Array, cache cache.Cache, B, L int, opts Options) *mlx.Array {
query := m.QueryProj.Forward(hiddenStates)
query = query.Reshape(B, L, opts.NumAttentionHeads, -1).Transpose(0, 2, 1, 3)
key := m.KeyProj.Forward(hiddenStates)
key = key.Reshape(B, L, opts.NumKeyValueHeads, -1).Transpose(0, 2, 1, 3)
value := m.ValueProj.Forward(hiddenStates)
value = value.Reshape(B, L, opts.NumKeyValueHeads, -1).Transpose(0, 2, 1, 3)
query = opts.RoPE.Forward(query, cache.Offset())
key = opts.RoPE.Forward(key, cache.Offset())
key, value = cache.Update(key, value)
attention := mlx.ScaledDotProductAttention(query, key, value, nil, 1.0/float32(math.Sqrt(float64(key.Dim(-1)))))
attention = attention.Transpose(0, 2, 1, 3).Reshape(B, L, -1)
return m.OutputProj.Forward(attention)
}
type MLP struct {
Gate mlx.Linear `weight:"gate_proj"`
Up mlx.Linear `weight:"up_proj"`
Down mlx.Linear `weight:"down_proj"`
}
func (m MLP) Forward(h *mlx.Array) *mlx.Array {
return m.Down.Forward(mlx.SILU(m.Gate.Forward(h)).Multiply(m.Up.Forward(h)))
}
func init() {
base.Register("MistralForCausalLM", func(root *model.Root) (base.Model, error) {
bts, err := root.ReadFile("config.json")
if err != nil {
return nil, err
}
var opts Options
// TODO: implement json.Unmarshal for Options
if err := json.Unmarshal(bts, &opts); err != nil {
return nil, err
}
if err := json.Unmarshal(bts, &opts.RoPE); err != nil {
return nil, err
}
return &Model{
Layers: make([]Layer, opts.NumHiddenLayers),
Options: opts,
}, nil
})
}

View File

@@ -1,7 +0,0 @@
package model
import (
_ "github.com/ollama/ollama/x/mlxrunner/model/gemma/3"
_ "github.com/ollama/ollama/x/mlxrunner/model/glm/4/moe/lite"
_ "github.com/ollama/ollama/x/mlxrunner/model/llama"
)

View File

@@ -1,138 +0,0 @@
package mlxrunner
import (
"bytes"
"errors"
"log/slog"
"time"
"unicode/utf8"
"github.com/ollama/ollama/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
)
func (r *Runner) TextGenerationPipeline(request Request) error {
model, ok := r.Model.(base.TextGeneration)
if !ok {
return errors.New("model does not support causal language modeling")
}
inputs, err := r.Tokenizer.Encode(request.Prompt, true)
if err != nil {
return err
}
caches, tokens := r.FindNearestCache(inputs)
if len(caches) == 0 {
if cacher, ok := model.(base.Cacher); ok {
caches = cacher.Cache()
} else {
caches = make([]cache.Cache, model.NumLayers())
for i := range caches {
caches[i] = cache.NewKVCache()
}
}
}
total, processed := len(tokens), 0
slog.Info("Prompt processing progress", "processed", processed, "total", total)
for total-processed > 1 {
n := min(2<<10, total-processed-1)
temp := model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
defer mlx.Free(temp)
mlx.Eval(func() []*mlx.Array {
s := make([]*mlx.Array, 2*len(caches))
for i, c := range caches {
s[2*i], s[2*i+1] = c.State()
}
return s
}()...)
processed += n
slog.Info("Prompt processing progress", "processed", processed, "total", total)
mlx.ClearCache()
}
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
logits := model.Unembed(model.Forward(token.ExpandDims(0), caches))
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
// TODO: additional logit processing (logit bias, repetition penalty, etc.)
logprobs := logits.Subtract(logits.Logsumexp(true))
return request.Sample(logprobs), logprobs
}
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed))
mlx.AsyncEval(sample, logprobs)
// buffer partial, multibyte unicode
var b bytes.Buffer
now := time.Now()
final := Response{PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
outputs := make([]int32, 0, request.Options.MaxTokens)
for i := range request.Options.MaxTokens {
nextSample, nextLogprobs := step(sample)
mlx.AsyncEval(nextSample, nextLogprobs)
if i == 0 {
slog.Info("Prompt processing progress", "processed", total, "total", total)
mlx.Eval(sample)
final.PromptTokensDuration = time.Since(now)
now = time.Now()
}
output := int32(sample.Int())
outputs = append(outputs, output)
if r.Tokenizer.Is(output, tokenizer.SpecialEOS) {
final.Token = int(output)
final.DoneReason = 0
final.CompletionTokens = i
break
}
request.Responses <- Response{
Text: r.Decode(output, &b),
Token: int(output),
}
mlx.Free(sample, logprobs)
if i%256 == 0 {
mlx.ClearCache()
}
sample, logprobs = nextSample, nextLogprobs
}
mlx.Free(sample, logprobs)
final.CompletionTokensDuration = time.Since(now)
request.Responses <- final
r.InsertCache(append(inputs, outputs...), caches)
return nil
}
func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
token, err := r.Tokenizer.Decode([]int32{sample})
if err != nil {
slog.Error("Failed to decode tokens", "error", err)
return ""
}
if _, err := b.WriteString(token); err != nil {
slog.Error("Failed to write token to buffer", "error", err)
return ""
}
if text := b.String(); utf8.ValidString(text) {
b.Reset()
return text
} else if b.Len() >= utf8.UTFMax {
b.Reset()
return text
}
return ""
}

View File

@@ -1,110 +0,0 @@
package mlxrunner
import (
"context"
"log/slog"
"net"
"net/http"
"time"
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/tokenizer"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
_ "github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/mlxrunner/sample"
)
type Request struct {
TextCompletionsRequest
Responses chan Response
Pipeline func(Request) error
sample.Sampler
caches []cache.Cache
}
type TextCompletionsRequest struct {
Prompt string `json:"prompt"`
Options struct {
Temperature float32 `json:"temperature"`
TopP float32 `json:"top_p"`
MinP float32 `json:"min_p"`
TopK int `json:"top_k"`
MaxTokens int `json:"max_tokens"`
// Deprecated: use MaxTokens instead
NumPredict int `json:"num_predict"`
} `json:"options"`
}
type Response struct {
Text string `json:"content,omitempty"`
Token int `json:"token,omitempty"`
Logprobs []float32 `json:"logprobs,omitempty"`
Done bool `json:"done,omitempty"`
DoneReason int `json:"done_reason,omitempty"`
PromptTokens int `json:"prompt_eval_count,omitempty"`
PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"`
CompletionTokens int `json:"eval_count,omitempty"`
CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
}
type Runner struct {
Model base.Model
Tokenizer tokenizer.Tokenizer
Requests chan Request
CacheEntries map[int32]*CacheEntry
}
func (r *Runner) Load(name model.Name) (err error) {
root, err := model.Open(name)
if err != nil {
return err
}
defer root.Close()
r.Model, err = base.New(root)
if err != nil {
return err
}
r.Tokenizer, err = tokenizer.New(root)
if err != nil {
return err
}
weights, quantizations, afterLoadFuncs := base.Walk(r.Model)
return mlx.LoadAll(root, weights, quantizations, afterLoadFuncs)
}
func (r *Runner) Run(host, port string, mux http.Handler) error {
g, ctx := errgroup.WithContext(context.Background())
g.Go(func() error {
for {
select {
case <-ctx.Done():
return nil
case request := <-r.Requests:
if err := request.Pipeline(request); err != nil {
break
}
close(request.Responses)
}
}
})
g.Go(func() error {
slog.Info("Starting HTTP server", "host", host, "port", port)
return http.ListenAndServe(net.JoinHostPort(host, port), mux)
})
return g.Wait()
}

View File

@@ -1,75 +0,0 @@
package sample
import (
"math"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
type Sampler interface {
Sample(*mlx.Array) *mlx.Array
}
func New(temp, top_p, min_p float32, top_k int) Sampler {
if temp == 0 {
return greedy{}
}
var samplers []Sampler
if top_p > 0 && top_p < 1 {
samplers = append(samplers, TopP(top_p))
}
if min_p != 0 {
samplers = append(samplers, MinP(min_p))
}
if top_k > 0 {
samplers = append(samplers, TopK(top_k))
}
samplers = append(samplers, Temperature(temp))
return chain(samplers)
}
type greedy struct{}
func (greedy) Sample(logits *mlx.Array) *mlx.Array {
return logits.Argmax(-1, false)
}
type chain []Sampler
func (c chain) Sample(logits *mlx.Array) *mlx.Array {
for _, sampler := range c {
logits = sampler.Sample(logits)
}
return logits
}
type Temperature float32
func (t Temperature) Sample(logits *mlx.Array) *mlx.Array {
return logits.Multiply(mlx.FromValue(1 / float32(t))).Categorical(-1)
}
type TopP float32
func (p TopP) Sample(logprobs *mlx.Array) *mlx.Array {
// TODO: implement
return logprobs
}
type MinP float32
func (p MinP) Sample(logprobs *mlx.Array) *mlx.Array {
// TODO: implement
return logprobs
}
type TopK int
func (k TopK) Sample(logprobs *mlx.Array) *mlx.Array {
mask := logprobs.Negative().ArgpartitionAxis(int(k)-1, -1).Slice(mlx.Slice(), mlx.Slice(int(k), 0))
return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
}

View File

@@ -1,180 +0,0 @@
package mlxrunner
import (
"bytes"
"cmp"
"encoding/json"
"flag"
"io"
"log/slog"
"net/http"
"os"
"strconv"
"time"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/mlxrunner/sample"
)
func Execute(args []string) error {
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
var (
name model.Name
port int
)
flagSet := flag.NewFlagSet("mlxrunner", flag.ExitOnError)
flagSet.Var(&name, "model", "Model name")
flagSet.IntVar(&port, "port", 0, "Port to listen on")
_ = flagSet.Bool("verbose", false, "Enable debug logging")
flagSet.Parse(args)
runner := Runner{
Requests: make(chan Request),
CacheEntries: make(map[int32]*CacheEntry),
}
if err := runner.Load(name); err != nil {
return err
}
mux := http.NewServeMux()
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
if err := json.NewEncoder(w).Encode(map[string]any{
"status": 0,
"progress": 100,
}); err != nil {
slog.Error("Failed to encode response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
})
mux.HandleFunc("/v1/models", func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case "POST":
fallthrough
case "GET":
if err := json.NewEncoder(w).Encode(map[string]any{
"Success": true,
}); err != nil {
slog.Error("Failed to encode response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
case "DELETE":
// TODO: cleanup model and cache
}
})
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
request := Request{Responses: make(chan Response)}
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
slog.Error("Failed to decode request", "error", err)
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
if request.Options.MaxTokens < 1 {
request.Options.MaxTokens = 16 << 10
}
request.Pipeline = runner.TextGenerationPipeline
request.Sampler = sample.New(
request.Options.Temperature,
request.Options.TopP,
request.Options.MinP,
request.Options.TopK,
)
runner.Requests <- request
w.Header().Set("Content-Type", "application/jsonl")
w.WriteHeader(http.StatusOK)
enc := json.NewEncoder(w)
for response := range request.Responses {
if err := enc.Encode(response); err != nil {
slog.Error("Failed to encode response", "error", err)
return
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
}
})
mux.HandleFunc("POST /v1/tokenize", func(w http.ResponseWriter, r *http.Request) {
var b bytes.Buffer
if _, err := io.Copy(&b, r.Body); err != nil {
slog.Error("Failed to read request body", "error", err)
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
tokens, err := runner.Tokenizer.Encode(b.String(), true)
if err != nil {
slog.Error("Failed to tokenize text", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
if err := json.NewEncoder(w).Encode(tokens); err != nil {
slog.Error("Failed to encode response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
})
for source, target := range map[string]string{
"GET /health": "/v1/status",
"POST /load": "/v1/models",
"POST /completion": "/v1/completions",
} {
mux.Handle(source, http.RedirectHandler(target, http.StatusPermanentRedirect))
}
return runner.Run("127.0.0.1", strconv.Itoa(port), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
recorder := &statusRecorder{ResponseWriter: w, code: http.StatusOK}
t := time.Now()
mux.ServeHTTP(recorder, r)
var level slog.Level
switch {
case recorder.code >= 500:
level = slog.LevelError
case recorder.code >= 400:
level = slog.LevelWarn
case recorder.code >= 300:
return
}
slog.Log(r.Context(), level, "ServeHTTP", "method", r.Method, "path", r.URL.Path, "took", time.Since(t), "status", recorder.Status())
}))
}
type statusRecorder struct {
http.ResponseWriter
code int
}
func (w *statusRecorder) WriteHeader(code int) {
w.code = code
w.ResponseWriter.WriteHeader(code)
}
func (w *statusRecorder) Status() string {
return strconv.Itoa(w.code) + " " + http.StatusText(w.code)
}
func (w *statusRecorder) Flush() {
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}