mirror of
https://github.com/ollama/ollama.git
synced 2026-02-06 13:43:39 -05:00
Compare commits
3 Commits
brucemacd/
...
mxyng/mlx
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b367596c58 | ||
|
|
7c027625ef | ||
|
|
092ffbb2f6 |
5
go.mod
5
go.mod
@@ -13,7 +13,7 @@ require (
|
||||
github.com/mattn/go-sqlite3 v1.14.24
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/x448/float16 v0.8.4
|
||||
golang.org/x/sync v0.17.0
|
||||
golang.org/x/sys v0.37.0
|
||||
@@ -29,6 +29,8 @@ require (
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0
|
||||
github.com/tree-sitter/go-tree-sitter v0.25.0
|
||||
github.com/tree-sitter/tree-sitter-cpp v0.23.4
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8
|
||||
golang.org/x/image v0.22.0
|
||||
golang.org/x/mod v0.30.0
|
||||
@@ -50,6 +52,7 @@ require (
|
||||
github.com/google/flatbuffers v24.3.25+incompatible // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/mailru/easyjson v0.7.7 // indirect
|
||||
github.com/mattn/go-pointer v0.0.1 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
|
||||
31
go.sum
31
go.sum
@@ -152,6 +152,8 @@ github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0
|
||||
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-pointer v0.0.1 h1:n+XhsuGeVO6MEAp7xyEukFINEa+Quek5psIR/ylA6o0=
|
||||
github.com/mattn/go-pointer v0.0.1/go.mod h1:2zXcozF6qYGgmsG+SeTZz3oAbFLdD3OWqnUbNvJZAlc=
|
||||
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
|
||||
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
@@ -206,12 +208,39 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tkrajina/go-reflector v0.5.5 h1:gwoQFNye30Kk7NrExj8zm3zFtrGPqOkzFMLuQZg1DtQ=
|
||||
github.com/tkrajina/go-reflector v0.5.5/go.mod h1:ECbqLgccecY5kPmPmXg1MrHW585yMcDkVl6IvJe64T4=
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0 h1:ZedWk82egydDspGTryAatbX0/1NZDQbdiZLoCbOk4f8=
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0/go.mod h1:sjU00nti/PMEOZb07KljFlR+lJ+RotsC0GBQMv9EKls=
|
||||
github.com/tree-sitter/go-tree-sitter v0.25.0 h1:sx6kcg8raRFCvc9BnXglke6axya12krCJF5xJ2sftRU=
|
||||
github.com/tree-sitter/go-tree-sitter v0.25.0/go.mod h1:r77ig7BikoZhHrrsjAnv8RqGti5rtSyvDHPzgTPsUuU=
|
||||
github.com/tree-sitter/tree-sitter-c v0.23.4 h1:nBPH3FV07DzAD7p0GfNvXM+Y7pNIoPenQWBpvM++t4c=
|
||||
github.com/tree-sitter/tree-sitter-c v0.23.4/go.mod h1:MkI5dOiIpeN94LNjeCp8ljXN/953JCwAby4bClMr6bw=
|
||||
github.com/tree-sitter/tree-sitter-cpp v0.23.4 h1:LaWZsiqQKvR65yHgKmnaqA+uz6tlDJTJFCyFIeZU/8w=
|
||||
github.com/tree-sitter/tree-sitter-cpp v0.23.4/go.mod h1:doqNW64BriC7WBCQ1klf0KmJpdEvfxyXtoEybnBo6v8=
|
||||
github.com/tree-sitter/tree-sitter-embedded-template v0.23.2 h1:nFkkH6Sbe56EXLmZBqHHcamTpmz3TId97I16EnGy4rg=
|
||||
github.com/tree-sitter/tree-sitter-embedded-template v0.23.2/go.mod h1:HNPOhN0qF3hWluYLdxWs5WbzP/iE4aaRVPMsdxuzIaQ=
|
||||
github.com/tree-sitter/tree-sitter-go v0.23.4 h1:yt5KMGnTHS+86pJmLIAZMWxukr8W7Ae1STPvQUuNROA=
|
||||
github.com/tree-sitter/tree-sitter-go v0.23.4/go.mod h1:Jrx8QqYN0v7npv1fJRH1AznddllYiCMUChtVjxPK040=
|
||||
github.com/tree-sitter/tree-sitter-html v0.23.2 h1:1UYDV+Yd05GGRhVnTcbP58GkKLSHHZwVaN+lBZV11Lc=
|
||||
github.com/tree-sitter/tree-sitter-html v0.23.2/go.mod h1:gpUv/dG3Xl/eebqgeYeFMt+JLOY9cgFinb/Nw08a9og=
|
||||
github.com/tree-sitter/tree-sitter-java v0.23.5 h1:J9YeMGMwXYlKSP3K4Us8CitC6hjtMjqpeOf2GGo6tig=
|
||||
github.com/tree-sitter/tree-sitter-java v0.23.5/go.mod h1:NRKlI8+EznxA7t1Yt3xtraPk1Wzqh3GAIC46wxvc320=
|
||||
github.com/tree-sitter/tree-sitter-javascript v0.23.1 h1:1fWupaRC0ArlHJ/QJzsfQ3Ibyopw7ZfQK4xXc40Zveo=
|
||||
github.com/tree-sitter/tree-sitter-javascript v0.23.1/go.mod h1:lmGD1EJdCA+v0S1u2fFgepMg/opzSg/4pgFym2FPGAs=
|
||||
github.com/tree-sitter/tree-sitter-json v0.24.8 h1:tV5rMkihgtiOe14a9LHfDY5kzTl5GNUYe6carZBn0fQ=
|
||||
github.com/tree-sitter/tree-sitter-json v0.24.8/go.mod h1:F351KK0KGvCaYbZ5zxwx/gWWvZhIDl0eMtn+1r+gQbo=
|
||||
github.com/tree-sitter/tree-sitter-php v0.23.11 h1:iHewsLNDmznh8kgGyfWfujsZxIz1YGbSd2ZTEM0ZiP8=
|
||||
github.com/tree-sitter/tree-sitter-php v0.23.11/go.mod h1:T/kbfi+UcCywQfUNAJnGTN/fMSUjnwPXA8k4yoIks74=
|
||||
github.com/tree-sitter/tree-sitter-python v0.23.6 h1:qHnWFR5WhtMQpxBZRwiaU5Hk/29vGju6CVtmvu5Haas=
|
||||
github.com/tree-sitter/tree-sitter-python v0.23.6/go.mod h1:cpdthSy/Yoa28aJFBscFHlGiU+cnSiSh1kuDVtI8YeM=
|
||||
github.com/tree-sitter/tree-sitter-ruby v0.23.1 h1:T/NKHUA+iVbHM440hFx+lzVOzS4dV6z8Qw8ai+72bYo=
|
||||
github.com/tree-sitter/tree-sitter-ruby v0.23.1/go.mod h1:kUS4kCCQloFcdX6sdpr8p6r2rogbM6ZjTox5ZOQy8cA=
|
||||
github.com/tree-sitter/tree-sitter-rust v0.23.2 h1:6AtoooCW5GqNrRpfnvl0iUhxTAZEovEmLKDbyHlfw90=
|
||||
github.com/tree-sitter/tree-sitter-rust v0.23.2/go.mod h1:hfeGWic9BAfgTrc7Xf6FaOAguCFJRo3RBbs7QJ6D7MI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"github.com/ollama/ollama/runner/llamarunner"
|
||||
"github.com/ollama/ollama/runner/ollamarunner"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/mlxrunner"
|
||||
)
|
||||
|
||||
func Execute(args []string) error {
|
||||
@@ -17,6 +18,8 @@ func Execute(args []string) error {
|
||||
return ollamarunner.Execute(args[1:])
|
||||
case "--imagegen-engine":
|
||||
return imagegen.Execute(args[1:])
|
||||
case "--mlx-engine":
|
||||
return mlxrunner.Execute(args[1:])
|
||||
}
|
||||
}
|
||||
return llamarunner.Execute(args)
|
||||
|
||||
107
server/sched.go
107
server/sched.go
@@ -5,9 +5,13 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"os"
|
||||
"os/exec"
|
||||
"reflect"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -22,6 +26,7 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/mlxrunner"
|
||||
)
|
||||
|
||||
type LlmRequest struct {
|
||||
@@ -195,25 +200,14 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
|
||||
}
|
||||
|
||||
// Check for image generation models - all use MLX runner
|
||||
if slices.Contains(pending.model.Config.Capabilities, "image") {
|
||||
if s.loadMLX(pending) {
|
||||
// Check for experimental safetensors LLM models
|
||||
if pending.model.Config.ModelFormat == "safetensors" {
|
||||
if s.loadSafetensors(pending) {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for experimental safetensors LLM models
|
||||
if pending.model.Config.ModelFormat == "safetensors" {
|
||||
if slices.Contains(pending.model.Config.Capabilities, "completion") {
|
||||
// LLM model with safetensors format - use MLX runner
|
||||
if s.loadMLX(pending) {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Load model for fitting
|
||||
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
|
||||
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
|
||||
@@ -563,9 +557,90 @@ iGPUScan:
|
||||
return false
|
||||
}
|
||||
|
||||
// loadMLX loads an experimental safetensors model using the unified MLX runner.
|
||||
func subproc(args, environ []string) (*exec.Cmd, int, error) {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("unable to lookup executable path: %w", err)
|
||||
}
|
||||
|
||||
for range 3 {
|
||||
// get a random port in the ephemeral range
|
||||
port := rand.Intn(65535-49152) + 49152
|
||||
cmd := exec.Command(exe, slices.Concat([]string{"runner"}, args, []string{"--port", strconv.Itoa(port)})...)
|
||||
cmd.Env = slices.Concat(os.Environ(), environ)
|
||||
cmd.Stdout = os.Stderr
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Start(); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
return cmd, port, nil
|
||||
}
|
||||
|
||||
return nil, 0, fmt.Errorf("unable to start subprocess after multiple attempts")
|
||||
}
|
||||
|
||||
func (s *Scheduler) loadSafetensors(req *LlmRequest) bool {
|
||||
if slices.Contains(req.model.Config.Capabilities, "image") {
|
||||
return s.loadImageGen(req)
|
||||
}
|
||||
|
||||
args := []string{"--mlx-engine", "--model", req.model.ShortName}
|
||||
environ := []string{}
|
||||
cmd, port, err := subproc(args, environ)
|
||||
if err != nil {
|
||||
req.errCh <- fmt.Errorf("failed to start mlx subprocess: %w", err)
|
||||
return true
|
||||
}
|
||||
|
||||
sessionDuration := envconfig.KeepAlive()
|
||||
if req.sessionDuration != nil {
|
||||
sessionDuration = req.sessionDuration.Duration
|
||||
}
|
||||
|
||||
runner := &runnerRef{
|
||||
model: req.model,
|
||||
modelPath: req.model.ModelPath,
|
||||
Options: &req.opts,
|
||||
loading: false,
|
||||
sessionDuration: sessionDuration,
|
||||
llama: &mlxrunner.Client{
|
||||
Cmd: cmd,
|
||||
Port: port,
|
||||
},
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded[req.model.ModelPath] = runner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
runner.refMu.Lock()
|
||||
if sessionDuration > 0 {
|
||||
runner.expireTimer = time.AfterFunc(sessionDuration, func() {
|
||||
s.expiredCh <- runner
|
||||
})
|
||||
}
|
||||
runner.refMu.Unlock()
|
||||
req.useLoadedRunner(runner, s.finishedReqCh)
|
||||
|
||||
for range time.Tick(20 * time.Millisecond) {
|
||||
if err := func() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
return runner.llama.Ping(ctx)
|
||||
}(); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// loadImageGen loads an experimental safetensors model using the unified MLX runner.
|
||||
// This supports both LLM (completion) and image generation models.
|
||||
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
|
||||
// Determine mode based on capabilities
|
||||
var mode imagegen.ModelMode
|
||||
if slices.Contains(req.model.Config.Capabilities, "image") {
|
||||
|
||||
@@ -1,5 +1,14 @@
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
const (
|
||||
TOKEN_TYPE_NORMAL = iota + 1
|
||||
TOKEN_TYPE_UNKNOWN
|
||||
@@ -15,3 +24,287 @@ type Tokenizer interface {
|
||||
Is(int32, Special) bool
|
||||
Vocabulary() *Vocabulary
|
||||
}
|
||||
|
||||
func New(root *model.Root) (Tokenizer, error) {
|
||||
f, err := root.Open("tokenizer.json")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var tokenizer struct {
|
||||
Model struct {
|
||||
Type string `json:"type"`
|
||||
Vocab map[string]int32 `json:"vocab"`
|
||||
Merges json.RawMessage `json:"merges"`
|
||||
} `json:"model"`
|
||||
|
||||
PreTokenizer json.RawMessage `json:"pre_tokenizer"`
|
||||
Decoder json.RawMessage `json:"decoder"`
|
||||
|
||||
AddedTokens []struct {
|
||||
ID int32 `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Special bool `json:"special"`
|
||||
} `json:"added_tokens"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(f).Decode(&tokenizer); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
special := make(map[int32]struct{})
|
||||
for _, token := range tokenizer.AddedTokens {
|
||||
tokenizer.Model.Vocab[token.Content] = token.ID
|
||||
special[token.ID] = struct{}{}
|
||||
}
|
||||
|
||||
vocab, err := specialTokens(root, tokenizer.Model.Vocab)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
vocab.Values = make([]string, len(tokenizer.Model.Vocab))
|
||||
vocab.Scores = make([]float32, len(tokenizer.Model.Vocab))
|
||||
vocab.Types = make([]int32, len(tokenizer.Model.Vocab))
|
||||
for content, id := range tokenizer.Model.Vocab {
|
||||
vocab.Values[id] = content
|
||||
vocab.Scores[id] = float32(id)
|
||||
vocab.Types[id] = TOKEN_TYPE_NORMAL
|
||||
if _, ok := special[id]; ok {
|
||||
vocab.Types[id] = TOKEN_TYPE_USER_DEFINED
|
||||
}
|
||||
}
|
||||
|
||||
if tokenizer.Model.Merges != nil {
|
||||
var pairs [][]string
|
||||
if err := json.Unmarshal(tokenizer.Model.Merges, &pairs); err == nil {
|
||||
vocab.Merges = make([]string, len(pairs))
|
||||
for i, pair := range pairs {
|
||||
vocab.Merges[i] = pair[0] + " " + pair[1]
|
||||
}
|
||||
} else if err := json.Unmarshal(tokenizer.Model.Merges, &vocab.Merges); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
vocab.valuesOnce.Do(func() {})
|
||||
vocab.values = tokenizer.Model.Vocab
|
||||
|
||||
if tokenizer.Model.Type == "WordPiece" {
|
||||
return NewWordPiece(vocab, true), nil
|
||||
}
|
||||
|
||||
if tokenizer.Decoder != nil {
|
||||
var decoder struct {
|
||||
Type string `json:"type"`
|
||||
Decoders []struct {
|
||||
Type string `json:"type"`
|
||||
Pattern struct {
|
||||
String string `json:"string"`
|
||||
} `json:"pattern"`
|
||||
} `json:"decoders"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(tokenizer.Decoder, &decoder); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if decoder.Type == "Sequence" {
|
||||
for _, d := range decoder.Decoders {
|
||||
if d.Type == "Replace" && d.Pattern.String == "▁" {
|
||||
return NewSentencePiece(vocab), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var pretokenizers []string
|
||||
if tokenizer.PreTokenizer != nil {
|
||||
var pretokenizer struct {
|
||||
Type string `json:"type"`
|
||||
Pretokenizers []struct {
|
||||
Type string `json:"type"`
|
||||
Pattern struct {
|
||||
Regex string
|
||||
} `json:"pattern"`
|
||||
IndividualDigits bool `json:"individual_digits"`
|
||||
}
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(tokenizer.PreTokenizer, &pretokenizer); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if pretokenizer.Type == "Sequence" {
|
||||
for _, pretokenizer := range pretokenizer.Pretokenizers {
|
||||
switch pretokenizer.Type {
|
||||
case "Digits":
|
||||
if pretokenizer.IndividualDigits {
|
||||
pretokenizers = append(pretokenizers, `\d`)
|
||||
} else {
|
||||
pretokenizers = append(pretokenizers, `\d+`)
|
||||
}
|
||||
case "Punctuation":
|
||||
pretokenizers = append(pretokenizers, `[^\p{L}\p{N}]+`)
|
||||
case "Split":
|
||||
pretokenizers = append(pretokenizers, pretokenizer.Pattern.Regex)
|
||||
case "WhitespaceSplit":
|
||||
pretokenizers = append(pretokenizers, `\s+(?!\S)|\s+`)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return NewBytePairEncoding(vocab, pretokenizers...), nil
|
||||
}
|
||||
|
||||
// valueOrValues is a type that can unmarshal from either a single value or an array of values.
|
||||
type valueOrValues[E any] []E
|
||||
|
||||
func (m *valueOrValues[E]) UnmarshalJSON(data []byte) error {
|
||||
var s []E
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
var e E
|
||||
if err := json.Unmarshal(data, &e); err != nil {
|
||||
return err
|
||||
}
|
||||
s = []E{e}
|
||||
}
|
||||
*m = valueOrValues[E](s)
|
||||
return nil
|
||||
}
|
||||
|
||||
type specialTokenIDs struct {
|
||||
BOSTokenID valueOrValues[int32] `json:"bos_token_id"`
|
||||
EOSTokenID valueOrValues[int32] `json:"eos_token_id"`
|
||||
}
|
||||
|
||||
// stringOrContent is a type that can unmarshal from either a string or an object with a "content" field.
|
||||
type stringOrContent string
|
||||
|
||||
func (t *stringOrContent) UnmarshalJSON(data []byte) error {
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return err
|
||||
}
|
||||
if content, ok := m["content"].(string); ok {
|
||||
s = content
|
||||
}
|
||||
}
|
||||
*t = stringOrContent(s)
|
||||
return nil
|
||||
}
|
||||
|
||||
func specialTokens(root *model.Root, values map[string]int32) (*Vocabulary, error) {
|
||||
var vocab Vocabulary
|
||||
for _, c := range []struct {
|
||||
name string
|
||||
fn func(io.Reader) error
|
||||
}{
|
||||
{
|
||||
name: "generation_config.json",
|
||||
fn: func(r io.Reader) error {
|
||||
var c specialTokenIDs
|
||||
if err := json.NewDecoder(r).Decode(&c); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
vocab.BOS = c.BOSTokenID
|
||||
vocab.EOS = c.EOSTokenID
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "config.json",
|
||||
fn: func(r io.Reader) error {
|
||||
var c specialTokenIDs
|
||||
if err := json.NewDecoder(r).Decode(&c); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(vocab.BOS) == 0 {
|
||||
vocab.BOS = c.BOSTokenID
|
||||
}
|
||||
|
||||
if len(vocab.EOS) == 0 {
|
||||
vocab.EOS = c.EOSTokenID
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tokenizer_config.json",
|
||||
fn: func(r io.Reader) error {
|
||||
var c struct {
|
||||
BOSToken stringOrContent `json:"bos_token"`
|
||||
EOSToken stringOrContent `json:"eos_token"`
|
||||
PADToken stringOrContent `json:"pad_token"`
|
||||
AddBOSToken bool `json:"add_bos_token"`
|
||||
AddEOSToken bool `json:"add_eos_token"`
|
||||
}
|
||||
if err := json.NewDecoder(r).Decode(&c); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(vocab.BOS) == 0 && c.BOSToken != "" {
|
||||
if id, ok := values[string(c.BOSToken)]; ok {
|
||||
vocab.BOS = []int32{id}
|
||||
}
|
||||
}
|
||||
|
||||
if len(vocab.EOS) == 0 && c.EOSToken != "" {
|
||||
if id, ok := values[string(c.EOSToken)]; ok {
|
||||
vocab.EOS = []int32{id}
|
||||
}
|
||||
}
|
||||
|
||||
vocab.AddBOS = c.AddBOSToken
|
||||
vocab.AddEOS = c.AddEOSToken
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "special_tokens_map.json",
|
||||
fn: func(r io.Reader) error {
|
||||
var c map[string]stringOrContent
|
||||
if err := json.NewDecoder(r).Decode(&c); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if bos, ok := c["bos_token"]; ok && len(vocab.BOS) == 0 {
|
||||
if id, ok := values[string(bos)]; ok {
|
||||
vocab.BOS = []int32{id}
|
||||
}
|
||||
}
|
||||
|
||||
if eos, ok := c["eos_token"]; ok && len(vocab.EOS) == 0 {
|
||||
if id, ok := values[string(eos)]; ok {
|
||||
vocab.EOS = []int32{id}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
} {
|
||||
if err := func() error {
|
||||
f, err := root.Open(c.name)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
return c.fn(f)
|
||||
}(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &vocab, nil
|
||||
}
|
||||
|
||||
309
types/model/file.go
Normal file
309
types/model/file.go
Normal file
@@ -0,0 +1,309 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"io/fs"
|
||||
"iter"
|
||||
"maps"
|
||||
"mime"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
func root() (*os.Root, error) {
|
||||
root, err := os.OpenRoot(envconfig.Models())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, sub := range []string{"manifests", "blobs"} {
|
||||
if _, err := root.Stat(sub); errors.Is(err, fs.ErrNotExist) {
|
||||
if err := root.MkdirAll(sub, 0o750); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return root, nil
|
||||
}
|
||||
|
||||
// Open opens an existing file for reading. It will return [fs.ErrNotExist]
|
||||
// if the file does not exist. The returned [*Root] can only be used for reading.
|
||||
// It is the caller's responsibility to close the file when done.
|
||||
func Open(n Name) (*Root, error) {
|
||||
r, err := root()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := r.Open(filepath.Join("manifests", n.Filepath()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var m manifest
|
||||
if err := json.NewDecoder(f).Decode(&m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
blobs := make(map[string]*blob, len(m.Layers)+1)
|
||||
blobs[NamePrefix] = m.Config
|
||||
for _, layer := range m.Layers {
|
||||
if layer.Name == "" && layer.MediaType != "" {
|
||||
mediatype, _, err := mime.ParseMediaType(layer.MediaType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if suffix, ok := strings.CutPrefix(mediatype, MediaTypePrefix); ok {
|
||||
layer.Name = NamePrefix + suffix
|
||||
}
|
||||
}
|
||||
|
||||
blobs[layer.Name] = layer
|
||||
}
|
||||
|
||||
return &Root{
|
||||
root: r,
|
||||
name: n,
|
||||
blobs: blobs,
|
||||
flags: os.O_RDONLY,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Create creates a new file. The returned [Root] can be used for both reading
|
||||
// and writing. It is the caller's responsibility to close the file when done
|
||||
// in order to finalize any new blobs and write the manifest.
|
||||
func Create(n Name) (*Root, error) {
|
||||
r, err := root()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Root{
|
||||
root: r,
|
||||
name: n,
|
||||
blobs: make(map[string]*blob),
|
||||
flags: os.O_RDWR,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type blob struct {
|
||||
Digest string `json:"digest"`
|
||||
MediaType string `json:"mediaType"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Size int64 `json:"size"`
|
||||
|
||||
// tempfile is the temporary file where the blob data is written.
|
||||
tempfile *os.File
|
||||
|
||||
// hash is the hash.Hash used to compute the blob digest.
|
||||
hash hash.Hash
|
||||
}
|
||||
|
||||
func (b *blob) Write(p []byte) (int, error) {
|
||||
return io.MultiWriter(b.tempfile, b.hash).Write(p)
|
||||
}
|
||||
|
||||
func (b *blob) Filepath() string {
|
||||
return strings.ReplaceAll(b.Digest, ":", "-")
|
||||
}
|
||||
|
||||
type manifest struct {
|
||||
SchemaVersion int `json:"schemaVersion"`
|
||||
MediaType string `json:"mediaType"`
|
||||
Config *blob `json:"config"`
|
||||
Layers []*blob `json:"layers"`
|
||||
}
|
||||
|
||||
// Root represents a model file. It can be used to read and write blobs
|
||||
// associated with the model.
|
||||
//
|
||||
// Blobs are identified by name. Certain names are special and reserved;
|
||||
// see [NamePrefix] for details.
|
||||
type Root struct {
|
||||
root *os.Root
|
||||
name Name
|
||||
blobs map[string]*blob
|
||||
flags int
|
||||
}
|
||||
|
||||
const MediaTypePrefix = "application/vnd.ollama"
|
||||
|
||||
// NamePrefix is the prefix used for identifying special names. Names
|
||||
// with this prefix are idenfitied by their media types:
|
||||
//
|
||||
// - name: NamePrefix + suffix
|
||||
// - mediaType: [MediaTypePrefix] + suffix
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// - name: "./..image.model"
|
||||
// - mediaType: "application/vnd.ollama.image.model"
|
||||
//
|
||||
// NamePrefix by itself identifies the manifest config.
|
||||
const NamePrefix = "./."
|
||||
|
||||
// Open opens the named blob for reading. It is the caller's responsibility
|
||||
// to close the returned [io.ReadCloser] when done. It will return
|
||||
// [fs.ErrNotExist] if the blob does not exist.
|
||||
func (r Root) Open(name string) (io.ReadCloser, error) {
|
||||
if b, ok := r.blobs[name]; ok {
|
||||
r, err := r.root.Open(filepath.Join("blobs", b.Filepath()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
return nil, fs.ErrNotExist
|
||||
}
|
||||
|
||||
func (r Root) ReadFile(name string) ([]byte, error) {
|
||||
f, err := r.Open(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
return io.ReadAll(f)
|
||||
}
|
||||
|
||||
// Create creates or replaces a named blob in the file. If the blob already
|
||||
// exists, it will be overwritten. It will return [fs.ErrInvalid] if the file
|
||||
// was opened in read-only mode. The returned [io.Writer] can be used to write
|
||||
// to the blob and does not need be closed, but the file must be closed to
|
||||
// finalize the blob.
|
||||
func (r *Root) Create(name string) (io.Writer, error) {
|
||||
if r.flags&os.O_RDWR != 0 {
|
||||
w, err := os.CreateTemp(r.root.Name(), "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.blobs[name] = &blob{Name: name, tempfile: w, hash: sha256.New()}
|
||||
return r.blobs[name], nil
|
||||
}
|
||||
|
||||
return nil, fs.ErrInvalid
|
||||
}
|
||||
|
||||
// Close closes the file. If the file was opened in read-write mode, it
|
||||
// will finalize any writeable blobs and write the manifest.
|
||||
func (r *Root) Close() error {
|
||||
if r.flags&os.O_RDWR != 0 {
|
||||
for _, b := range r.blobs {
|
||||
if b.tempfile != nil {
|
||||
fi, err := b.tempfile.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := b.tempfile.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b.Size = fi.Size()
|
||||
b.Digest = fmt.Sprintf("sha256:%x", b.hash.Sum(nil))
|
||||
|
||||
if suffix, ok := strings.CutPrefix(b.Name, NamePrefix); ok {
|
||||
if b.Name == NamePrefix {
|
||||
b.MediaType = "application/vnd.docker.container.image.v1+json"
|
||||
} else {
|
||||
b.MediaType = MediaTypePrefix + suffix
|
||||
}
|
||||
b.Name = ""
|
||||
}
|
||||
|
||||
rel, err := filepath.Rel(r.root.Name(), b.tempfile.Name())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.root.Rename(rel, filepath.Join("blobs", b.Filepath())); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p := filepath.Join("manifests", r.name.Filepath())
|
||||
if _, err := r.root.Stat(filepath.Dir(p)); errors.Is(err, os.ErrNotExist) {
|
||||
if err := r.root.MkdirAll(filepath.Dir(p), 0o750); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err := r.root.OpenFile(p, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o640)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := json.NewEncoder(f).Encode(manifest{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||
Config: r.blobs[NamePrefix],
|
||||
Layers: func() []*blob {
|
||||
blobs := make([]*blob, 0, len(r.blobs))
|
||||
for name, b := range r.blobs {
|
||||
if name != NamePrefix {
|
||||
blobs = append(blobs, b)
|
||||
}
|
||||
}
|
||||
return blobs
|
||||
}(),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return r.root.Close()
|
||||
}
|
||||
|
||||
// Name returns the name of the file.
|
||||
func (r Root) Name() Name {
|
||||
return r.name
|
||||
}
|
||||
|
||||
// Names returns an iterator over the names in the file.
|
||||
func (r Root) Names() iter.Seq[string] {
|
||||
return maps.Keys(r.blobs)
|
||||
}
|
||||
|
||||
// Glob returns an iterator over the names in the file that match the given
|
||||
// pattern.
|
||||
//
|
||||
// The pattern syntax is the same as [filepath.Match]. As with filepath.Match,
|
||||
// the only possible returned error is ErrBadPattern, when pattern is malformed.
|
||||
func (r Root) Glob(pattern string) (iter.Seq[string], error) {
|
||||
if _, err := filepath.Match(pattern, ""); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return func(yield func(string) bool) {
|
||||
for name, blob := range r.blobs {
|
||||
if matched, _ := filepath.Match(pattern, name); matched {
|
||||
if !yield(blob.Filepath()) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r Root) JoinPath(parts ...string) string {
|
||||
return filepath.Join(append([]string{r.root.Name()}, parts...)...)
|
||||
}
|
||||
90
types/model/file_test.go
Normal file
90
types/model/file_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// setup is a helper function to set up the test environment.
|
||||
func setup(t *testing.T, models map[Name]map[string]io.Reader) {
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
|
||||
for m, s := range models {
|
||||
f, err := Create(m)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for n, r := range s {
|
||||
w, err := f.Create(n)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := io.Copy(w, r); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := f.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpen(t *testing.T) {
|
||||
setup(t, map[Name]map[string]io.Reader{
|
||||
ParseName("namespace/model"): {
|
||||
"./.": strings.NewReader(`{"key":"value"}`),
|
||||
},
|
||||
ParseName("namespace/model:8b"): {
|
||||
"./.": strings.NewReader(`{"foo":"bar"}`),
|
||||
},
|
||||
ParseName("another/model"): {
|
||||
"./.": strings.NewReader(`{"another":"config"}`),
|
||||
},
|
||||
})
|
||||
|
||||
f, err := Open(ParseName("namespace/model"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, name := range []string{"./."} {
|
||||
r, err := f.Open(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := io.ReadAll(r); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := r.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := f.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("does not exist", func(t *testing.T) {
|
||||
if _, err := Open(ParseName("namespace/unknown")); err == nil {
|
||||
t.Error("expected error for unknown model")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("write", func(t *testing.T) {
|
||||
f, err := Open(ParseName("namespace/model"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := f.Create("new-blob"); err == nil {
|
||||
t.Error("expected error creating blob in read-only mode")
|
||||
}
|
||||
})
|
||||
}
|
||||
33
types/model/files.go
Normal file
33
types/model/files.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"iter"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func All() (iter.Seq[Name], error) {
|
||||
r, err := root()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
manifests, err := r.OpenRoot("manifests")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
matches, err := fs.Glob(manifests.FS(), "*/*/*/*")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return func(yield func(Name) bool) {
|
||||
for _, match := range matches {
|
||||
name := ParseNameFromFilepath(filepath.ToSlash(match))
|
||||
if !yield(name) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
@@ -227,6 +227,17 @@ func (n Name) String() string {
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// Set implements [flag.Value]. It parses the provided input as a name string
|
||||
// and sets the receiver to the parsed value. If the parsed name is not valid,
|
||||
// ErrUnqualifiedName is returned.
|
||||
func (n *Name) Set(s string) error {
|
||||
*n = ParseName(s)
|
||||
if !n.IsValid() {
|
||||
return ErrUnqualifiedName
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DisplayShortest returns a short string version of the name.
|
||||
func (n Name) DisplayShortest() string {
|
||||
var sb strings.Builder
|
||||
|
||||
94
x/mlxrunner/cache.go
Normal file
94
x/mlxrunner/cache.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
)
|
||||
|
||||
type CacheEntry struct {
|
||||
Caches []cache.Cache
|
||||
Count int
|
||||
Entries map[int32]*CacheEntry
|
||||
}
|
||||
|
||||
func (s Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
|
||||
current := &CacheEntry{Entries: s.CacheEntries}
|
||||
index, cacheIndex := 0, -1
|
||||
for _, token := range tokens {
|
||||
if _, ok := current.Entries[token]; !ok {
|
||||
break
|
||||
}
|
||||
|
||||
current = current.Entries[token]
|
||||
if len(current.Caches) > 0 {
|
||||
cacheIndex = index
|
||||
}
|
||||
|
||||
index += 1
|
||||
}
|
||||
|
||||
if cacheIndex == len(tokens)-1 {
|
||||
slog.Info("Cache hit", "type", "exact", "total", len(tokens), "cached", len(tokens), "left", len(tokens))
|
||||
return current.Caches, []int32{}
|
||||
} else if cacheIndex > 1 {
|
||||
slog.Info("Cache hit", "type", "partial", "total", len(tokens), "cached", cacheIndex+1, "left", len(tokens[cacheIndex+1:]))
|
||||
return current.Caches, tokens[cacheIndex+1:]
|
||||
} else if index > 0 && cacheIndex < 0 {
|
||||
type stackItem struct {
|
||||
entry *CacheEntry
|
||||
tokens []int32
|
||||
}
|
||||
|
||||
var best, item stackItem
|
||||
stack := []stackItem{{entry: current, tokens: []int32{}}}
|
||||
for len(stack) > 0 {
|
||||
item, stack = stack[len(stack)-1], stack[:len(stack)-1]
|
||||
if len(item.entry.Caches) > 0 {
|
||||
if len(best.tokens) == 0 || len(item.tokens) < len(best.tokens) {
|
||||
best = item
|
||||
}
|
||||
} else {
|
||||
for token, entry := range item.entry.Entries {
|
||||
stack = append(stack, stackItem{
|
||||
entry: entry,
|
||||
tokens: append(item.tokens, token),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prefix := min(len(tokens)-1, index)
|
||||
caches := make([]cache.Cache, len(best.entry.Caches))
|
||||
trim := len(best.tokens)+1
|
||||
for i := range caches {
|
||||
caches[i] = best.entry.Caches[i].Clone()
|
||||
caches[i].Trim(trim)
|
||||
}
|
||||
|
||||
slog.Info("Cache hit", "type", "prefix", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]), "trimmed", trim)
|
||||
return caches, tokens[prefix:]
|
||||
}
|
||||
|
||||
slog.Info("Cache miss", "left", len(tokens))
|
||||
return nil, tokens
|
||||
}
|
||||
|
||||
func (s *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
|
||||
current := &CacheEntry{Entries: s.CacheEntries}
|
||||
for _, token := range tokens {
|
||||
if _, ok := current.Entries[token]; !ok {
|
||||
current.Entries[token] = &CacheEntry{
|
||||
Entries: make(map[int32]*CacheEntry),
|
||||
}
|
||||
}
|
||||
|
||||
current = current.Entries[token]
|
||||
}
|
||||
|
||||
if len(current.Caches) > 0 {
|
||||
current.Count += 1
|
||||
} else {
|
||||
current.Caches = caches
|
||||
}
|
||||
}
|
||||
196
x/mlxrunner/cache/cache.go
vendored
Normal file
196
x/mlxrunner/cache/cache.go
vendored
Normal file
@@ -0,0 +1,196 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
type Cache interface {
|
||||
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
|
||||
State() (keys, values *mlx.Array)
|
||||
Trim(int) int
|
||||
Clone() Cache
|
||||
Offset() int
|
||||
Len() int
|
||||
}
|
||||
|
||||
type KVCache struct {
|
||||
keys, values *mlx.Array
|
||||
offset int
|
||||
step int
|
||||
}
|
||||
|
||||
func NewKVCache() *KVCache {
|
||||
return &KVCache{step: 256, keys: &mlx.Array{}, values: &mlx.Array{}}
|
||||
}
|
||||
|
||||
func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
|
||||
|
||||
prev := c.offset
|
||||
|
||||
// Grow buffer if needed
|
||||
if !c.keys.Valid() || (prev+L) > c.keys.Dim(2) {
|
||||
steps := (c.step + L - 1) / c.step
|
||||
newKeys := mlx.Zeros(keys.DType(), B, H, steps*c.step, Dk)
|
||||
newValues := mlx.Zeros(values.DType(), B, H, steps*c.step, Dv)
|
||||
|
||||
if c.keys.Valid() {
|
||||
if prev%c.step != 0 {
|
||||
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice()))
|
||||
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice()))
|
||||
}
|
||||
c.keys.Set(c.keys.Concatenate(2, newKeys))
|
||||
c.values.Set(c.values.Concatenate(2, newValues))
|
||||
} else {
|
||||
c.keys, c.values = newKeys, newValues
|
||||
}
|
||||
}
|
||||
|
||||
c.offset += L
|
||||
c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
|
||||
c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
|
||||
|
||||
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||
}
|
||||
|
||||
func (c *KVCache) State() (*mlx.Array, *mlx.Array) {
|
||||
if c.offset == c.keys.Dim(2) {
|
||||
return c.keys, c.values
|
||||
}
|
||||
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||
}
|
||||
|
||||
func (c *KVCache) Trim(n int) int {
|
||||
n = min(c.offset, n)
|
||||
c.offset -= n
|
||||
return n
|
||||
}
|
||||
|
||||
func (c *KVCache) Clone() Cache {
|
||||
return &KVCache{
|
||||
keys: c.keys.Clone(),
|
||||
values: c.values.Clone(),
|
||||
offset: c.offset,
|
||||
step: c.step,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *KVCache) Offset() int { return c.offset }
|
||||
func (c *KVCache) Len() int { return c.offset }
|
||||
|
||||
// RotatingKVCache implements sliding window attention with bounded memory
|
||||
type RotatingKVCache struct {
|
||||
maxSize int
|
||||
idx int
|
||||
|
||||
*KVCache
|
||||
}
|
||||
|
||||
func NewRotatingKVCache(maxSize int) *RotatingKVCache {
|
||||
return &RotatingKVCache{maxSize: maxSize, KVCache: NewKVCache()}
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
if keys.Dim(2) > 1 {
|
||||
return c.concat(keys, values)
|
||||
}
|
||||
return c.update(keys, values)
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV *mlx.Array) {
|
||||
slog.Debug("(*RotatingKVCache).concat", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
|
||||
if !c.keys.Valid() {
|
||||
c.keys, c.values = keys, values
|
||||
} else {
|
||||
if c.idx < c.keys.Dim(2) {
|
||||
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
||||
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
||||
}
|
||||
|
||||
// Trim to max_size to maintain sliding window
|
||||
if trim := c.idx - c.maxSize + 1; trim > 0 {
|
||||
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.keys.Dim(2)), mlx.Slice()))
|
||||
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.values.Dim(2)), mlx.Slice()))
|
||||
}
|
||||
|
||||
c.keys.Set(c.keys.Concatenate(2, keys))
|
||||
c.values.Set(c.values.Concatenate(2, values))
|
||||
c.idx = c.keys.Dim(2)
|
||||
}
|
||||
|
||||
c.offset += keys.Dim(2)
|
||||
c.idx = c.keys.Dim(2)
|
||||
return c.keys, c.values
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
slog.Debug("(*RotatingKVCache).update", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
|
||||
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
|
||||
|
||||
prev := c.offset
|
||||
|
||||
// Grow buffer if not yet at max
|
||||
if !c.keys.Valid() || (prev >= c.keys.Dim(2) && c.keys.Dim(2) < c.maxSize) {
|
||||
newSize := min(c.step, c.maxSize-prev)
|
||||
newKeys := mlx.Zeros(keys.DType(), B, H, newSize, Dk)
|
||||
newValues := mlx.Zeros(values.DType(), B, H, newSize, Dv)
|
||||
if c.keys.Valid() {
|
||||
c.keys.Set(c.keys.Concatenate(2, newKeys))
|
||||
c.values.Set(c.values.Concatenate(2, newValues))
|
||||
} else {
|
||||
c.keys, c.values = newKeys, newValues
|
||||
}
|
||||
c.idx = prev
|
||||
}
|
||||
|
||||
// Trim to max_size to maintain sliding window
|
||||
if trim := c.keys.Dim(2) - c.maxSize; trim > 0 {
|
||||
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.keys.Dim(2)), mlx.Slice()))
|
||||
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.values.Dim(2)), mlx.Slice()))
|
||||
c.idx = c.maxSize
|
||||
}
|
||||
|
||||
// Rotate when hitting max
|
||||
if c.idx >= c.maxSize {
|
||||
c.idx = 0
|
||||
}
|
||||
|
||||
c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.idx+L), mlx.Slice()))
|
||||
c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.idx+L), mlx.Slice()))
|
||||
|
||||
c.offset += L
|
||||
c.idx += L
|
||||
|
||||
validLen := min(c.offset, c.maxSize)
|
||||
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice())
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) {
|
||||
if c.offset < c.keys.Dim(2) {
|
||||
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||
}
|
||||
return c.keys, c.values
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Trim(n int) int {
|
||||
n = min(c.offset, n)
|
||||
c.offset -= n
|
||||
c.idx -= n
|
||||
return n
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Clone() Cache {
|
||||
return &RotatingKVCache{
|
||||
maxSize: c.maxSize,
|
||||
idx: c.idx,
|
||||
KVCache: c.KVCache.Clone().(*KVCache),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }
|
||||
174
x/mlxrunner/client.go
Normal file
174
x/mlxrunner/client.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
Port int
|
||||
*exec.Cmd
|
||||
}
|
||||
|
||||
func (c *Client) JoinPath(path string) string {
|
||||
return (&url.URL{
|
||||
Scheme: "http",
|
||||
Host: net.JoinHostPort("127.0.0.1", strconv.Itoa(c.Port)),
|
||||
}).JoinPath(path).String()
|
||||
}
|
||||
|
||||
func (c *Client) CheckError(w *http.Response) error {
|
||||
if w.StatusCode >= 400 {
|
||||
return errors.New(w.Status)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close implements llm.LlamaServer.
|
||||
func (c *Client) Close() error {
|
||||
return c.Cmd.Process.Kill()
|
||||
}
|
||||
|
||||
// Completion implements llm.LlamaServer.
|
||||
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w, err := http.Post(c.JoinPath("/v1/completions"), "application/json", &b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer w.Body.Close()
|
||||
|
||||
if err := c.CheckError(w); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(w.Body)
|
||||
for scanner.Scan() {
|
||||
bts := scanner.Bytes()
|
||||
|
||||
var resp llm.CompletionResponse
|
||||
if err := json.Unmarshal(bts, &resp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fn(resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) ContextLength() int {
|
||||
return math.MaxInt
|
||||
}
|
||||
|
||||
// Detokenize implements llm.LlamaServer.
|
||||
func (c *Client) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Embedding implements llm.LlamaServer.
|
||||
func (c *Client) Embedding(ctx context.Context, input string) ([]float32, int, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// GetDeviceInfos implements llm.LlamaServer.
|
||||
func (c *Client) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// GetPort implements llm.LlamaServer.
|
||||
func (c *Client) GetPort() int {
|
||||
return c.Port
|
||||
}
|
||||
|
||||
// HasExited implements llm.LlamaServer.
|
||||
func (c *Client) HasExited() bool {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Load implements llm.LlamaServer.
|
||||
func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) ([]ml.DeviceID, error) {
|
||||
w, err := http.Post(c.JoinPath("/v1/models"), "application/json", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer w.Body.Close()
|
||||
|
||||
return []ml.DeviceID{}, nil
|
||||
}
|
||||
|
||||
// ModelPath implements llm.LlamaServer.
|
||||
func (c *Client) ModelPath() string {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Pid implements llm.LlamaServer.
|
||||
func (c *Client) Pid() int {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Ping implements llm.LlamaServer.
|
||||
func (c *Client) Ping(ctx context.Context) error {
|
||||
w, err := http.Get(c.JoinPath("/v1/status"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer w.Body.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Tokenize implements llm.LlamaServer.
|
||||
func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
w, err := http.Post(c.JoinPath("/v1/tokenize"), "text/plain", strings.NewReader(content))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer w.Body.Close()
|
||||
|
||||
var tokens []int
|
||||
if err := json.NewDecoder(w.Body).Decode(&tokens); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// TotalSize implements llm.LlamaServer.
|
||||
func (c *Client) TotalSize() uint64 {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// VRAMByGPU implements llm.LlamaServer.
|
||||
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// VRAMSize implements llm.LlamaServer.
|
||||
func (c *Client) VRAMSize() uint64 {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// WaitUntilRunning implements llm.LlamaServer.
|
||||
func (c *Client) WaitUntilRunning(ctx context.Context) error {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
var _ llm.LlamaServer = (*Client)(nil)
|
||||
3
x/mlxrunner/mlx/.gitignore
vendored
Normal file
3
x/mlxrunner/mlx/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
_deps
|
||||
build
|
||||
dist
|
||||
26
x/mlxrunner/mlx/CMakeLists.txt
Normal file
26
x/mlxrunner/mlx/CMakeLists.txt
Normal file
@@ -0,0 +1,26 @@
|
||||
cmake_minimum_required(VERSION 3.5)
|
||||
|
||||
project(mlx)
|
||||
|
||||
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
|
||||
set(CMAKE_INSTALL_PREFIX "${CMAKE_CURRENT_SOURCE_DIR}/dist" CACHE PATH "" FORCE)
|
||||
endif()
|
||||
|
||||
set(MLX_BUILD_GGUF OFF CACHE BOOL "" FORCE)
|
||||
set(MLX_BUILD_SAFETENSORS ON CACHE BOOL "" FORCE)
|
||||
set(MLX_C_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
|
||||
set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
|
||||
|
||||
set(CMAKE_INSTALL_RPATH "@loader_path")
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
set(MLX_C_GIT_TAG "v0.4.0" CACHE STRING "")
|
||||
|
||||
FetchContent_Declare(
|
||||
mlx-c
|
||||
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
|
||||
GIT_TAG ${MLX_C_GIT_TAG}
|
||||
)
|
||||
|
||||
FetchContent_MakeAvailable(mlx-c)
|
||||
21
x/mlxrunner/mlx/act.go
Normal file
21
x/mlxrunner/mlx/act.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
import "math"
|
||||
|
||||
func GELUApprox(t *Array) *Array {
|
||||
return t.Multiply(
|
||||
FromValue[float32](0.5),
|
||||
).Multiply(
|
||||
t.Add(
|
||||
t.Power(FromValue[float32](3.0)).Multiply(FromValue[float32](0.044715)),
|
||||
).Multiply(
|
||||
FromValue(float32(math.Sqrt(2 / math.Pi))),
|
||||
).Tanh().Add(FromValue[float32](1.0)),
|
||||
).AsType(t.DType())
|
||||
}
|
||||
|
||||
func SILU(t *Array) *Array {
|
||||
return t.Multiply(t.Sigmoid()).AsType(t.DType())
|
||||
}
|
||||
264
x/mlxrunner/mlx/array.go
Normal file
264
x/mlxrunner/mlx/array.go
Normal file
@@ -0,0 +1,264 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"log/slog"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type tensorDesc struct {
|
||||
name string
|
||||
inputs []*Array
|
||||
numRefs int
|
||||
}
|
||||
|
||||
func (d tensorDesc) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.String("name", d.name),
|
||||
slog.Int("inputs", len(d.inputs)),
|
||||
slog.Int("num_refs", d.numRefs),
|
||||
)
|
||||
}
|
||||
|
||||
type Array struct {
|
||||
ctx C.mlx_array
|
||||
desc tensorDesc
|
||||
}
|
||||
|
||||
// constructor utilities
|
||||
|
||||
func New(name string, inputs ...*Array) *Array {
|
||||
t := &Array{
|
||||
desc: tensorDesc{
|
||||
name: name,
|
||||
inputs: inputs,
|
||||
},
|
||||
}
|
||||
|
||||
for _, input := range inputs {
|
||||
input.desc.numRefs++
|
||||
}
|
||||
logutil.Trace("New", "t", t)
|
||||
return t
|
||||
}
|
||||
|
||||
type scalarTypes interface {
|
||||
~bool | ~int | ~float32 | ~float64 | ~complex64
|
||||
}
|
||||
|
||||
func FromValue[T scalarTypes](t T) *Array {
|
||||
tt := New("")
|
||||
switch v := any(t).(type) {
|
||||
case bool:
|
||||
tt.ctx = C.mlx_array_new_bool(C.bool(v))
|
||||
case int:
|
||||
tt.ctx = C.mlx_array_new_int(C.int(v))
|
||||
case float32:
|
||||
tt.ctx = C.mlx_array_new_float32(C.float(v))
|
||||
case float64:
|
||||
tt.ctx = C.mlx_array_new_float64(C.double(v))
|
||||
case complex64:
|
||||
tt.ctx = C.mlx_array_new_complex(C.float(real(v)), C.float(imag(v)))
|
||||
default:
|
||||
panic("unsupported type")
|
||||
}
|
||||
return tt
|
||||
}
|
||||
|
||||
type arrayTypes interface {
|
||||
~bool | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
|
||||
~int8 | ~int16 | ~int32 | ~int64 |
|
||||
~float32 | ~float64 |
|
||||
~complex64
|
||||
}
|
||||
|
||||
func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
|
||||
if len(shape) == 0 {
|
||||
panic("shape must be provided for non-scalar tensors")
|
||||
}
|
||||
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i := range shape {
|
||||
cShape[i] = C.int(shape[i])
|
||||
}
|
||||
|
||||
var dtype DType
|
||||
switch reflect.TypeOf(s).Elem().Kind() {
|
||||
case reflect.Bool:
|
||||
dtype = DTypeBool
|
||||
case reflect.Uint8:
|
||||
dtype = DTypeUint8
|
||||
case reflect.Uint16:
|
||||
dtype = DTypeUint16
|
||||
case reflect.Uint32:
|
||||
dtype = DTypeUint32
|
||||
case reflect.Uint64:
|
||||
dtype = DTypeUint64
|
||||
case reflect.Int8:
|
||||
dtype = DTypeInt8
|
||||
case reflect.Int16:
|
||||
dtype = DTypeInt16
|
||||
case reflect.Int32:
|
||||
dtype = DTypeInt32
|
||||
case reflect.Int64:
|
||||
dtype = DTypeInt64
|
||||
case reflect.Float32:
|
||||
dtype = DTypeFloat32
|
||||
case reflect.Float64:
|
||||
dtype = DTypeFloat64
|
||||
case reflect.Complex64:
|
||||
dtype = DTypeComplex64
|
||||
default:
|
||||
panic("unsupported type")
|
||||
}
|
||||
|
||||
bts := make([]byte, binary.Size(s))
|
||||
if _, err := binary.Encode(bts, binary.LittleEndian, s); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
tt := New("")
|
||||
tt.ctx = C.mlx_array_new_data(unsafe.Pointer(&bts[0]), unsafe.SliceData(cShape), C.int(len(cShape)), C.mlx_dtype(dtype))
|
||||
return tt
|
||||
}
|
||||
|
||||
func (t *Array) Set(other *Array) {
|
||||
other.desc.numRefs++
|
||||
t.desc.inputs = []*Array{other}
|
||||
C.mlx_array_set(&t.ctx, other.ctx)
|
||||
}
|
||||
|
||||
func (t *Array) Clone() *Array {
|
||||
tt := New(t.desc.name, t.desc.inputs...)
|
||||
C.mlx_array_set(&tt.ctx, t.ctx)
|
||||
return tt
|
||||
}
|
||||
|
||||
// misc. utilities
|
||||
|
||||
func (t *Array) Valid() bool {
|
||||
return t.ctx.ctx != nil
|
||||
}
|
||||
|
||||
func (t *Array) String() string {
|
||||
str := C.mlx_string_new()
|
||||
defer C.mlx_string_free(str)
|
||||
C.mlx_array_tostring(&str, t.ctx)
|
||||
return strings.TrimSpace(C.GoString(C.mlx_string_data(str)))
|
||||
}
|
||||
|
||||
func (t *Array) LogValue() slog.Value {
|
||||
attrs := []slog.Attr{slog.Any("", t.desc)}
|
||||
if t.Valid() {
|
||||
attrs = append(attrs,
|
||||
slog.Any("dtype", t.DType()),
|
||||
slog.Any("shape", t.Dims()),
|
||||
slog.Int("num_bytes", t.NumBytes()),
|
||||
)
|
||||
}
|
||||
return slog.GroupValue(attrs...)
|
||||
}
|
||||
|
||||
// shape utilities
|
||||
|
||||
func (t Array) Size() int {
|
||||
return int(C.mlx_array_size(t.ctx))
|
||||
}
|
||||
|
||||
func (t Array) NumBytes() int {
|
||||
return int(C.mlx_array_nbytes(t.ctx))
|
||||
}
|
||||
|
||||
func (t Array) NumDims() int {
|
||||
return int(C.mlx_array_ndim(t.ctx))
|
||||
}
|
||||
|
||||
func (t Array) Dims() []int {
|
||||
dims := make([]int, t.NumDims())
|
||||
for i := range dims {
|
||||
dims[i] = t.Dim(i)
|
||||
}
|
||||
|
||||
return dims
|
||||
}
|
||||
|
||||
func (t Array) Dim(dim int) int {
|
||||
return int(C.mlx_array_dim(t.ctx, C.int(dim)))
|
||||
}
|
||||
|
||||
func (t Array) DType() DType {
|
||||
return DType(C.mlx_array_dtype(t.ctx))
|
||||
}
|
||||
|
||||
// data utilities
|
||||
|
||||
func (t Array) Int() int {
|
||||
var item C.int64_t
|
||||
C.mlx_array_item_int64(&item, t.ctx)
|
||||
return int(item)
|
||||
}
|
||||
|
||||
func (t Array) Float() float64 {
|
||||
var item C.double
|
||||
C.mlx_array_item_float64(&item, t.ctx)
|
||||
return float64(item)
|
||||
}
|
||||
|
||||
func (t Array) Ints() []int {
|
||||
ints := make([]int, t.Size())
|
||||
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) {
|
||||
ints[i] = int(f)
|
||||
}
|
||||
return ints
|
||||
}
|
||||
|
||||
func (t Array) Floats() []float32 {
|
||||
floats := make([]float32, t.Size())
|
||||
for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) {
|
||||
floats[i] = float32(f)
|
||||
}
|
||||
return floats
|
||||
}
|
||||
|
||||
func Free(s ...*Array) (n int) {
|
||||
now := time.Now()
|
||||
defer func() {
|
||||
if n > 0 {
|
||||
logutil.Trace("Freed tensors", "num_bytes", PrettyBytes(n), "took", time.Since(now))
|
||||
}
|
||||
}()
|
||||
|
||||
free := make([]*Array, 0, 8192)
|
||||
fn := func(t *Array) {
|
||||
if t.Valid() {
|
||||
free = append(free, t.desc.inputs...)
|
||||
t.desc.numRefs--
|
||||
if t.desc.numRefs <= 0 {
|
||||
logutil.Trace("Free", "t", t)
|
||||
n += t.NumBytes()
|
||||
C.mlx_array_free(t.ctx)
|
||||
t.ctx.ctx = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, t := range s {
|
||||
fn(t)
|
||||
}
|
||||
|
||||
for len(free) > 0 {
|
||||
tail := free[len(free)-1]
|
||||
free = free[:len(free)-1]
|
||||
fn(tail)
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
43
x/mlxrunner/mlx/array_test.go
Normal file
43
x/mlxrunner/mlx/array_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package mlx
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestFromValue(t *testing.T) {
|
||||
for got, want := range map[*Tensor]DType{
|
||||
FromValue(true): DTypeBool,
|
||||
FromValue(false): DTypeBool,
|
||||
FromValue(int(7)): DTypeInt32,
|
||||
FromValue(float32(3.14)): DTypeFloat32,
|
||||
FromValue(float64(2.71)): DTypeFloat64,
|
||||
FromValue(complex64(1 + 2i)): DTypeComplex64,
|
||||
} {
|
||||
t.Run(want.String(), func(t *testing.T) {
|
||||
if got.DType() != want {
|
||||
t.Errorf("want %v, got %v", want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromValues(t *testing.T) {
|
||||
for got, want := range map[*Tensor]DType{
|
||||
FromValues([]bool{true, false, true}, 3): DTypeBool,
|
||||
FromValues([]uint8{1, 2, 3}, 3): DTypeUint8,
|
||||
FromValues([]uint16{1, 2, 3}, 3): DTypeUint16,
|
||||
FromValues([]uint32{1, 2, 3}, 3): DTypeUint32,
|
||||
FromValues([]uint64{1, 2, 3}, 3): DTypeUint64,
|
||||
FromValues([]int8{-1, -2, -3}, 3): DTypeInt8,
|
||||
FromValues([]int16{-1, -2, -3}, 3): DTypeInt16,
|
||||
FromValues([]int32{-1, -2, -3}, 3): DTypeInt32,
|
||||
FromValues([]int64{-1, -2, -3}, 3): DTypeInt64,
|
||||
FromValues([]float32{3.14, 2.71, 1.61}, 3): DTypeFloat32,
|
||||
FromValues([]float64{3.14, 2.71, 1.61}, 3): DTypeFloat64,
|
||||
FromValues([]complex64{1 + 2i, 3 + 4i, 5 + 6i}, 3): DTypeComplex64,
|
||||
} {
|
||||
t.Run(want.String(), func(t *testing.T) {
|
||||
if got.DType() != want {
|
||||
t.Errorf("want %v, got %v", want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
94
x/mlxrunner/mlx/dtype.go
Normal file
94
x/mlxrunner/mlx/dtype.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
type DType int
|
||||
|
||||
func (t DType) String() string {
|
||||
switch t {
|
||||
case DTypeBool:
|
||||
return "BOOL"
|
||||
case DTypeUint8:
|
||||
return "U8"
|
||||
case DTypeUint16:
|
||||
return "U16"
|
||||
case DTypeUint32:
|
||||
return "U32"
|
||||
case DTypeUint64:
|
||||
return "U64"
|
||||
case DTypeInt8:
|
||||
return "I8"
|
||||
case DTypeInt16:
|
||||
return "I16"
|
||||
case DTypeInt32:
|
||||
return "I32"
|
||||
case DTypeInt64:
|
||||
return "I64"
|
||||
case DTypeFloat16:
|
||||
return "F16"
|
||||
case DTypeFloat32:
|
||||
return "F32"
|
||||
case DTypeFloat64:
|
||||
return "F64"
|
||||
case DTypeBFloat16:
|
||||
return "BF16"
|
||||
case DTypeComplex64:
|
||||
return "C64"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (t *DType) UnmarshalJSON(b []byte) error {
|
||||
switch string(b) {
|
||||
case `"BOOL"`:
|
||||
*t = DTypeBool
|
||||
case `"U8"`:
|
||||
*t = DTypeUint8
|
||||
case `"U16"`:
|
||||
*t = DTypeUint16
|
||||
case `"U32"`:
|
||||
*t = DTypeUint32
|
||||
case `"U64"`:
|
||||
*t = DTypeUint64
|
||||
case `"I8"`:
|
||||
*t = DTypeInt8
|
||||
case `"I16"`:
|
||||
*t = DTypeInt16
|
||||
case `"I32"`:
|
||||
*t = DTypeInt32
|
||||
case `"I64"`:
|
||||
*t = DTypeInt64
|
||||
case `"F16"`:
|
||||
*t = DTypeFloat16
|
||||
case `"F64"`:
|
||||
*t = DTypeFloat64
|
||||
case `"F32"`:
|
||||
*t = DTypeFloat32
|
||||
case `"BF16"`:
|
||||
*t = DTypeBFloat16
|
||||
case `"C64"`:
|
||||
*t = DTypeComplex64
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
DTypeBool DType = C.MLX_BOOL
|
||||
DTypeUint8 DType = C.MLX_UINT8
|
||||
DTypeUint16 DType = C.MLX_UINT16
|
||||
DTypeUint32 DType = C.MLX_UINT32
|
||||
DTypeUint64 DType = C.MLX_UINT64
|
||||
DTypeInt8 DType = C.MLX_INT8
|
||||
DTypeInt16 DType = C.MLX_INT16
|
||||
DTypeInt32 DType = C.MLX_INT32
|
||||
DTypeInt64 DType = C.MLX_INT64
|
||||
DTypeFloat16 DType = C.MLX_FLOAT16
|
||||
DTypeFloat32 DType = C.MLX_FLOAT32
|
||||
DTypeFloat64 DType = C.MLX_FLOAT64
|
||||
DTypeBFloat16 DType = C.MLX_BFLOAT16
|
||||
DTypeComplex64 DType = C.MLX_COMPLEX64
|
||||
)
|
||||
34
x/mlxrunner/mlx/dynamic.c
Normal file
34
x/mlxrunner/mlx/dynamic.c
Normal file
@@ -0,0 +1,34 @@
|
||||
#include "dynamic.h"
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
#define DLOPEN(path) LoadLibraryA(path)
|
||||
#define DLCLOSE(handle) FreeLibrary((HMODULE)(handle))
|
||||
#else
|
||||
#ifdef __APPLE__
|
||||
#include <mach-o/dyld.h>
|
||||
#include <libgen.h>
|
||||
#endif
|
||||
#include <dlfcn.h>
|
||||
#define DLOPEN(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
|
||||
#define DLCLOSE(handle) dlclose(handle)
|
||||
#endif
|
||||
|
||||
static int mlx_dynamic_open(mlx_dynamic_handle* handle, const char* path) {
|
||||
handle->ctx = (void*) DLOPEN(path);
|
||||
CHECK(handle->ctx != NULL);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int mlx_dynamic_load(mlx_dynamic_handle* handle, const char *path) {
|
||||
return mlx_dynamic_open(handle, path);
|
||||
}
|
||||
|
||||
void mlx_dynamic_unload(mlx_dynamic_handle* handle) {
|
||||
if (handle->ctx) {
|
||||
DLCLOSE(handle->ctx);
|
||||
handle->ctx = NULL;
|
||||
}
|
||||
}
|
||||
63
x/mlxrunner/mlx/dynamic.go
Normal file
63
x/mlxrunner/mlx/dynamic.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package mlx
|
||||
|
||||
// #include "dynamic.h"
|
||||
// #include "generated.h"
|
||||
// #include <stdlib.h>
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func init() {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
|
||||
case "windows":
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH")
|
||||
if !ok {
|
||||
slog.Debug("OLLAMA_LIBRARY_PATH not set, skipping mlx dynamic loading")
|
||||
return
|
||||
}
|
||||
|
||||
for _, path := range filepath.SplitList(paths) {
|
||||
matches, err := fs.Glob(os.DirFS(path), "libmlxc.*")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for _, match := range matches {
|
||||
path := filepath.Join(paths, match)
|
||||
slog.Info("Loading MLX dynamic library", "path", path)
|
||||
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
var handle C.mlx_dynamic_handle
|
||||
if C.mlx_dynamic_load(&handle, cPath) != 0 {
|
||||
slog.Error("Failed to load MLX dynamic library", "path", path)
|
||||
continue
|
||||
}
|
||||
|
||||
if C.mlx_dynamic_load_symbols(handle) != 0 {
|
||||
slog.Error("Failed to load MLX dynamic library symbols", "path", path)
|
||||
C.mlx_dynamic_unload(&handle)
|
||||
continue
|
||||
}
|
||||
|
||||
slog.Info("Loaded MLX dynamic library", "path", path)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
panic("Failed to load any MLX dynamic library")
|
||||
}
|
||||
27
x/mlxrunner/mlx/dynamic.h
Normal file
27
x/mlxrunner/mlx/dynamic.h
Normal file
@@ -0,0 +1,27 @@
|
||||
#ifndef MLX_DYNAMIC_H
|
||||
#define MLX_DYNAMIC_H
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
#define DLSYM(handle, symbol) GetProcAddress((HMODULE)(handle), symbol)
|
||||
#else
|
||||
#include <dlfcn.h>
|
||||
#define DLSYM(handle, symbol) dlsym(handle.ctx, symbol)
|
||||
#endif
|
||||
|
||||
#define ERROR(fmt, ...) fprintf(stderr, "%s %s - ERROR - %s:%d - " fmt "\n", __DATE__, __TIME__, __FILE__, __LINE__, ##__VA_ARGS__); return 1
|
||||
#define CHECK(x) if (!(x)) { ERROR("CHECK failed: " #x); }
|
||||
#define CHECK_LOAD(handle, x) x##_ = DLSYM(handle, #x); CHECK(x##_)
|
||||
|
||||
typedef struct {
|
||||
void* ctx;
|
||||
} mlx_dynamic_handle;
|
||||
|
||||
int mlx_dynamic_load(
|
||||
mlx_dynamic_handle* handle,
|
||||
const char *path);
|
||||
|
||||
void mlx_dynamic_unload(
|
||||
mlx_dynamic_handle* handle);
|
||||
|
||||
#endif // MLX_DYNAMIC_H
|
||||
72
x/mlxrunner/mlx/fast.go
Normal file
72
x/mlxrunner/mlx/fast.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *Array {
|
||||
if mask == nil {
|
||||
mask = New("")
|
||||
}
|
||||
|
||||
sinks := New("")
|
||||
|
||||
mode := "causal"
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
out := New("FAST_SDPA", query, key, value, mask, sinks)
|
||||
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
type LayerNorm struct {
|
||||
Weight Array `weight:"weight"`
|
||||
Bias Array `weight:"bias"`
|
||||
}
|
||||
|
||||
func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
|
||||
out := New("FAST_LAYERNORM", x)
|
||||
C.mlx_fast_layer_norm(&out.ctx, x.ctx, r.Weight.ctx, r.Bias.ctx, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
type RMSNorm struct {
|
||||
Weight Array `weight:"weight"`
|
||||
}
|
||||
|
||||
func (r RMSNorm) Forward(x *Array, eps float32) *Array {
|
||||
out := New("FAST_RMSNORM", x)
|
||||
C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
type RoPE struct {
|
||||
Dims int
|
||||
Traditional bool
|
||||
Base float32 `json:"rope_theta"`
|
||||
Scale float32
|
||||
}
|
||||
|
||||
func (r RoPE) Forward(t *Array, offset int) *Array {
|
||||
freqs := New("")
|
||||
out := New("FAST_ROPE", t, freqs)
|
||||
C.mlx_fast_rope(
|
||||
&out.ctx,
|
||||
t.ctx,
|
||||
C.int(r.Dims),
|
||||
C._Bool(r.Traditional),
|
||||
C.mlx_optional_float{
|
||||
value: C.float(r.Base),
|
||||
has_value: C._Bool(func() bool { return r.Base != 0 }()),
|
||||
},
|
||||
C.float(r.Scale),
|
||||
C.int(offset),
|
||||
freqs.ctx,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
6568
x/mlxrunner/mlx/generated.c
Normal file
6568
x/mlxrunner/mlx/generated.c
Normal file
File diff suppressed because it is too large
Load Diff
4872
x/mlxrunner/mlx/generated.h
Normal file
4872
x/mlxrunner/mlx/generated.h
Normal file
File diff suppressed because it is too large
Load Diff
24
x/mlxrunner/mlx/generator/generated.c.gotmpl
Normal file
24
x/mlxrunner/mlx/generator/generated.c.gotmpl
Normal file
@@ -0,0 +1,24 @@
|
||||
// This code is auto-generated; DO NOT EDIT.
|
||||
|
||||
#include "generated.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
{{ range .Functions }}
|
||||
{{ .Type }} (*{{ .Name }}_){{ .Parameters }} = NULL;
|
||||
{{- end }}
|
||||
|
||||
int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
{{- range .Functions }}
|
||||
CHECK_LOAD(handle, {{ .Name }});
|
||||
{{- end }}
|
||||
return 0;
|
||||
}
|
||||
|
||||
{{- range .Functions }}
|
||||
|
||||
{{ .Type }} {{ .Name }}{{ .Parameters }} {{ "{" }}
|
||||
return {{ .Name }}_({{ .Args }});
|
||||
{{ "}" }}
|
||||
{{- end }}
|
||||
20
x/mlxrunner/mlx/generator/generated.h.gotmpl
Normal file
20
x/mlxrunner/mlx/generator/generated.h.gotmpl
Normal file
@@ -0,0 +1,20 @@
|
||||
// This code is auto-generated; DO NOT EDIT.
|
||||
|
||||
#ifndef MLX_GENERATED_H
|
||||
#define MLX_GENERATED_H
|
||||
|
||||
#include "dynamic.h"
|
||||
#include "mlx/c/mlx.h"
|
||||
{{ range .Functions }}
|
||||
#undef {{ .Name }}
|
||||
{{- end }}
|
||||
{{ range .Functions }}
|
||||
extern {{ .Type }} (*{{ .Name }}_){{ .Parameters }};
|
||||
{{- end }}
|
||||
|
||||
int mlx_dynamic_load_symbols(mlx_dynamic_handle handle);
|
||||
{{ range .Functions }}
|
||||
{{ .Type }} {{ .Name }}{{ .Parameters }};
|
||||
{{- end }}
|
||||
|
||||
#endif // MLX_GENERATED_H
|
||||
135
x/mlxrunner/mlx/generator/main.go
Normal file
135
x/mlxrunner/mlx/generator/main.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
tree_sitter "github.com/tree-sitter/go-tree-sitter"
|
||||
tree_sitter_cpp "github.com/tree-sitter/tree-sitter-cpp/bindings/go"
|
||||
)
|
||||
|
||||
//go:embed *.gotmpl
|
||||
var fsys embed.FS
|
||||
|
||||
type Function struct {
|
||||
Type,
|
||||
Name,
|
||||
Parameters,
|
||||
Args string
|
||||
}
|
||||
|
||||
func ParseFunction(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) Function {
|
||||
var fn Function
|
||||
fn.Name = node.ChildByFieldName("declarator").Utf8Text(source)
|
||||
if params := node.ChildByFieldName("parameters"); params != nil {
|
||||
fn.Parameters = params.Utf8Text(source)
|
||||
fn.Args = ParseParameters(params, tc, source)
|
||||
}
|
||||
|
||||
var types []string
|
||||
for node.Parent() != nil && node.Parent().Kind() != "declaration" {
|
||||
if node.Parent().Kind() == "pointer_declarator" {
|
||||
types = append(types, "*")
|
||||
}
|
||||
node = node.Parent()
|
||||
}
|
||||
|
||||
for sibling := node.PrevSibling(); sibling != nil; sibling = sibling.PrevSibling() {
|
||||
types = append(types, sibling.Utf8Text(source))
|
||||
}
|
||||
|
||||
slices.Reverse(types)
|
||||
fn.Type = strings.Join(types, " ")
|
||||
return fn
|
||||
}
|
||||
|
||||
func ParseParameters(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) string {
|
||||
var s []string
|
||||
for _, child := range node.Children(tc) {
|
||||
if child.IsNamed() {
|
||||
child := child.ChildByFieldName("declarator")
|
||||
for child != nil && child.Kind() != "identifier" {
|
||||
if child.Kind() == "parenthesized_declarator" {
|
||||
child = child.Child(1)
|
||||
} else {
|
||||
child = child.ChildByFieldName("declarator")
|
||||
}
|
||||
}
|
||||
|
||||
if child != nil {
|
||||
s = append(s, child.Utf8Text(source))
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Join(s, ", ")
|
||||
}
|
||||
|
||||
func main() {
|
||||
var output string
|
||||
flag.StringVar(&output, "output", ".", "Output directory for generated files")
|
||||
flag.Parse()
|
||||
|
||||
parser := tree_sitter.NewParser()
|
||||
defer parser.Close()
|
||||
|
||||
language := tree_sitter.NewLanguage(tree_sitter_cpp.Language())
|
||||
parser.SetLanguage(language)
|
||||
|
||||
query, _ := tree_sitter.NewQuery(language, `(function_declarator declarator: (identifier)) @func`)
|
||||
defer query.Close()
|
||||
|
||||
qc := tree_sitter.NewQueryCursor()
|
||||
defer qc.Close()
|
||||
|
||||
var funs []Function
|
||||
for _, arg := range flag.Args() {
|
||||
bts, err := os.ReadFile(arg)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error reading file %s: %v\n", arg, err)
|
||||
continue
|
||||
}
|
||||
|
||||
tree := parser.Parse(bts, nil)
|
||||
defer tree.Close()
|
||||
|
||||
tc := tree.Walk()
|
||||
defer tc.Close()
|
||||
|
||||
matches := qc.Matches(query, tree.RootNode(), bts)
|
||||
for match := matches.Next(); match != nil; match = matches.Next() {
|
||||
for _, capture := range match.Captures {
|
||||
funs = append(funs, ParseFunction(&capture.Node, tc, bts))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tmpl, err := template.New("").ParseFS(fsys, "*.gotmpl")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error parsing template: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, tmpl := range tmpl.Templates() {
|
||||
name := filepath.Join(output, strings.TrimSuffix(tmpl.Name(), ".gotmpl"))
|
||||
|
||||
fmt.Println("Generating", name)
|
||||
f, err := os.Create(name)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error creating file %s: %v\n", name, err)
|
||||
continue
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := tmpl.Execute(f, map[string]any{
|
||||
"Functions": funs,
|
||||
}); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error executing template %s: %v\n", tmpl.Name(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
91
x/mlxrunner/mlx/io.go
Normal file
91
x/mlxrunner/mlx/io.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"iter"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"slices"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func Load(path string) iter.Seq2[string, *Array] {
|
||||
return func(yield func(string, *Array) bool) {
|
||||
string2array := C.mlx_map_string_to_array_new()
|
||||
defer C.mlx_map_string_to_array_free(string2array)
|
||||
|
||||
string2string := C.mlx_map_string_to_string_new()
|
||||
defer C.mlx_map_string_to_string_free(string2string)
|
||||
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
cpu := C.mlx_default_cpu_stream_new()
|
||||
defer C.mlx_stream_free(cpu)
|
||||
|
||||
C.mlx_load_safetensors(&string2array, &string2string, cPath, cpu)
|
||||
|
||||
it := C.mlx_map_string_to_array_iterator_new(string2array)
|
||||
defer C.mlx_map_string_to_array_iterator_free(it)
|
||||
|
||||
for {
|
||||
var key *C.char
|
||||
value := C.mlx_array_new()
|
||||
if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 {
|
||||
break
|
||||
}
|
||||
|
||||
name := C.GoString(key)
|
||||
if !yield(name, &Array{ctx: value, desc: tensorDesc{name: name, numRefs: 1000}}) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func LoadAll(root *model.Root, pattern string, states map[string]*Array, afterLoadFuncs []func(*model.Root) error) error {
|
||||
matches, err := root.Glob(pattern)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
weights := make(map[string]*Array)
|
||||
for match := range matches {
|
||||
slog.Debug("Loading weights from", "file", match)
|
||||
maps.Copy(weights, maps.Collect(Load(root.JoinPath("blobs", match))))
|
||||
}
|
||||
|
||||
var numBytes int
|
||||
for name, weight := range states {
|
||||
if _, ok := weights[name]; ok {
|
||||
slog.Debug("Loading weight", "name", name, "weight", weight)
|
||||
*weight = *weights[name]
|
||||
numBytes += weight.NumBytes()
|
||||
}
|
||||
}
|
||||
|
||||
for _, afterLoadFunc := range afterLoadFuncs {
|
||||
if err := afterLoadFunc(root); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
Eval(slices.Collect(maps.Values(states))...)
|
||||
ClearCache()
|
||||
slog.Info("Loaded weights", "count", len(states), "num_bytes", PrettyBytes(numBytes), "memory", Memory{})
|
||||
return nil
|
||||
}
|
||||
|
||||
func UnloadAll(states map[string]*Array) {
|
||||
weights := slices.Collect(maps.Values(states))
|
||||
for _, weight := range weights {
|
||||
weight.desc.numRefs = 0
|
||||
}
|
||||
|
||||
numBytes := Free(weights...)
|
||||
slog.Info("Unloaded weights", "count", len(states), "num_bytes", PrettyBytes(numBytes), "memory", Memory{})
|
||||
}
|
||||
85
x/mlxrunner/mlx/memory.go
Normal file
85
x/mlxrunner/mlx/memory.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func (b Byte) String() string {
|
||||
return strconv.FormatInt(int64(b), 10) + " B"
|
||||
}
|
||||
|
||||
func (b KibiByte) String() string {
|
||||
return strconv.FormatFloat(float64(b)/(1<<10), 'f', 2, 64) + " KiB"
|
||||
}
|
||||
|
||||
func (b MebiByte) String() string {
|
||||
return strconv.FormatFloat(float64(b)/(1<<(2*10)), 'f', 2, 64) + " MiB"
|
||||
}
|
||||
|
||||
func (b GibiByte) String() string {
|
||||
return strconv.FormatFloat(float64(b)/(1<<(3*10)), 'f', 2, 64) + " GiB"
|
||||
}
|
||||
|
||||
func (b TebiByte) String() string {
|
||||
return strconv.FormatFloat(float64(b)/(1<<(4*10)), 'f', 2, 64) + " TiB"
|
||||
}
|
||||
|
||||
func PrettyBytes(n int) fmt.Stringer {
|
||||
switch {
|
||||
case n < 1<<10:
|
||||
return Byte(n)
|
||||
case n < 1<<(2*10):
|
||||
return KibiByte(n)
|
||||
case n < 1<<(3*10):
|
||||
return MebiByte(n)
|
||||
case n < 1<<(4*10):
|
||||
return GibiByte(n)
|
||||
default:
|
||||
return TebiByte(n)
|
||||
}
|
||||
}
|
||||
|
||||
func ActiveMemory() int {
|
||||
var active C.size_t
|
||||
C.mlx_get_active_memory(&active)
|
||||
return int(active)
|
||||
}
|
||||
|
||||
func CacheMemory() int {
|
||||
var cache C.size_t
|
||||
C.mlx_get_cache_memory(&cache)
|
||||
return int(cache)
|
||||
}
|
||||
|
||||
func PeakMemory() int {
|
||||
var peak C.size_t
|
||||
C.mlx_get_peak_memory(&peak)
|
||||
return int(peak)
|
||||
}
|
||||
|
||||
type Memory struct{}
|
||||
|
||||
func (Memory) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.Any("active", PrettyBytes(ActiveMemory())),
|
||||
slog.Any("cache", PrettyBytes(CacheMemory())),
|
||||
slog.Any("peak", PrettyBytes(PeakMemory())),
|
||||
)
|
||||
}
|
||||
|
||||
type (
|
||||
Byte int
|
||||
KibiByte int
|
||||
MebiByte int
|
||||
GibiByte int
|
||||
TebiByte int
|
||||
)
|
||||
|
||||
func ClearCache() {
|
||||
C.mlx_clear_cache()
|
||||
}
|
||||
43
x/mlxrunner/mlx/mlx.go
Normal file
43
x/mlxrunner/mlx/mlx.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package mlx
|
||||
|
||||
//go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist -DCMAKE_BUILD_TYPE=Release
|
||||
//go:generate cmake --build build --parallel
|
||||
//go:generate cmake --install build
|
||||
//go:generate sh -c "go run generator/main.go -output=. ./dist/include/mlx/c/*.h"
|
||||
|
||||
// #cgo CXXFLAGS: -std=c++17
|
||||
// #cgo CPPFLAGS: -I${SRCDIR}/dist/include
|
||||
// #cgo LDFLAGS: -L${SRCDIR}/dist/lib -lstdc++
|
||||
// #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func doEval(outputs []*Array, async bool) {
|
||||
vectorData := make([]C.mlx_array, 0, len(outputs))
|
||||
for _, output := range outputs {
|
||||
if output.Valid() {
|
||||
vectorData = append(vectorData, output.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
if async {
|
||||
C.mlx_async_eval(vector)
|
||||
} else {
|
||||
C.mlx_eval(vector)
|
||||
}
|
||||
}
|
||||
|
||||
func AsyncEval(outputs ...*Array) {
|
||||
doEval(outputs, true)
|
||||
}
|
||||
|
||||
func Eval(outputs ...*Array) {
|
||||
doEval(outputs, false)
|
||||
}
|
||||
30
x/mlxrunner/mlx/nn.go
Normal file
30
x/mlxrunner/mlx/nn.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package mlx
|
||||
|
||||
type Linear struct {
|
||||
Weight Array `weight:"weight"`
|
||||
Bias Array `weight:"bias"`
|
||||
}
|
||||
|
||||
// Forward computes the linear transformation: x @ Weight.T + Bias
|
||||
func (m Linear) Forward(x *Array) *Array {
|
||||
w := m.Weight.Transpose(1, 0)
|
||||
if m.Bias.Valid() {
|
||||
return m.Bias.Addmm(x, w, 1.0, 1.0)
|
||||
}
|
||||
|
||||
return x.Matmul(w)
|
||||
}
|
||||
|
||||
type Embedding struct {
|
||||
Weight Array `weight:"weight"`
|
||||
}
|
||||
|
||||
func (e *Embedding) Forward(indices *Array) *Array {
|
||||
return e.Weight.TakeAxis(indices, 0)
|
||||
}
|
||||
|
||||
func (e *Embedding) AsLinear() Linear {
|
||||
return Linear{
|
||||
Weight: e.Weight,
|
||||
}
|
||||
}
|
||||
192
x/mlxrunner/mlx/ops.go
Normal file
192
x/mlxrunner/mlx/ops.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func (t *Array) Abs() *Array {
|
||||
out := New("ABS", t)
|
||||
C.mlx_abs(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Add(other *Array) *Array {
|
||||
out := New("ADD", t, other)
|
||||
C.mlx_add(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Addmm(a, b *Array, alpha, beta float32) *Array {
|
||||
out := New("ADDMM", t, a, b)
|
||||
C.mlx_addmm(&out.ctx, t.ctx, a.ctx, b.ctx, C.float(alpha), C.float(beta), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Argmax(axis int, keepDims bool) *Array {
|
||||
out := New("ARGMAX", t)
|
||||
C.mlx_argmax_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) ArgpartitionAxis(kth int, axis int) *Array {
|
||||
out := New("ARGPARTITION", t)
|
||||
C.mlx_argpartition_axis(&out.ctx, t.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) AsType(dtype DType) *Array {
|
||||
out := New("AS_TYPE", t)
|
||||
C.mlx_astype(&out.ctx, t.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array {
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
}
|
||||
|
||||
cStrides := make([]C.int64_t, len(strides))
|
||||
for i, s := range strides {
|
||||
cStrides[i] = C.int64_t(s)
|
||||
}
|
||||
|
||||
out := New("AS_STRIDED", t)
|
||||
C.mlx_as_strided(
|
||||
&out.ctx, t.ctx,
|
||||
unsafe.SliceData(cShape), C.size_t(len(shape)),
|
||||
unsafe.SliceData(cStrides), C.size_t(len(strides)),
|
||||
C.size_t(offset),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Concatenate(axis int, others ...*Array) *Array {
|
||||
vectorData := make([]C.mlx_array, len(others)+1)
|
||||
vectorData[0] = t.ctx
|
||||
for i := range others {
|
||||
vectorData[i+1] = others[i].ctx
|
||||
}
|
||||
|
||||
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
out := New("CONCATENATE", t)
|
||||
C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) ExpandDims(axis int) *Array {
|
||||
out := New("EXPAND_DIMS", t)
|
||||
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Logsumexp(keepDims bool) *Array {
|
||||
out := New("LOGSUMEXP", t)
|
||||
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Matmul(other *Array) *Array {
|
||||
out := New("MATMUL", t, other)
|
||||
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Multiply(other *Array) *Array {
|
||||
out := New("MULTIPLY", t, other)
|
||||
C.mlx_multiply(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Negative() *Array {
|
||||
out := New("NEGATIVE", t)
|
||||
C.mlx_negative(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Power(exponent *Array) *Array {
|
||||
out := New("POWER", t, exponent)
|
||||
C.mlx_power(&out.ctx, t.ctx, exponent.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array {
|
||||
out := New("PUT_ALONG_AXIS", t, indices, values)
|
||||
C.mlx_put_along_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Reshape(axes ...int) *Array {
|
||||
cAxes := make([]C.int, len(axes))
|
||||
for i := range axes {
|
||||
cAxes[i] = C.int(axes[i])
|
||||
}
|
||||
|
||||
out := New("RESHAPE", t)
|
||||
C.mlx_reshape(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Sigmoid() *Array {
|
||||
out := New("SIGMOID", t)
|
||||
C.mlx_sigmoid(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Sqrt() *Array {
|
||||
out := New("SQRT", t)
|
||||
C.mlx_sqrt(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Squeeze(axis int) *Array {
|
||||
out := New("SQUEEZE", t)
|
||||
C.mlx_squeeze_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Subtract(other *Array) *Array {
|
||||
out := New("SUBTRACT", t, other)
|
||||
C.mlx_subtract(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) TakeAxis(indices *Array, axis int) *Array {
|
||||
out := New("TAKE_AXIS", t, indices)
|
||||
C.mlx_take_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Tanh() *Array {
|
||||
out := New("TANH", t)
|
||||
C.mlx_tanh(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Transpose(axes ...int) *Array {
|
||||
cAxes := make([]C.int, len(axes))
|
||||
for i, axis := range axes {
|
||||
cAxes[i] = C.int(axis)
|
||||
}
|
||||
|
||||
out := New("TRANSPOSE", t)
|
||||
C.mlx_transpose_axes(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Zeros(dtype DType, shape ...int) *Array {
|
||||
cAxes := make([]C.int, len(shape))
|
||||
for i := range shape {
|
||||
cAxes[i] = C.int(shape[i])
|
||||
}
|
||||
|
||||
t := New("ZEROS")
|
||||
C.mlx_zeros(&t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
return t
|
||||
}
|
||||
11
x/mlxrunner/mlx/random.go
Normal file
11
x/mlxrunner/mlx/random.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
func (t *Array) Categorical(axis int) *Array {
|
||||
key := New("")
|
||||
out := New("", t, key)
|
||||
C.mlx_random_categorical(&out.ctx, t.ctx, C.int(axis), key.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
84
x/mlxrunner/mlx/slice.go
Normal file
84
x/mlxrunner/mlx/slice.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type slice struct {
|
||||
args []int
|
||||
}
|
||||
|
||||
func Slice(args ...int) slice {
|
||||
return slice{args: args}
|
||||
}
|
||||
|
||||
func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
|
||||
if len(slices) != len(dims) {
|
||||
panic("number of slice arguments must match number of tensor dimensions")
|
||||
}
|
||||
|
||||
args := [3][]C.int{
|
||||
make([]C.int, len(slices)),
|
||||
make([]C.int, len(slices)),
|
||||
make([]C.int, len(slices)),
|
||||
}
|
||||
|
||||
for i, s := range slices {
|
||||
switch len(s.args) {
|
||||
case 0:
|
||||
// slice[:]
|
||||
args[0][i] = C.int(0)
|
||||
args[1][i] = C.int(dims[i])
|
||||
args[2][i] = C.int(1)
|
||||
case 1:
|
||||
// slice[i]
|
||||
args[0][i] = C.int(s.args[0])
|
||||
args[1][i] = C.int(s.args[0] + 1)
|
||||
args[2][i] = C.int(1)
|
||||
case 2:
|
||||
// slice[i:j]
|
||||
args[0][i] = C.int(s.args[0])
|
||||
args[1][i] = cmp.Or(C.int(s.args[1]), C.int(dims[i]))
|
||||
args[2][i] = C.int(1)
|
||||
case 3:
|
||||
// slice[i:j:k]
|
||||
args[0][i] = C.int(s.args[0])
|
||||
args[1][i] = cmp.Or(C.int(s.args[1]), C.int(dims[i]))
|
||||
args[2][i] = C.int(s.args[2])
|
||||
default:
|
||||
panic("invalid slice arguments")
|
||||
}
|
||||
}
|
||||
|
||||
return args[0], args[1], args[2]
|
||||
}
|
||||
|
||||
func (t *Array) Slice(slices ...slice) *Array {
|
||||
starts, stops, strides := makeSlices(t.Dims(), slices...)
|
||||
out := New("SLICE", t)
|
||||
C.mlx_slice(
|
||||
&out.ctx, t.ctx,
|
||||
unsafe.SliceData(starts), C.size_t(len(starts)),
|
||||
unsafe.SliceData(stops), C.size_t(len(stops)),
|
||||
unsafe.SliceData(strides), C.size_t(len(strides)),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) SliceUpdate(other *Array, slices ...slice) *Array {
|
||||
starts, stops, strides := makeSlices(t.Dims(), slices...)
|
||||
out := New("SLICE_UPDATE", t, other)
|
||||
C.mlx_slice_update(
|
||||
&out.ctx, t.ctx, other.ctx,
|
||||
unsafe.SliceData(starts), C.size_t(len(starts)),
|
||||
unsafe.SliceData(stops), C.size_t(len(stops)),
|
||||
unsafe.SliceData(strides), C.size_t(len(strides)),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
43
x/mlxrunner/mlx/stream.go
Normal file
43
x/mlxrunner/mlx/stream.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Device struct {
|
||||
ctx C.mlx_device
|
||||
}
|
||||
|
||||
func (d Device) LogValue() slog.Value {
|
||||
str := C.mlx_string_new()
|
||||
defer C.mlx_string_free(str)
|
||||
C.mlx_device_tostring(&str, d.ctx)
|
||||
return slog.StringValue(C.GoString(C.mlx_string_data(str)))
|
||||
}
|
||||
|
||||
var DefaultDevice = sync.OnceValue(func() Device {
|
||||
d := C.mlx_device_new()
|
||||
C.mlx_get_default_device(&d)
|
||||
return Device{d}
|
||||
})
|
||||
|
||||
type Stream struct {
|
||||
ctx C.mlx_stream
|
||||
}
|
||||
|
||||
func (s Stream) LogValue() slog.Value {
|
||||
str := C.mlx_string_new()
|
||||
defer C.mlx_string_free(str)
|
||||
C.mlx_stream_tostring(&str, s.ctx)
|
||||
return slog.StringValue(C.GoString(C.mlx_string_data(str)))
|
||||
}
|
||||
|
||||
var DefaultStream = sync.OnceValue(func() Stream {
|
||||
s := C.mlx_stream_new()
|
||||
C.mlx_get_default_stream(&s, DefaultDevice().ctx)
|
||||
return Stream{s}
|
||||
})
|
||||
8
x/mlxrunner/model/base/cache.go
Normal file
8
x/mlxrunner/model/base/cache.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package base
|
||||
|
||||
import "github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
|
||||
// Cacher is implemented by models that support custom caching mechanisms.
|
||||
type Cacher interface {
|
||||
Cache() []cache.Cache
|
||||
}
|
||||
113
x/mlxrunner/model/base/model.go
Normal file
113
x/mlxrunner/model/base/model.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
type Model interface {
|
||||
// Forward performs a forward pass through the model.
|
||||
Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array
|
||||
|
||||
// NumLayers returns the number of layers in the model.
|
||||
// This is used to initialize caches.
|
||||
// TODO: consider moving cache initialization into the model itself.
|
||||
NumLayers() int
|
||||
}
|
||||
|
||||
type TextGeneration interface {
|
||||
Model
|
||||
Unembed(*mlx.Array) *mlx.Array
|
||||
}
|
||||
|
||||
func Weights(m Model) (map[string]*mlx.Array, []func(*model.Root) error) {
|
||||
mapping := make(map[string]*mlx.Array)
|
||||
var afterLoadFuncs []func(*model.Root) error
|
||||
var fn func(v reflect.Value, tags []string)
|
||||
fn = func(v reflect.Value, tags []string) {
|
||||
t := v.Type()
|
||||
|
||||
if method := v.Addr().MethodByName("AfterLoad"); method.IsValid() {
|
||||
var afterLoadFunc func(*model.Root) error
|
||||
reflect.ValueOf(&afterLoadFunc).Elem().Set(method)
|
||||
afterLoadFuncs = append(afterLoadFuncs, afterLoadFunc)
|
||||
}
|
||||
|
||||
for _, field := range reflect.VisibleFields(t) {
|
||||
if field.IsExported() {
|
||||
tt, vv := field.Type, v.FieldByIndex(field.Index)
|
||||
|
||||
// create local copy so tags are not modified between fields
|
||||
tags := tags
|
||||
if tag := field.Tag.Get("weight"); tag != "" {
|
||||
// TODO: use model.Tag
|
||||
tags = append(tags, tag)
|
||||
}
|
||||
|
||||
if tt == reflect.TypeOf((*mlx.Array)(nil)).Elem() {
|
||||
name := strings.Join(tags, ".")
|
||||
mapping[name] = vv.Addr().Interface().(*mlx.Array)
|
||||
continue
|
||||
}
|
||||
|
||||
switch tt.Kind() {
|
||||
case reflect.Interface:
|
||||
vv = vv.Elem()
|
||||
fallthrough
|
||||
case reflect.Pointer:
|
||||
vv = vv.Elem()
|
||||
fallthrough
|
||||
case reflect.Struct:
|
||||
fn(vv, tags)
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := range vv.Len() {
|
||||
fn(vv.Index(i), append(tags, strconv.Itoa(i)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
fn(reflect.ValueOf(m).Elem(), []string{})
|
||||
return mapping, afterLoadFuncs
|
||||
}
|
||||
|
||||
var m = make(map[string]func(*model.Root) (Model, error))
|
||||
|
||||
func Register(name string, f func(*model.Root) (Model, error)) {
|
||||
if _, exists := m[name]; exists {
|
||||
panic("model already registered: " + name)
|
||||
}
|
||||
|
||||
m[name] = f
|
||||
}
|
||||
|
||||
func New(root *model.Root) (Model, error) {
|
||||
c, err := root.Open("config.json")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
var config struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(c).Decode(&config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
slog.Info("Model architecture", "arch", config.Architectures[0])
|
||||
if f, exists := m[config.Architectures[0]]; exists {
|
||||
return f(root)
|
||||
}
|
||||
|
||||
return nil, errors.New("unknown architecture")
|
||||
}
|
||||
84
x/mlxrunner/model/gemma/3/model.go
Normal file
84
x/mlxrunner/model/gemma/3/model.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package gemma
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
Text TextModel `weight:"language_model"`
|
||||
}
|
||||
|
||||
func (m *Model) NumLayers() int {
|
||||
return len(m.Text.Layers)
|
||||
}
|
||||
|
||||
func (m Model) Cache() []cache.Cache {
|
||||
caches := make([]cache.Cache, m.NumLayers())
|
||||
for i := range caches {
|
||||
if (i+1)%m.Text.Options.SlidingWindowPattern == 0 {
|
||||
caches[i] = cache.NewKVCache()
|
||||
} else {
|
||||
caches[i] = cache.NewRotatingKVCache(m.Text.Options.SlidingWindow)
|
||||
}
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
func (m *Model) Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array {
|
||||
return m.Text.Forward(inputs, cache)
|
||||
}
|
||||
|
||||
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
||||
return m.Text.EmbedTokens.AsLinear().Forward(x)
|
||||
}
|
||||
|
||||
func init() {
|
||||
base.Register("Gemma3ForConditionalGeneration", func(root *model.Root) (base.Model, error) {
|
||||
bts, err := root.ReadFile("config.json")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
Text TextOptions `json:"text_config"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(bts, &opts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opts.Text.NumAttentionHeads = cmp.Or(opts.Text.NumAttentionHeads, 8)
|
||||
opts.Text.NumKeyValueHeads = cmp.Or(opts.Text.NumKeyValueHeads, 4)
|
||||
opts.Text.HeadDim = cmp.Or(opts.Text.HeadDim, 256)
|
||||
opts.Text.RMSNormEps = cmp.Or(opts.Text.RMSNormEps, 1e-6)
|
||||
opts.Text.SlidingWindowPattern = cmp.Or(opts.Text.SlidingWindowPattern, 6)
|
||||
|
||||
// TODO: implement json.Unmarshaler
|
||||
opts.Text.RoPE = map[bool]mlx.RoPE{
|
||||
true: {Dims: opts.Text.HeadDim, Traditional: false, Base: 1_000_000, Scale: 1. / 8.},
|
||||
false: {Dims: opts.Text.HeadDim, Traditional: false, Base: 10_000, Scale: 1},
|
||||
}
|
||||
|
||||
return &Model{
|
||||
Text: TextModel{
|
||||
Layers: make([]TextDecoderLayer, opts.Text.NumHiddenLayers),
|
||||
Options: opts.Text,
|
||||
},
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
type RMSNorm struct {
|
||||
mlx.RMSNorm
|
||||
}
|
||||
|
||||
func (m *RMSNorm) AfterLoad(*model.Root) error {
|
||||
m.Weight.Set(m.Weight.Add(mlx.FromValue(1)))
|
||||
return nil
|
||||
}
|
||||
118
x/mlxrunner/model/gemma/3/text.go
Normal file
118
x/mlxrunner/model/gemma/3/text.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package gemma
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
type TextOptions struct {
|
||||
HiddenSize int `json:"hidden_size"`
|
||||
NumHiddenLayers int `json:"num_hidden_layers"`
|
||||
IntermediateSize int `json:"intermediate_size"`
|
||||
NumAttentionHeads int `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int `json:"num_key_value_heads"`
|
||||
HeadDim int `json:"head_dim"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
SlidingWindow int `json:"sliding_window"`
|
||||
SlidingWindowPattern int `json:"sliding_window_pattern"`
|
||||
|
||||
RoPE map[bool]mlx.RoPE
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
EmbedTokens mlx.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []TextDecoderLayer `weight:"model.layers"`
|
||||
Norm RMSNorm `weight:"model.norm"`
|
||||
|
||||
Options TextOptions
|
||||
}
|
||||
|
||||
func (m TextModel) Forward(inputs *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
B, L := inputs.Dim(0), inputs.Dim(1)
|
||||
hiddenStates := m.EmbedTokens.Forward(inputs)
|
||||
|
||||
hiddenSize := mlx.FromValue(m.Options.HiddenSize).AsType(hiddenStates.DType())
|
||||
hiddenStates = hiddenStates.Multiply(hiddenSize.Sqrt())
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
hiddenStates = layer.Forward(hiddenStates, caches[i], B, L, m.Options.RoPE[(i+1)%m.Options.SlidingWindowPattern == 0], m.Options)
|
||||
}
|
||||
|
||||
hiddenStates = m.Norm.Forward(hiddenStates, m.Options.RMSNormEps)
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type TextDecoderLayer struct {
|
||||
InputNorm RMSNorm `weight:"input_layernorm"`
|
||||
Attention TextAttention `weight:"self_attn"`
|
||||
PostAttnNorm RMSNorm `weight:"post_attention_layernorm"`
|
||||
PreFFNorm RMSNorm `weight:"pre_feedforward_layernorm"`
|
||||
MLP TextMLP `weight:"mlp"`
|
||||
PostFFNorm RMSNorm `weight:"post_feedforward_layernorm"`
|
||||
}
|
||||
|
||||
func (m TextDecoderLayer) Forward(hiddenStates *mlx.Array, cache cache.Cache, B, L int, rope mlx.RoPE, opts TextOptions) *mlx.Array {
|
||||
residual := hiddenStates
|
||||
hiddenStates = m.InputNorm.Forward(hiddenStates, opts.RMSNormEps)
|
||||
hiddenStates = m.Attention.Forward(hiddenStates, cache, B, L, rope, opts)
|
||||
hiddenStates = m.PostAttnNorm.Forward(hiddenStates, opts.RMSNormEps)
|
||||
hiddenStates = hiddenStates.Add(residual)
|
||||
|
||||
residual = hiddenStates
|
||||
hiddenStates = m.PreFFNorm.Forward(hiddenStates, opts.RMSNormEps)
|
||||
hiddenStates = m.MLP.Forward(hiddenStates, opts)
|
||||
hiddenStates = m.PostFFNorm.Forward(hiddenStates, opts.RMSNormEps)
|
||||
hiddenStates = hiddenStates.Add(residual)
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type TextAttention struct {
|
||||
QProj mlx.Linear `weight:"q_proj"`
|
||||
QNorm RMSNorm `weight:"q_norm"`
|
||||
KProj mlx.Linear `weight:"k_proj"`
|
||||
KNorm RMSNorm `weight:"k_norm"`
|
||||
VProj mlx.Linear `weight:"v_proj"`
|
||||
OProj mlx.Linear `weight:"o_proj"`
|
||||
}
|
||||
|
||||
func (m TextAttention) Forward(hiddenStates *mlx.Array, cache cache.Cache, B, L int, rope mlx.RoPE, opts TextOptions) *mlx.Array {
|
||||
query := m.QProj.Forward(hiddenStates)
|
||||
key := m.KProj.Forward(hiddenStates)
|
||||
value := m.VProj.Forward(hiddenStates)
|
||||
|
||||
query = query.AsStrided(
|
||||
[]int{B, opts.NumAttentionHeads, L, opts.HeadDim},
|
||||
[]int{L * opts.NumAttentionHeads * opts.HeadDim, opts.HeadDim, opts.NumAttentionHeads * opts.HeadDim, 1},
|
||||
0)
|
||||
key = key.AsStrided(
|
||||
[]int{B, opts.NumKeyValueHeads, L, opts.HeadDim},
|
||||
[]int{L * opts.NumKeyValueHeads * opts.HeadDim, opts.HeadDim, opts.NumKeyValueHeads * opts.HeadDim, 1},
|
||||
0)
|
||||
value = value.AsStrided(
|
||||
[]int{B, opts.NumKeyValueHeads, L, opts.HeadDim},
|
||||
[]int{L * opts.NumKeyValueHeads * opts.HeadDim, opts.HeadDim, opts.NumKeyValueHeads * opts.HeadDim, 1},
|
||||
0)
|
||||
|
||||
query = m.QNorm.Forward(query, opts.RMSNormEps)
|
||||
key = m.KNorm.Forward(key, opts.RMSNormEps)
|
||||
|
||||
query = rope.Forward(query, cache.Offset())
|
||||
key = rope.Forward(key, cache.Offset())
|
||||
key, value = cache.Update(key, value)
|
||||
|
||||
attention := mlx.ScaledDotProductAttention(query, key, value, nil, 1.0/float32(math.Sqrt(float64(opts.HeadDim))))
|
||||
attention = attention.Transpose(0, 2, 1, 3).Reshape(B, L, -1)
|
||||
return m.OProj.Forward(attention)
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
GateProj mlx.Linear `weight:"gate_proj"`
|
||||
UpProj mlx.Linear `weight:"up_proj"`
|
||||
DownProj mlx.Linear `weight:"down_proj"`
|
||||
}
|
||||
|
||||
func (m TextMLP) Forward(h *mlx.Array, opts TextOptions) *mlx.Array {
|
||||
return m.DownProj.Forward(mlx.GELUApprox(m.GateProj.Forward(h)).Multiply(m.UpProj.Forward(h)))
|
||||
}
|
||||
130
x/mlxrunner/model/llama/model.go
Normal file
130
x/mlxrunner/model/llama/model.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package llama
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
HiddenAct string `json:"hidden_act"`
|
||||
HiddenSize int `json:"hidden_size"`
|
||||
IntermediateSize int `json:"intermediate_size"`
|
||||
NumAttentionHeads int `json:"num_attention_heads"`
|
||||
NumHiddenLayers int `json:"num_hidden_layers"`
|
||||
NumKeyValueHeads int `json:"num_key_value_heads"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
|
||||
mlx.RoPE
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
EmbedTokens mlx.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []Layer `weight:"model.layers"`
|
||||
Norm mlx.RMSNorm `weight:"model.norm"`
|
||||
Output mlx.Linear `weight:"lm_head"`
|
||||
|
||||
Options
|
||||
}
|
||||
|
||||
func (m Model) NumLayers() int {
|
||||
return len(m.Layers)
|
||||
}
|
||||
|
||||
func (m Model) Forward(inputs *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
slog.Debug("Model.forward", "input shape", inputs.Dims(), "m.EmbedTokens", m.EmbedTokens.Weight.Dims())
|
||||
B, L := inputs.Dim(0), inputs.Dim(1)
|
||||
hiddenStates := m.EmbedTokens.Forward(inputs)
|
||||
for i, layer := range m.Layers {
|
||||
hiddenStates = layer.Forward(hiddenStates, caches[i], B, L, m.Options)
|
||||
}
|
||||
hiddenStates = m.Norm.Forward(hiddenStates, m.RMSNormEps)
|
||||
hiddenStates = m.Output.Forward(hiddenStates)
|
||||
slog.Debug("Model.forward", "output shape", hiddenStates.Dims(), "m.Output", m.Output.Weight.Dims())
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
AttentionNorm mlx.RMSNorm `weight:"input_layernorm"`
|
||||
Attention Attention `weight:"self_attn"`
|
||||
MLPNorm mlx.RMSNorm `weight:"post_attention_layernorm"`
|
||||
MLP MLP `weight:"mlp"`
|
||||
}
|
||||
|
||||
func (m Layer) Forward(hiddenStates *mlx.Array, c cache.Cache, B, L int, opts Options) *mlx.Array {
|
||||
residual := hiddenStates
|
||||
hiddenStates = m.AttentionNorm.Forward(hiddenStates, opts.RMSNormEps)
|
||||
hiddenStates = m.Attention.Forward(hiddenStates, c, B, L, opts)
|
||||
hiddenStates = hiddenStates.Add(residual)
|
||||
|
||||
residual = hiddenStates
|
||||
hiddenStates = m.MLPNorm.Forward(hiddenStates, opts.RMSNormEps)
|
||||
hiddenStates = m.MLP.Forward(hiddenStates)
|
||||
hiddenStates = hiddenStates.Add(residual)
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
QueryProj mlx.Linear `weight:"q_proj"`
|
||||
KeyProj mlx.Linear `weight:"k_proj"`
|
||||
ValueProj mlx.Linear `weight:"v_proj"`
|
||||
OutputProj mlx.Linear `weight:"o_proj"`
|
||||
}
|
||||
|
||||
func (m Attention) Forward(hiddenStates *mlx.Array, cache cache.Cache, B, L int, opts Options) *mlx.Array {
|
||||
query := m.QueryProj.Forward(hiddenStates)
|
||||
query = query.Reshape(B, L, opts.NumAttentionHeads, -1).Transpose(0, 2, 1, 3)
|
||||
|
||||
key := m.KeyProj.Forward(hiddenStates)
|
||||
key = key.Reshape(B, L, opts.NumKeyValueHeads, -1).Transpose(0, 2, 1, 3)
|
||||
|
||||
value := m.ValueProj.Forward(hiddenStates)
|
||||
value = value.Reshape(B, L, opts.NumKeyValueHeads, -1).Transpose(0, 2, 1, 3)
|
||||
|
||||
query = opts.RoPE.Forward(query, cache.Offset())
|
||||
key = opts.RoPE.Forward(key, cache.Offset())
|
||||
key, value = cache.Update(key, value)
|
||||
|
||||
attention := mlx.ScaledDotProductAttention(query, key, value, nil, 1.0/float32(math.Sqrt(float64(key.Dim(-1)))))
|
||||
attention = attention.Transpose(0, 2, 1, 3).Reshape(B, L, -1)
|
||||
return m.OutputProj.Forward(attention)
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
Gate mlx.Linear `weight:"gate_proj"`
|
||||
Up mlx.Linear `weight:"up_proj"`
|
||||
Down mlx.Linear `weight:"down_proj"`
|
||||
}
|
||||
|
||||
func (m MLP) Forward(h *mlx.Array) *mlx.Array {
|
||||
return m.Down.Forward(mlx.SILU(m.Gate.Forward(h)).Multiply(m.Up.Forward(h)))
|
||||
}
|
||||
|
||||
func init() {
|
||||
base.Register("MistralForCausalLM", func(root *model.Root) (base.Model, error) {
|
||||
bts, err := root.ReadFile("config.json")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var opts Options
|
||||
// TODO: implement json.Unmarshal for Options
|
||||
if err := json.Unmarshal(bts, &opts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(bts, &opts.RoPE); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Model{
|
||||
Layers: make([]Layer, opts.NumHiddenLayers),
|
||||
Options: opts,
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
6
x/mlxrunner/model/model.go
Normal file
6
x/mlxrunner/model/model.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
_ "github.com/ollama/ollama/x/mlxrunner/model/gemma/3"
|
||||
_ "github.com/ollama/ollama/x/mlxrunner/model/llama"
|
||||
)
|
||||
138
x/mlxrunner/pipeline.go
Normal file
138
x/mlxrunner/pipeline.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
)
|
||||
|
||||
func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
model, ok := r.Model.(base.TextGeneration)
|
||||
if !ok {
|
||||
return errors.New("model does not support causal language modeling")
|
||||
}
|
||||
|
||||
inputs, err := r.Tokenizer.Encode(request.Prompt, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
caches, tokens := r.FindNearestCache(inputs)
|
||||
if len(caches) == 0 {
|
||||
if cacher, ok := model.(base.Cacher); ok {
|
||||
caches = cacher.Cache()
|
||||
} else {
|
||||
caches = make([]cache.Cache, model.NumLayers())
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
total, processed := len(tokens), 0
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
for total-processed > 1 {
|
||||
n := min(2<<10, total-processed-1)
|
||||
temp := model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
||||
defer mlx.Free(temp)
|
||||
mlx.Eval(func() []*mlx.Array {
|
||||
s := make([]*mlx.Array, 2*len(caches))
|
||||
for i, c := range caches {
|
||||
s[2*i], s[2*i+1] = c.State()
|
||||
}
|
||||
return s
|
||||
}()...)
|
||||
processed += n
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
mlx.ClearCache()
|
||||
}
|
||||
|
||||
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
logits := model.Unembed(model.Forward(token.ExpandDims(0), caches))
|
||||
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
||||
|
||||
// TODO: additional logit processing (logit bias, repetition penalty, etc.)
|
||||
|
||||
logprobs := logits.Subtract(logits.Logsumexp(true))
|
||||
return request.Sample(logprobs), logprobs
|
||||
}
|
||||
|
||||
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed))
|
||||
mlx.AsyncEval(sample, logprobs)
|
||||
|
||||
// buffer partial, multibyte unicode
|
||||
var b bytes.Buffer
|
||||
|
||||
now := time.Now()
|
||||
final := Response{PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
|
||||
outputs := make([]int32, 0, request.Options.MaxTokens)
|
||||
for i := range request.Options.MaxTokens {
|
||||
nextSample, nextLogprobs := step(sample)
|
||||
mlx.AsyncEval(nextSample, nextLogprobs)
|
||||
|
||||
if i == 0 {
|
||||
slog.Info("Prompt processing progress", "processed", total, "total", total)
|
||||
mlx.Eval(sample)
|
||||
final.PromptTokensDuration = time.Since(now)
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
output := int32(sample.Int())
|
||||
outputs = append(outputs, output)
|
||||
|
||||
if r.Tokenizer.Is(output, tokenizer.SpecialEOS) {
|
||||
final.Token = int(output)
|
||||
final.DoneReason = 0
|
||||
final.CompletionTokens = i
|
||||
break
|
||||
}
|
||||
|
||||
request.Responses <- Response{
|
||||
Text: r.Decode(output, &b),
|
||||
Token: int(output),
|
||||
}
|
||||
|
||||
mlx.Free(sample, logprobs)
|
||||
if i%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
|
||||
sample, logprobs = nextSample, nextLogprobs
|
||||
}
|
||||
|
||||
mlx.Free(sample, logprobs)
|
||||
final.CompletionTokensDuration = time.Since(now)
|
||||
request.Responses <- final
|
||||
r.InsertCache(append(inputs, outputs...), caches)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
|
||||
token, err := r.Tokenizer.Decode([]int32{sample})
|
||||
if err != nil {
|
||||
slog.Error("Failed to decode tokens", "error", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
if _, err := b.WriteString(token); err != nil {
|
||||
slog.Error("Failed to write token to buffer", "error", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
if text := b.String(); utf8.ValidString(text) {
|
||||
b.Reset()
|
||||
return text
|
||||
} else if b.Len() >= utf8.UTFMax {
|
||||
b.Reset()
|
||||
return text
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
110
x/mlxrunner/runner.go
Normal file
110
x/mlxrunner/runner.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
_ "github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/mlxrunner/sample"
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
TextCompletionsRequest
|
||||
Responses chan Response
|
||||
Pipeline func(Request) error
|
||||
|
||||
sample.Sampler
|
||||
caches []cache.Cache
|
||||
}
|
||||
|
||||
type TextCompletionsRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Options struct {
|
||||
Temperature float32 `json:"temperature"`
|
||||
TopP float32 `json:"top_p"`
|
||||
MinP float32 `json:"min_p"`
|
||||
TopK int `json:"top_k"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
|
||||
// Deprecated: use MaxTokens instead
|
||||
NumPredict int `json:"num_predict"`
|
||||
} `json:"options"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Text string `json:"content,omitempty"`
|
||||
Token int `json:"token,omitempty"`
|
||||
Logprobs []float32 `json:"logprobs,omitempty"`
|
||||
Done bool `json:"done,omitempty"`
|
||||
DoneReason int `json:"done_reason,omitempty"`
|
||||
|
||||
PromptTokens int `json:"prompt_eval_count,omitempty"`
|
||||
PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
||||
CompletionTokens int `json:"eval_count,omitempty"`
|
||||
CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"`
|
||||
TotalTokens int `json:"total_tokens,omitempty"`
|
||||
}
|
||||
|
||||
type Runner struct {
|
||||
Model base.Model
|
||||
Tokenizer tokenizer.Tokenizer
|
||||
Requests chan Request
|
||||
CacheEntries map[int32]*CacheEntry
|
||||
}
|
||||
|
||||
func (r *Runner) Load(name model.Name) (err error) {
|
||||
root, err := model.Open(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer root.Close()
|
||||
|
||||
r.Model, err = base.New(root)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.Tokenizer, err = tokenizer.New(root)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
weights, afterLoadFuncs := base.Weights(r.Model)
|
||||
return mlx.LoadAll(root, "model*.safetensors", weights, afterLoadFuncs)
|
||||
}
|
||||
|
||||
func (r *Runner) Run(host, port string, mux http.Handler) error {
|
||||
g, ctx := errgroup.WithContext(context.Background())
|
||||
|
||||
g.Go(func() error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case request := <-r.Requests:
|
||||
if err := request.Pipeline(request); err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
close(request.Responses)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
slog.Info("Starting HTTP server", "host", host, "port", port)
|
||||
return http.ListenAndServe(net.JoinHostPort(host, port), mux)
|
||||
})
|
||||
|
||||
return g.Wait()
|
||||
}
|
||||
75
x/mlxrunner/sample/sample.go
Normal file
75
x/mlxrunner/sample/sample.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
type Sampler interface {
|
||||
Sample(*mlx.Array) *mlx.Array
|
||||
}
|
||||
|
||||
func New(temp, top_p, min_p float32, top_k int) Sampler {
|
||||
if temp == 0 {
|
||||
return greedy{}
|
||||
}
|
||||
|
||||
var samplers []Sampler
|
||||
if top_p > 0 && top_p < 1 {
|
||||
samplers = append(samplers, TopP(top_p))
|
||||
}
|
||||
|
||||
if min_p != 0 {
|
||||
samplers = append(samplers, MinP(min_p))
|
||||
}
|
||||
|
||||
if top_k > 0 {
|
||||
samplers = append(samplers, TopK(top_k))
|
||||
}
|
||||
|
||||
samplers = append(samplers, Temperature(temp))
|
||||
return chain(samplers)
|
||||
}
|
||||
|
||||
type greedy struct{}
|
||||
|
||||
func (greedy) Sample(logits *mlx.Array) *mlx.Array {
|
||||
return logits.Argmax(-1, false)
|
||||
}
|
||||
|
||||
type chain []Sampler
|
||||
|
||||
func (c chain) Sample(logits *mlx.Array) *mlx.Array {
|
||||
for _, sampler := range c {
|
||||
logits = sampler.Sample(logits)
|
||||
}
|
||||
return logits
|
||||
}
|
||||
|
||||
type Temperature float32
|
||||
|
||||
func (t Temperature) Sample(logits *mlx.Array) *mlx.Array {
|
||||
return logits.Multiply(mlx.FromValue(1 / float32(t))).Categorical(-1)
|
||||
}
|
||||
|
||||
type TopP float32
|
||||
|
||||
func (p TopP) Sample(logprobs *mlx.Array) *mlx.Array {
|
||||
// TODO: implement
|
||||
return logprobs
|
||||
}
|
||||
|
||||
type MinP float32
|
||||
|
||||
func (p MinP) Sample(logprobs *mlx.Array) *mlx.Array {
|
||||
// TODO: implement
|
||||
return logprobs
|
||||
}
|
||||
|
||||
type TopK int
|
||||
|
||||
func (k TopK) Sample(logprobs *mlx.Array) *mlx.Array {
|
||||
mask := logprobs.Negative().ArgpartitionAxis(int(k)-1, -1).Slice(mlx.Slice(), mlx.Slice(int(k), 0))
|
||||
return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
|
||||
}
|
||||
180
x/mlxrunner/server.go
Normal file
180
x/mlxrunner/server.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/sample"
|
||||
)
|
||||
|
||||
func Execute(args []string) error {
|
||||
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
|
||||
|
||||
var (
|
||||
name model.Name
|
||||
port int
|
||||
)
|
||||
|
||||
flagSet := flag.NewFlagSet("mlxrunner", flag.ExitOnError)
|
||||
flagSet.Var(&name, "model", "Model name")
|
||||
flagSet.IntVar(&port, "port", 0, "Port to listen on")
|
||||
_ = flagSet.Bool("verbose", false, "Enable debug logging")
|
||||
flagSet.Parse(args)
|
||||
|
||||
runner := Runner{
|
||||
Requests: make(chan Request),
|
||||
CacheEntries: make(map[int32]*CacheEntry),
|
||||
}
|
||||
|
||||
if err := runner.Load(name); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": 0,
|
||||
"progress": 100,
|
||||
}); err != nil {
|
||||
slog.Error("Failed to encode response", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
mux.HandleFunc("/v1/models", func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case "POST":
|
||||
fallthrough
|
||||
case "GET":
|
||||
if err := json.NewEncoder(w).Encode(map[string]any{
|
||||
"Success": true,
|
||||
}); err != nil {
|
||||
slog.Error("Failed to encode response", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
case "DELETE":
|
||||
// TODO: cleanup model and cache
|
||||
}
|
||||
})
|
||||
|
||||
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
request := Request{Responses: make(chan Response)}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
|
||||
slog.Error("Failed to decode request", "error", err)
|
||||
http.Error(w, "Bad Request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
|
||||
if request.Options.MaxTokens < 1 {
|
||||
request.Options.MaxTokens = 16 << 10
|
||||
}
|
||||
|
||||
request.Pipeline = runner.TextGenerationPipeline
|
||||
request.Sampler = sample.New(
|
||||
request.Options.Temperature,
|
||||
request.Options.TopP,
|
||||
request.Options.MinP,
|
||||
request.Options.TopK,
|
||||
)
|
||||
|
||||
runner.Requests <- request
|
||||
|
||||
w.Header().Set("Content-Type", "application/jsonl")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
enc := json.NewEncoder(w)
|
||||
for response := range request.Responses {
|
||||
if err := enc.Encode(response); err != nil {
|
||||
slog.Error("Failed to encode response", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
mux.HandleFunc("POST /v1/tokenize", func(w http.ResponseWriter, r *http.Request) {
|
||||
var b bytes.Buffer
|
||||
if _, err := io.Copy(&b, r.Body); err != nil {
|
||||
slog.Error("Failed to read request body", "error", err)
|
||||
http.Error(w, "Bad Request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
tokens, err := runner.Tokenizer.Encode(b.String(), true)
|
||||
if err != nil {
|
||||
slog.Error("Failed to tokenize text", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(tokens); err != nil {
|
||||
slog.Error("Failed to encode response", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
for source, target := range map[string]string{
|
||||
"GET /health": "/v1/status",
|
||||
"POST /load": "/v1/models",
|
||||
"POST /completion": "/v1/completions",
|
||||
} {
|
||||
mux.Handle(source, http.RedirectHandler(target, http.StatusPermanentRedirect))
|
||||
}
|
||||
|
||||
return runner.Run("127.0.0.1", strconv.Itoa(port), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
recorder := &statusRecorder{ResponseWriter: w, code: http.StatusOK}
|
||||
t := time.Now()
|
||||
mux.ServeHTTP(recorder, r)
|
||||
|
||||
var level slog.Level
|
||||
switch {
|
||||
case recorder.code >= 500:
|
||||
level = slog.LevelError
|
||||
case recorder.code >= 400:
|
||||
level = slog.LevelWarn
|
||||
case recorder.code >= 300:
|
||||
return
|
||||
}
|
||||
|
||||
slog.Log(r.Context(), level, "ServeHTTP", "method", r.Method, "path", r.URL.Path, "took", time.Since(t), "status", recorder.Status())
|
||||
}))
|
||||
}
|
||||
|
||||
type statusRecorder struct {
|
||||
http.ResponseWriter
|
||||
code int
|
||||
}
|
||||
|
||||
func (w *statusRecorder) WriteHeader(code int) {
|
||||
w.code = code
|
||||
w.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (w *statusRecorder) Status() string {
|
||||
return strconv.Itoa(w.code) + " " + http.StatusText(w.code)
|
||||
}
|
||||
|
||||
func (w *statusRecorder) Flush() {
|
||||
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user