mirror of
https://github.com/ollama/ollama.git
synced 2026-02-06 21:53:11 -05:00
Compare commits
3 Commits
mxyng/mlx-
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
099a0f18ef | ||
|
|
fff696ee31 | ||
|
|
2e3ce6eab3 |
@@ -147,7 +147,7 @@ ARG PARALLEL
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
COPY x/ml/backend/mlx x/ml/backend/mlx
|
||||
COPY x/imagegen/mlx x/imagegen/mlx
|
||||
COPY go.mod go.sum .
|
||||
COPY MLX_VERSION .
|
||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||
|
||||
@@ -897,11 +897,5 @@ func countContentBlock(block any) int {
|
||||
}
|
||||
}
|
||||
|
||||
if source, ok := blockMap["source"].(map[string]any); ok {
|
||||
if data, ok := source["data"].(string); ok {
|
||||
total += len(data)
|
||||
}
|
||||
}
|
||||
|
||||
return total
|
||||
}
|
||||
|
||||
@@ -312,7 +312,7 @@ Parallel request processing for a given model results in increasing the context
|
||||
The following server settings may be used to adjust how Ollama handles concurrent requests on most platforms:
|
||||
|
||||
- `OLLAMA_MAX_LOADED_MODELS` - The maximum number of models that can be loaded concurrently provided they fit in available memory. The default is 3 \* the number of GPUs or 3 for CPU inference.
|
||||
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
|
||||
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time, default 1. Required RAM will scale by `OLLAMA_NUM_PARALLEL` * `OLLAMA_CONTEXT_LENGTH`.
|
||||
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
|
||||
|
||||
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
|
||||
|
||||
5
go.mod
5
go.mod
@@ -13,7 +13,7 @@ require (
|
||||
github.com/mattn/go-sqlite3 v1.14.24
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/x448/float16 v0.8.4
|
||||
golang.org/x/sync v0.17.0
|
||||
golang.org/x/sys v0.37.0
|
||||
@@ -29,8 +29,6 @@ require (
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0
|
||||
github.com/tree-sitter/go-tree-sitter v0.25.0
|
||||
github.com/tree-sitter/tree-sitter-cpp v0.23.4
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8
|
||||
golang.org/x/image v0.22.0
|
||||
golang.org/x/mod v0.30.0
|
||||
@@ -52,7 +50,6 @@ require (
|
||||
github.com/google/flatbuffers v24.3.25+incompatible // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/mailru/easyjson v0.7.7 // indirect
|
||||
github.com/mattn/go-pointer v0.0.1 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
|
||||
31
go.sum
31
go.sum
@@ -152,8 +152,6 @@ github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0
|
||||
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-pointer v0.0.1 h1:n+XhsuGeVO6MEAp7xyEukFINEa+Quek5psIR/ylA6o0=
|
||||
github.com/mattn/go-pointer v0.0.1/go.mod h1:2zXcozF6qYGgmsG+SeTZz3oAbFLdD3OWqnUbNvJZAlc=
|
||||
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
|
||||
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
@@ -208,39 +206,12 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tkrajina/go-reflector v0.5.5 h1:gwoQFNye30Kk7NrExj8zm3zFtrGPqOkzFMLuQZg1DtQ=
|
||||
github.com/tkrajina/go-reflector v0.5.5/go.mod h1:ECbqLgccecY5kPmPmXg1MrHW585yMcDkVl6IvJe64T4=
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0 h1:ZedWk82egydDspGTryAatbX0/1NZDQbdiZLoCbOk4f8=
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0/go.mod h1:sjU00nti/PMEOZb07KljFlR+lJ+RotsC0GBQMv9EKls=
|
||||
github.com/tree-sitter/go-tree-sitter v0.25.0 h1:sx6kcg8raRFCvc9BnXglke6axya12krCJF5xJ2sftRU=
|
||||
github.com/tree-sitter/go-tree-sitter v0.25.0/go.mod h1:r77ig7BikoZhHrrsjAnv8RqGti5rtSyvDHPzgTPsUuU=
|
||||
github.com/tree-sitter/tree-sitter-c v0.23.4 h1:nBPH3FV07DzAD7p0GfNvXM+Y7pNIoPenQWBpvM++t4c=
|
||||
github.com/tree-sitter/tree-sitter-c v0.23.4/go.mod h1:MkI5dOiIpeN94LNjeCp8ljXN/953JCwAby4bClMr6bw=
|
||||
github.com/tree-sitter/tree-sitter-cpp v0.23.4 h1:LaWZsiqQKvR65yHgKmnaqA+uz6tlDJTJFCyFIeZU/8w=
|
||||
github.com/tree-sitter/tree-sitter-cpp v0.23.4/go.mod h1:doqNW64BriC7WBCQ1klf0KmJpdEvfxyXtoEybnBo6v8=
|
||||
github.com/tree-sitter/tree-sitter-embedded-template v0.23.2 h1:nFkkH6Sbe56EXLmZBqHHcamTpmz3TId97I16EnGy4rg=
|
||||
github.com/tree-sitter/tree-sitter-embedded-template v0.23.2/go.mod h1:HNPOhN0qF3hWluYLdxWs5WbzP/iE4aaRVPMsdxuzIaQ=
|
||||
github.com/tree-sitter/tree-sitter-go v0.23.4 h1:yt5KMGnTHS+86pJmLIAZMWxukr8W7Ae1STPvQUuNROA=
|
||||
github.com/tree-sitter/tree-sitter-go v0.23.4/go.mod h1:Jrx8QqYN0v7npv1fJRH1AznddllYiCMUChtVjxPK040=
|
||||
github.com/tree-sitter/tree-sitter-html v0.23.2 h1:1UYDV+Yd05GGRhVnTcbP58GkKLSHHZwVaN+lBZV11Lc=
|
||||
github.com/tree-sitter/tree-sitter-html v0.23.2/go.mod h1:gpUv/dG3Xl/eebqgeYeFMt+JLOY9cgFinb/Nw08a9og=
|
||||
github.com/tree-sitter/tree-sitter-java v0.23.5 h1:J9YeMGMwXYlKSP3K4Us8CitC6hjtMjqpeOf2GGo6tig=
|
||||
github.com/tree-sitter/tree-sitter-java v0.23.5/go.mod h1:NRKlI8+EznxA7t1Yt3xtraPk1Wzqh3GAIC46wxvc320=
|
||||
github.com/tree-sitter/tree-sitter-javascript v0.23.1 h1:1fWupaRC0ArlHJ/QJzsfQ3Ibyopw7ZfQK4xXc40Zveo=
|
||||
github.com/tree-sitter/tree-sitter-javascript v0.23.1/go.mod h1:lmGD1EJdCA+v0S1u2fFgepMg/opzSg/4pgFym2FPGAs=
|
||||
github.com/tree-sitter/tree-sitter-json v0.24.8 h1:tV5rMkihgtiOe14a9LHfDY5kzTl5GNUYe6carZBn0fQ=
|
||||
github.com/tree-sitter/tree-sitter-json v0.24.8/go.mod h1:F351KK0KGvCaYbZ5zxwx/gWWvZhIDl0eMtn+1r+gQbo=
|
||||
github.com/tree-sitter/tree-sitter-php v0.23.11 h1:iHewsLNDmznh8kgGyfWfujsZxIz1YGbSd2ZTEM0ZiP8=
|
||||
github.com/tree-sitter/tree-sitter-php v0.23.11/go.mod h1:T/kbfi+UcCywQfUNAJnGTN/fMSUjnwPXA8k4yoIks74=
|
||||
github.com/tree-sitter/tree-sitter-python v0.23.6 h1:qHnWFR5WhtMQpxBZRwiaU5Hk/29vGju6CVtmvu5Haas=
|
||||
github.com/tree-sitter/tree-sitter-python v0.23.6/go.mod h1:cpdthSy/Yoa28aJFBscFHlGiU+cnSiSh1kuDVtI8YeM=
|
||||
github.com/tree-sitter/tree-sitter-ruby v0.23.1 h1:T/NKHUA+iVbHM440hFx+lzVOzS4dV6z8Qw8ai+72bYo=
|
||||
github.com/tree-sitter/tree-sitter-ruby v0.23.1/go.mod h1:kUS4kCCQloFcdX6sdpr8p6r2rogbM6ZjTox5ZOQy8cA=
|
||||
github.com/tree-sitter/tree-sitter-rust v0.23.2 h1:6AtoooCW5GqNrRpfnvl0iUhxTAZEovEmLKDbyHlfw90=
|
||||
github.com/tree-sitter/tree-sitter-rust v0.23.2/go.mod h1:hfeGWic9BAfgTrc7Xf6FaOAguCFJRo3RBbs7QJ6D7MI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"github.com/ollama/ollama/runner/llamarunner"
|
||||
"github.com/ollama/ollama/runner/ollamarunner"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/mlxrunner"
|
||||
)
|
||||
|
||||
func Execute(args []string) error {
|
||||
@@ -18,8 +17,6 @@ func Execute(args []string) error {
|
||||
return ollamarunner.Execute(args[1:])
|
||||
case "--imagegen-engine":
|
||||
return imagegen.Execute(args[1:])
|
||||
case "--mlx-engine":
|
||||
return mlxrunner.Execute(args[1:])
|
||||
}
|
||||
}
|
||||
return llamarunner.Execute(args)
|
||||
|
||||
107
server/sched.go
107
server/sched.go
@@ -5,13 +5,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"os"
|
||||
"os/exec"
|
||||
"reflect"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -26,7 +22,6 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/mlxrunner"
|
||||
)
|
||||
|
||||
type LlmRequest struct {
|
||||
@@ -200,14 +195,25 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
|
||||
}
|
||||
|
||||
// Check for experimental safetensors LLM models
|
||||
if pending.model.Config.ModelFormat == "safetensors" {
|
||||
if s.loadSafetensors(pending) {
|
||||
// Check for image generation models - all use MLX runner
|
||||
if slices.Contains(pending.model.Config.Capabilities, "image") {
|
||||
if s.loadMLX(pending) {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for experimental safetensors LLM models
|
||||
if pending.model.Config.ModelFormat == "safetensors" {
|
||||
if slices.Contains(pending.model.Config.Capabilities, "completion") {
|
||||
// LLM model with safetensors format - use MLX runner
|
||||
if s.loadMLX(pending) {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Load model for fitting
|
||||
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
|
||||
ggml, err := llm.LoadModel(pending.model.ModelPath, 1024)
|
||||
@@ -557,90 +563,9 @@ iGPUScan:
|
||||
return false
|
||||
}
|
||||
|
||||
func subproc(args, environ []string) (*exec.Cmd, int, error) {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("unable to lookup executable path: %w", err)
|
||||
}
|
||||
|
||||
for range 3 {
|
||||
// get a random port in the ephemeral range
|
||||
port := rand.Intn(65535-49152) + 49152
|
||||
cmd := exec.Command(exe, slices.Concat([]string{"runner"}, args, []string{"--port", strconv.Itoa(port)})...)
|
||||
cmd.Env = slices.Concat(os.Environ(), environ)
|
||||
cmd.Stdout = os.Stderr
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Start(); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
return cmd, port, nil
|
||||
}
|
||||
|
||||
return nil, 0, fmt.Errorf("unable to start subprocess after multiple attempts")
|
||||
}
|
||||
|
||||
func (s *Scheduler) loadSafetensors(req *LlmRequest) bool {
|
||||
if slices.Contains(req.model.Config.Capabilities, "image") {
|
||||
return s.loadImageGen(req)
|
||||
}
|
||||
|
||||
args := []string{"--mlx-engine", "--model", req.model.ShortName}
|
||||
environ := []string{}
|
||||
cmd, port, err := subproc(args, environ)
|
||||
if err != nil {
|
||||
req.errCh <- fmt.Errorf("failed to start mlx subprocess: %w", err)
|
||||
return true
|
||||
}
|
||||
|
||||
sessionDuration := envconfig.KeepAlive()
|
||||
if req.sessionDuration != nil {
|
||||
sessionDuration = req.sessionDuration.Duration
|
||||
}
|
||||
|
||||
runner := &runnerRef{
|
||||
model: req.model,
|
||||
modelPath: req.model.ModelPath,
|
||||
Options: &req.opts,
|
||||
loading: false,
|
||||
sessionDuration: sessionDuration,
|
||||
llama: &mlxrunner.Client{
|
||||
Cmd: cmd,
|
||||
Port: port,
|
||||
},
|
||||
}
|
||||
|
||||
s.loadedMu.Lock()
|
||||
s.loaded[req.model.ModelPath] = runner
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
runner.refMu.Lock()
|
||||
if sessionDuration > 0 {
|
||||
runner.expireTimer = time.AfterFunc(sessionDuration, func() {
|
||||
s.expiredCh <- runner
|
||||
})
|
||||
}
|
||||
runner.refMu.Unlock()
|
||||
req.useLoadedRunner(runner, s.finishedReqCh)
|
||||
|
||||
for range time.Tick(20 * time.Millisecond) {
|
||||
if err := func() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
return runner.llama.Ping(ctx)
|
||||
}(); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// loadImageGen loads an experimental safetensors model using the unified MLX runner.
|
||||
// loadMLX loads an experimental safetensors model using the unified MLX runner.
|
||||
// This supports both LLM (completion) and image generation models.
|
||||
func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
|
||||
func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
||||
// Determine mode based on capabilities
|
||||
var mode imagegen.ModelMode
|
||||
if slices.Contains(req.model.Config.Capabilities, "image") {
|
||||
|
||||
@@ -1,14 +1,5 @@
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
const (
|
||||
TOKEN_TYPE_NORMAL = iota + 1
|
||||
TOKEN_TYPE_UNKNOWN
|
||||
@@ -24,287 +15,3 @@ type Tokenizer interface {
|
||||
Is(int32, Special) bool
|
||||
Vocabulary() *Vocabulary
|
||||
}
|
||||
|
||||
func New(root *model.Root) (Tokenizer, error) {
|
||||
f, err := root.Open("tokenizer.json")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var tokenizer struct {
|
||||
Model struct {
|
||||
Type string `json:"type"`
|
||||
Vocab map[string]int32 `json:"vocab"`
|
||||
Merges json.RawMessage `json:"merges"`
|
||||
} `json:"model"`
|
||||
|
||||
PreTokenizer json.RawMessage `json:"pre_tokenizer"`
|
||||
Decoder json.RawMessage `json:"decoder"`
|
||||
|
||||
AddedTokens []struct {
|
||||
ID int32 `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Special bool `json:"special"`
|
||||
} `json:"added_tokens"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(f).Decode(&tokenizer); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
special := make(map[int32]struct{})
|
||||
for _, token := range tokenizer.AddedTokens {
|
||||
tokenizer.Model.Vocab[token.Content] = token.ID
|
||||
special[token.ID] = struct{}{}
|
||||
}
|
||||
|
||||
vocab, err := specialTokens(root, tokenizer.Model.Vocab)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
vocab.Values = make([]string, len(tokenizer.Model.Vocab))
|
||||
vocab.Scores = make([]float32, len(tokenizer.Model.Vocab))
|
||||
vocab.Types = make([]int32, len(tokenizer.Model.Vocab))
|
||||
for content, id := range tokenizer.Model.Vocab {
|
||||
vocab.Values[id] = content
|
||||
vocab.Scores[id] = float32(id)
|
||||
vocab.Types[id] = TOKEN_TYPE_NORMAL
|
||||
if _, ok := special[id]; ok {
|
||||
vocab.Types[id] = TOKEN_TYPE_USER_DEFINED
|
||||
}
|
||||
}
|
||||
|
||||
if tokenizer.Model.Merges != nil {
|
||||
var pairs [][]string
|
||||
if err := json.Unmarshal(tokenizer.Model.Merges, &pairs); err == nil {
|
||||
vocab.Merges = make([]string, len(pairs))
|
||||
for i, pair := range pairs {
|
||||
vocab.Merges[i] = pair[0] + " " + pair[1]
|
||||
}
|
||||
} else if err := json.Unmarshal(tokenizer.Model.Merges, &vocab.Merges); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
vocab.valuesOnce.Do(func() {})
|
||||
vocab.values = tokenizer.Model.Vocab
|
||||
|
||||
if tokenizer.Model.Type == "WordPiece" {
|
||||
return NewWordPiece(vocab, true), nil
|
||||
}
|
||||
|
||||
if tokenizer.Decoder != nil {
|
||||
var decoder struct {
|
||||
Type string `json:"type"`
|
||||
Decoders []struct {
|
||||
Type string `json:"type"`
|
||||
Pattern struct {
|
||||
String string `json:"string"`
|
||||
} `json:"pattern"`
|
||||
} `json:"decoders"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(tokenizer.Decoder, &decoder); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if decoder.Type == "Sequence" {
|
||||
for _, d := range decoder.Decoders {
|
||||
if d.Type == "Replace" && d.Pattern.String == "▁" {
|
||||
return NewSentencePiece(vocab), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var pretokenizers []string
|
||||
if tokenizer.PreTokenizer != nil {
|
||||
var pretokenizer struct {
|
||||
Type string `json:"type"`
|
||||
Pretokenizers []struct {
|
||||
Type string `json:"type"`
|
||||
Pattern struct {
|
||||
Regex string
|
||||
} `json:"pattern"`
|
||||
IndividualDigits bool `json:"individual_digits"`
|
||||
}
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(tokenizer.PreTokenizer, &pretokenizer); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if pretokenizer.Type == "Sequence" {
|
||||
for _, pretokenizer := range pretokenizer.Pretokenizers {
|
||||
switch pretokenizer.Type {
|
||||
case "Digits":
|
||||
if pretokenizer.IndividualDigits {
|
||||
pretokenizers = append(pretokenizers, `\d`)
|
||||
} else {
|
||||
pretokenizers = append(pretokenizers, `\d+`)
|
||||
}
|
||||
case "Punctuation":
|
||||
pretokenizers = append(pretokenizers, `[^\p{L}\p{N}]+`)
|
||||
case "Split":
|
||||
pretokenizers = append(pretokenizers, pretokenizer.Pattern.Regex)
|
||||
case "WhitespaceSplit":
|
||||
pretokenizers = append(pretokenizers, `\s+(?!\S)|\s+`)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return NewBytePairEncoding(vocab, pretokenizers...), nil
|
||||
}
|
||||
|
||||
// valueOrValues is a type that can unmarshal from either a single value or an array of values.
|
||||
type valueOrValues[E any] []E
|
||||
|
||||
func (m *valueOrValues[E]) UnmarshalJSON(data []byte) error {
|
||||
var s []E
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
var e E
|
||||
if err := json.Unmarshal(data, &e); err != nil {
|
||||
return err
|
||||
}
|
||||
s = []E{e}
|
||||
}
|
||||
*m = valueOrValues[E](s)
|
||||
return nil
|
||||
}
|
||||
|
||||
type specialTokenIDs struct {
|
||||
BOSTokenID valueOrValues[int32] `json:"bos_token_id"`
|
||||
EOSTokenID valueOrValues[int32] `json:"eos_token_id"`
|
||||
}
|
||||
|
||||
// stringOrContent is a type that can unmarshal from either a string or an object with a "content" field.
|
||||
type stringOrContent string
|
||||
|
||||
func (t *stringOrContent) UnmarshalJSON(data []byte) error {
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return err
|
||||
}
|
||||
if content, ok := m["content"].(string); ok {
|
||||
s = content
|
||||
}
|
||||
}
|
||||
*t = stringOrContent(s)
|
||||
return nil
|
||||
}
|
||||
|
||||
func specialTokens(root *model.Root, values map[string]int32) (*Vocabulary, error) {
|
||||
var vocab Vocabulary
|
||||
for _, c := range []struct {
|
||||
name string
|
||||
fn func(io.Reader) error
|
||||
}{
|
||||
{
|
||||
name: "generation_config.json",
|
||||
fn: func(r io.Reader) error {
|
||||
var c specialTokenIDs
|
||||
if err := json.NewDecoder(r).Decode(&c); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
vocab.BOS = c.BOSTokenID
|
||||
vocab.EOS = c.EOSTokenID
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "config.json",
|
||||
fn: func(r io.Reader) error {
|
||||
var c specialTokenIDs
|
||||
if err := json.NewDecoder(r).Decode(&c); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(vocab.BOS) == 0 {
|
||||
vocab.BOS = c.BOSTokenID
|
||||
}
|
||||
|
||||
if len(vocab.EOS) == 0 {
|
||||
vocab.EOS = c.EOSTokenID
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tokenizer_config.json",
|
||||
fn: func(r io.Reader) error {
|
||||
var c struct {
|
||||
BOSToken stringOrContent `json:"bos_token"`
|
||||
EOSToken stringOrContent `json:"eos_token"`
|
||||
PADToken stringOrContent `json:"pad_token"`
|
||||
AddBOSToken bool `json:"add_bos_token"`
|
||||
AddEOSToken bool `json:"add_eos_token"`
|
||||
}
|
||||
if err := json.NewDecoder(r).Decode(&c); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(vocab.BOS) == 0 && c.BOSToken != "" {
|
||||
if id, ok := values[string(c.BOSToken)]; ok {
|
||||
vocab.BOS = []int32{id}
|
||||
}
|
||||
}
|
||||
|
||||
if len(vocab.EOS) == 0 && c.EOSToken != "" {
|
||||
if id, ok := values[string(c.EOSToken)]; ok {
|
||||
vocab.EOS = []int32{id}
|
||||
}
|
||||
}
|
||||
|
||||
vocab.AddBOS = c.AddBOSToken
|
||||
vocab.AddEOS = c.AddEOSToken
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "special_tokens_map.json",
|
||||
fn: func(r io.Reader) error {
|
||||
var c map[string]stringOrContent
|
||||
if err := json.NewDecoder(r).Decode(&c); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if bos, ok := c["bos_token"]; ok && len(vocab.BOS) == 0 {
|
||||
if id, ok := values[string(bos)]; ok {
|
||||
vocab.BOS = []int32{id}
|
||||
}
|
||||
}
|
||||
|
||||
if eos, ok := c["eos_token"]; ok && len(vocab.EOS) == 0 {
|
||||
if id, ok := values[string(eos)]; ok {
|
||||
vocab.EOS = []int32{id}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
} {
|
||||
if err := func() error {
|
||||
f, err := root.Open(c.name)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
return c.fn(f)
|
||||
}(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &vocab, nil
|
||||
}
|
||||
|
||||
@@ -1,316 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"io/fs"
|
||||
"iter"
|
||||
"maps"
|
||||
"mime"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
func root() (*os.Root, error) {
|
||||
root, err := os.OpenRoot(envconfig.Models())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, sub := range []string{"manifests", "blobs"} {
|
||||
if _, err := root.Stat(sub); errors.Is(err, fs.ErrNotExist) {
|
||||
if err := root.MkdirAll(sub, 0o750); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return root, nil
|
||||
}
|
||||
|
||||
// Open opens an existing file for reading. It will return [fs.ErrNotExist]
|
||||
// if the file does not exist. The returned [*Root] can only be used for reading.
|
||||
// It is the caller's responsibility to close the file when done.
|
||||
func Open(n Name) (*Root, error) {
|
||||
r, err := root()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := r.Open(filepath.Join("manifests", n.Filepath()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var m manifest
|
||||
if err := json.NewDecoder(f).Decode(&m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
blobs := make(map[string]*blob, len(m.Layers)+1)
|
||||
blobs[NamePrefix] = m.Config
|
||||
for _, layer := range m.Layers {
|
||||
if layer.Name == "" && layer.MediaType != "" {
|
||||
mediatype, _, err := mime.ParseMediaType(layer.MediaType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if suffix, ok := strings.CutPrefix(mediatype, MediaTypePrefix); ok {
|
||||
layer.Name = NamePrefix + suffix
|
||||
}
|
||||
}
|
||||
|
||||
blobs[layer.Name] = layer
|
||||
}
|
||||
|
||||
return &Root{
|
||||
root: r,
|
||||
name: n,
|
||||
blobs: blobs,
|
||||
flags: os.O_RDONLY,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Create creates a new file. The returned [Root] can be used for both reading
|
||||
// and writing. It is the caller's responsibility to close the file when done
|
||||
// in order to finalize any new blobs and write the manifest.
|
||||
func Create(n Name) (*Root, error) {
|
||||
r, err := root()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Root{
|
||||
root: r,
|
||||
name: n,
|
||||
blobs: make(map[string]*blob),
|
||||
flags: os.O_RDWR,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type blob struct {
|
||||
Digest string `json:"digest"`
|
||||
MediaType string `json:"mediaType"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Size int64 `json:"size"`
|
||||
|
||||
// tempfile is the temporary file where the blob data is written.
|
||||
tempfile *os.File
|
||||
|
||||
// hash is the hash.Hash used to compute the blob digest.
|
||||
hash hash.Hash
|
||||
}
|
||||
|
||||
func (b *blob) Write(p []byte) (int, error) {
|
||||
return io.MultiWriter(b.tempfile, b.hash).Write(p)
|
||||
}
|
||||
|
||||
func (b *blob) Filepath() string {
|
||||
return strings.ReplaceAll(b.Digest, ":", "-")
|
||||
}
|
||||
|
||||
type manifest struct {
|
||||
SchemaVersion int `json:"schemaVersion"`
|
||||
MediaType string `json:"mediaType"`
|
||||
Config *blob `json:"config"`
|
||||
Layers []*blob `json:"layers"`
|
||||
}
|
||||
|
||||
// Root represents a model file. It can be used to read and write blobs
|
||||
// associated with the model.
|
||||
//
|
||||
// Blobs are identified by name. Certain names are special and reserved;
|
||||
// see [NamePrefix] for details.
|
||||
type Root struct {
|
||||
root *os.Root
|
||||
name Name
|
||||
blobs map[string]*blob
|
||||
flags int
|
||||
}
|
||||
|
||||
const MediaTypePrefix = "application/vnd.ollama"
|
||||
|
||||
// NamePrefix is the prefix used for identifying special names. Names
|
||||
// with this prefix are idenfitied by their media types:
|
||||
//
|
||||
// - name: NamePrefix + suffix
|
||||
// - mediaType: [MediaTypePrefix] + suffix
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// - name: "./..image.model"
|
||||
// - mediaType: "application/vnd.ollama.image.model"
|
||||
//
|
||||
// NamePrefix by itself identifies the manifest config.
|
||||
const NamePrefix = "./."
|
||||
|
||||
// Open opens the named blob for reading. It is the caller's responsibility
|
||||
// to close the returned [io.ReadCloser] when done. It will return
|
||||
// [fs.ErrNotExist] if the blob does not exist.
|
||||
func (r Root) Open(name string) (io.ReadCloser, error) {
|
||||
if b, ok := r.blobs[name]; ok {
|
||||
r, err := r.root.Open(filepath.Join("blobs", b.Filepath()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
return nil, fs.ErrNotExist
|
||||
}
|
||||
|
||||
func (r Root) ReadFile(name string) ([]byte, error) {
|
||||
f, err := r.Open(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
return io.ReadAll(f)
|
||||
}
|
||||
|
||||
// Create creates or replaces a named blob in the file. If the blob already
|
||||
// exists, it will be overwritten. It will return [fs.ErrInvalid] if the file
|
||||
// was opened in read-only mode. The returned [io.Writer] can be used to write
|
||||
// to the blob and does not need be closed, but the file must be closed to
|
||||
// finalize the blob.
|
||||
func (r *Root) Create(name string) (io.Writer, error) {
|
||||
if r.flags&os.O_RDWR != 0 {
|
||||
w, err := os.CreateTemp(r.root.Name(), "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.blobs[name] = &blob{Name: name, tempfile: w, hash: sha256.New()}
|
||||
return r.blobs[name], nil
|
||||
}
|
||||
|
||||
return nil, fs.ErrInvalid
|
||||
}
|
||||
|
||||
// Close closes the file. If the file was opened in read-write mode, it
|
||||
// will finalize any writeable blobs and write the manifest.
|
||||
func (r *Root) Close() error {
|
||||
if r.flags&os.O_RDWR != 0 {
|
||||
for _, b := range r.blobs {
|
||||
if b.tempfile != nil {
|
||||
fi, err := b.tempfile.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := b.tempfile.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b.Size = fi.Size()
|
||||
b.Digest = fmt.Sprintf("sha256:%x", b.hash.Sum(nil))
|
||||
|
||||
if suffix, ok := strings.CutPrefix(b.Name, NamePrefix); ok {
|
||||
if b.Name == NamePrefix {
|
||||
b.MediaType = "application/vnd.docker.container.image.v1+json"
|
||||
} else {
|
||||
b.MediaType = MediaTypePrefix + suffix
|
||||
}
|
||||
b.Name = ""
|
||||
}
|
||||
|
||||
rel, err := filepath.Rel(r.root.Name(), b.tempfile.Name())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.root.Rename(rel, filepath.Join("blobs", b.Filepath())); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p := filepath.Join("manifests", r.name.Filepath())
|
||||
if _, err := r.root.Stat(filepath.Dir(p)); errors.Is(err, os.ErrNotExist) {
|
||||
if err := r.root.MkdirAll(filepath.Dir(p), 0o750); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err := r.root.OpenFile(p, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o640)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := json.NewEncoder(f).Encode(manifest{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||
Config: r.blobs[NamePrefix],
|
||||
Layers: func() []*blob {
|
||||
blobs := make([]*blob, 0, len(r.blobs))
|
||||
for name, b := range r.blobs {
|
||||
if name != NamePrefix {
|
||||
blobs = append(blobs, b)
|
||||
}
|
||||
}
|
||||
return blobs
|
||||
}(),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return r.root.Close()
|
||||
}
|
||||
|
||||
// Name returns the name of the file.
|
||||
func (r Root) Name() Name {
|
||||
return r.name
|
||||
}
|
||||
|
||||
// Names returns an iterator over the names in the file.
|
||||
func (r Root) Names() iter.Seq[string] {
|
||||
return maps.Keys(r.blobs)
|
||||
}
|
||||
|
||||
// Glob returns an iterator over the names in the file that match the given
|
||||
// pattern.
|
||||
//
|
||||
// The pattern syntax is the same as [filepath.Match]. As with filepath.Match,
|
||||
// the only possible returned error is ErrBadPattern, when pattern is malformed.
|
||||
func (r Root) Glob(pattern string) (iter.Seq[string], error) {
|
||||
if _, err := filepath.Match(pattern, ""); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return func(yield func(string) bool) {
|
||||
for name := range r.blobs {
|
||||
if matched, _ := filepath.Match(pattern, name); matched {
|
||||
if !yield(name) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r Root) JoinPath(parts ...string) string {
|
||||
return filepath.Join(append([]string{r.root.Name()}, parts...)...)
|
||||
}
|
||||
|
||||
func (r Root) Real(name string) string {
|
||||
if b, ok := r.blobs[name]; ok {
|
||||
return b.Filepath()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -1,90 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// setup is a helper function to set up the test environment.
|
||||
func setup(t *testing.T, models map[Name]map[string]io.Reader) {
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
|
||||
for m, s := range models {
|
||||
f, err := Create(m)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for n, r := range s {
|
||||
w, err := f.Create(n)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := io.Copy(w, r); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := f.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpen(t *testing.T) {
|
||||
setup(t, map[Name]map[string]io.Reader{
|
||||
ParseName("namespace/model"): {
|
||||
"./.": strings.NewReader(`{"key":"value"}`),
|
||||
},
|
||||
ParseName("namespace/model:8b"): {
|
||||
"./.": strings.NewReader(`{"foo":"bar"}`),
|
||||
},
|
||||
ParseName("another/model"): {
|
||||
"./.": strings.NewReader(`{"another":"config"}`),
|
||||
},
|
||||
})
|
||||
|
||||
f, err := Open(ParseName("namespace/model"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, name := range []string{"./."} {
|
||||
r, err := f.Open(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := io.ReadAll(r); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := r.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := f.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("does not exist", func(t *testing.T) {
|
||||
if _, err := Open(ParseName("namespace/unknown")); err == nil {
|
||||
t.Error("expected error for unknown model")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("write", func(t *testing.T) {
|
||||
f, err := Open(ParseName("namespace/model"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := f.Create("new-blob"); err == nil {
|
||||
t.Error("expected error creating blob in read-only mode")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"iter"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func All() (iter.Seq[Name], error) {
|
||||
r, err := root()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
manifests, err := r.OpenRoot("manifests")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
matches, err := fs.Glob(manifests.FS(), "*/*/*/*")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return func(yield func(Name) bool) {
|
||||
for _, match := range matches {
|
||||
name := ParseNameFromFilepath(filepath.ToSlash(match))
|
||||
if !yield(name) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
@@ -227,17 +227,6 @@ func (n Name) String() string {
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// Set implements [flag.Value]. It parses the provided input as a name string
|
||||
// and sets the receiver to the parsed value. If the parsed name is not valid,
|
||||
// ErrUnqualifiedName is returned.
|
||||
func (n *Name) Set(s string) error {
|
||||
*n = ParseName(s)
|
||||
if !n.IsValid() {
|
||||
return ErrUnqualifiedName
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DisplayShortest returns a short string version of the name.
|
||||
func (n Name) DisplayShortest() string {
|
||||
var sb strings.Builder
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
)
|
||||
|
||||
type CacheEntry struct {
|
||||
Caches []cache.Cache
|
||||
Count int
|
||||
Entries map[int32]*CacheEntry
|
||||
}
|
||||
|
||||
func (s Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
|
||||
current := &CacheEntry{Entries: s.CacheEntries}
|
||||
index, cacheIndex := 0, -1
|
||||
for _, token := range tokens {
|
||||
if _, ok := current.Entries[token]; !ok {
|
||||
break
|
||||
}
|
||||
|
||||
current = current.Entries[token]
|
||||
if len(current.Caches) > 0 {
|
||||
cacheIndex = index
|
||||
}
|
||||
|
||||
index += 1
|
||||
}
|
||||
|
||||
if cacheIndex == len(tokens)-1 {
|
||||
slog.Info("Cache hit", "type", "exact", "total", len(tokens), "cached", len(tokens), "left", len(tokens))
|
||||
return current.Caches, []int32{}
|
||||
} else if cacheIndex > 1 {
|
||||
slog.Info("Cache hit", "type", "partial", "total", len(tokens), "cached", cacheIndex+1, "left", len(tokens[cacheIndex+1:]))
|
||||
return current.Caches, tokens[cacheIndex+1:]
|
||||
} else if index > 0 && cacheIndex < 0 {
|
||||
type stackItem struct {
|
||||
entry *CacheEntry
|
||||
tokens []int32
|
||||
}
|
||||
|
||||
var best, item stackItem
|
||||
stack := []stackItem{{entry: current, tokens: []int32{}}}
|
||||
for len(stack) > 0 {
|
||||
item, stack = stack[len(stack)-1], stack[:len(stack)-1]
|
||||
if len(item.entry.Caches) > 0 {
|
||||
if len(best.tokens) == 0 || len(item.tokens) < len(best.tokens) {
|
||||
best = item
|
||||
}
|
||||
} else {
|
||||
for token, entry := range item.entry.Entries {
|
||||
stack = append(stack, stackItem{
|
||||
entry: entry,
|
||||
tokens: append(item.tokens, token),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prefix := min(len(tokens)-1, index)
|
||||
caches := make([]cache.Cache, len(best.entry.Caches))
|
||||
trim := len(best.tokens)+1
|
||||
for i := range caches {
|
||||
caches[i] = best.entry.Caches[i].Clone()
|
||||
caches[i].Trim(trim)
|
||||
}
|
||||
|
||||
slog.Info("Cache hit", "type", "prefix", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]), "trimmed", trim)
|
||||
return caches, tokens[prefix:]
|
||||
}
|
||||
|
||||
slog.Info("Cache miss", "left", len(tokens))
|
||||
return nil, tokens
|
||||
}
|
||||
|
||||
func (s *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
|
||||
current := &CacheEntry{Entries: s.CacheEntries}
|
||||
for _, token := range tokens {
|
||||
if _, ok := current.Entries[token]; !ok {
|
||||
current.Entries[token] = &CacheEntry{
|
||||
Entries: make(map[int32]*CacheEntry),
|
||||
}
|
||||
}
|
||||
|
||||
current = current.Entries[token]
|
||||
}
|
||||
|
||||
if len(current.Caches) > 0 {
|
||||
current.Count += 1
|
||||
} else {
|
||||
current.Caches = caches
|
||||
}
|
||||
}
|
||||
196
x/mlxrunner/cache/cache.go
vendored
196
x/mlxrunner/cache/cache.go
vendored
@@ -1,196 +0,0 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
type Cache interface {
|
||||
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
|
||||
State() (keys, values *mlx.Array)
|
||||
Trim(int) int
|
||||
Clone() Cache
|
||||
Offset() int
|
||||
Len() int
|
||||
}
|
||||
|
||||
type KVCache struct {
|
||||
keys, values *mlx.Array
|
||||
offset int
|
||||
step int
|
||||
}
|
||||
|
||||
func NewKVCache() *KVCache {
|
||||
return &KVCache{step: 256}
|
||||
}
|
||||
|
||||
func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
|
||||
|
||||
prev := c.offset
|
||||
|
||||
// Grow buffer if needed
|
||||
if c.keys == nil || (prev+L) > c.keys.Dim(2) {
|
||||
steps := (c.step + L - 1) / c.step
|
||||
newKeys := mlx.Zeros(keys.DType(), B, H, steps*c.step, Dk)
|
||||
newValues := mlx.Zeros(values.DType(), B, H, steps*c.step, Dv)
|
||||
|
||||
if c.keys != nil {
|
||||
if prev%c.step != 0 {
|
||||
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice()))
|
||||
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice()))
|
||||
}
|
||||
c.keys.Set(c.keys.Concatenate(2, newKeys))
|
||||
c.values.Set(c.values.Concatenate(2, newValues))
|
||||
} else {
|
||||
c.keys, c.values = newKeys, newValues
|
||||
}
|
||||
}
|
||||
|
||||
c.offset += L
|
||||
c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
|
||||
c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
|
||||
|
||||
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||
}
|
||||
|
||||
func (c *KVCache) State() (*mlx.Array, *mlx.Array) {
|
||||
if c.offset == c.keys.Dim(2) {
|
||||
return c.keys, c.values
|
||||
}
|
||||
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||
}
|
||||
|
||||
func (c *KVCache) Trim(n int) int {
|
||||
n = min(c.offset, n)
|
||||
c.offset -= n
|
||||
return n
|
||||
}
|
||||
|
||||
func (c *KVCache) Clone() Cache {
|
||||
return &KVCache{
|
||||
keys: c.keys.Clone(),
|
||||
values: c.values.Clone(),
|
||||
offset: c.offset,
|
||||
step: c.step,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *KVCache) Offset() int { return c.offset }
|
||||
func (c *KVCache) Len() int { return c.offset }
|
||||
|
||||
// RotatingKVCache implements sliding window attention with bounded memory
|
||||
type RotatingKVCache struct {
|
||||
maxSize int
|
||||
idx int
|
||||
|
||||
*KVCache
|
||||
}
|
||||
|
||||
func NewRotatingKVCache(maxSize int) *RotatingKVCache {
|
||||
return &RotatingKVCache{maxSize: maxSize, KVCache: NewKVCache()}
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
if keys.Dim(2) > 1 {
|
||||
return c.concat(keys, values)
|
||||
}
|
||||
return c.update(keys, values)
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV *mlx.Array) {
|
||||
slog.Debug("(*RotatingKVCache).concat", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
|
||||
if c.keys == nil {
|
||||
c.keys, c.values = keys, values
|
||||
} else {
|
||||
if c.idx < c.keys.Dim(2) {
|
||||
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
||||
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
||||
}
|
||||
|
||||
// Trim to max_size to maintain sliding window
|
||||
if trim := c.idx - c.maxSize + 1; trim > 0 {
|
||||
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.keys.Dim(2)), mlx.Slice()))
|
||||
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.values.Dim(2)), mlx.Slice()))
|
||||
}
|
||||
|
||||
c.keys.Set(c.keys.Concatenate(2, keys))
|
||||
c.values.Set(c.values.Concatenate(2, values))
|
||||
c.idx = c.keys.Dim(2)
|
||||
}
|
||||
|
||||
c.offset += keys.Dim(2)
|
||||
c.idx = c.keys.Dim(2)
|
||||
return c.keys, c.values
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
slog.Debug("(*RotatingKVCache).update", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
|
||||
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
|
||||
|
||||
prev := c.offset
|
||||
|
||||
// Grow buffer if not yet at max
|
||||
if c.keys == nil || (prev >= c.keys.Dim(2) && c.keys.Dim(2) < c.maxSize) {
|
||||
newSize := min(c.step, c.maxSize-prev)
|
||||
newKeys := mlx.Zeros(keys.DType(), B, H, newSize, Dk)
|
||||
newValues := mlx.Zeros(values.DType(), B, H, newSize, Dv)
|
||||
if c.keys != nil {
|
||||
c.keys.Set(c.keys.Concatenate(2, newKeys))
|
||||
c.values.Set(c.values.Concatenate(2, newValues))
|
||||
} else {
|
||||
c.keys, c.values = newKeys, newValues
|
||||
}
|
||||
c.idx = prev
|
||||
}
|
||||
|
||||
// Trim to max_size to maintain sliding window
|
||||
if trim := c.keys.Dim(2) - c.maxSize; trim > 0 {
|
||||
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.keys.Dim(2)), mlx.Slice()))
|
||||
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.values.Dim(2)), mlx.Slice()))
|
||||
c.idx = c.maxSize
|
||||
}
|
||||
|
||||
// Rotate when hitting max
|
||||
if c.idx >= c.maxSize {
|
||||
c.idx = 0
|
||||
}
|
||||
|
||||
c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.idx+L), mlx.Slice()))
|
||||
c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.idx+L), mlx.Slice()))
|
||||
|
||||
c.offset += L
|
||||
c.idx += L
|
||||
|
||||
validLen := min(c.offset, c.maxSize)
|
||||
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice())
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) {
|
||||
if c.offset < c.keys.Dim(2) {
|
||||
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||
}
|
||||
return c.keys, c.values
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Trim(n int) int {
|
||||
n = min(c.offset, n)
|
||||
c.offset -= n
|
||||
c.idx -= n
|
||||
return n
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Clone() Cache {
|
||||
return &RotatingKVCache{
|
||||
maxSize: c.maxSize,
|
||||
idx: c.idx,
|
||||
KVCache: c.KVCache.Clone().(*KVCache),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }
|
||||
@@ -1,174 +0,0 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
Port int
|
||||
*exec.Cmd
|
||||
}
|
||||
|
||||
func (c *Client) JoinPath(path string) string {
|
||||
return (&url.URL{
|
||||
Scheme: "http",
|
||||
Host: net.JoinHostPort("127.0.0.1", strconv.Itoa(c.Port)),
|
||||
}).JoinPath(path).String()
|
||||
}
|
||||
|
||||
func (c *Client) CheckError(w *http.Response) error {
|
||||
if w.StatusCode >= 400 {
|
||||
return errors.New(w.Status)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close implements llm.LlamaServer.
|
||||
func (c *Client) Close() error {
|
||||
return c.Cmd.Process.Kill()
|
||||
}
|
||||
|
||||
// Completion implements llm.LlamaServer.
|
||||
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w, err := http.Post(c.JoinPath("/v1/completions"), "application/json", &b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer w.Body.Close()
|
||||
|
||||
if err := c.CheckError(w); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(w.Body)
|
||||
for scanner.Scan() {
|
||||
bts := scanner.Bytes()
|
||||
|
||||
var resp llm.CompletionResponse
|
||||
if err := json.Unmarshal(bts, &resp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fn(resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) ContextLength() int {
|
||||
return math.MaxInt
|
||||
}
|
||||
|
||||
// Detokenize implements llm.LlamaServer.
|
||||
func (c *Client) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Embedding implements llm.LlamaServer.
|
||||
func (c *Client) Embedding(ctx context.Context, input string) ([]float32, int, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// GetDeviceInfos implements llm.LlamaServer.
|
||||
func (c *Client) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// GetPort implements llm.LlamaServer.
|
||||
func (c *Client) GetPort() int {
|
||||
return c.Port
|
||||
}
|
||||
|
||||
// HasExited implements llm.LlamaServer.
|
||||
func (c *Client) HasExited() bool {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Load implements llm.LlamaServer.
|
||||
func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) ([]ml.DeviceID, error) {
|
||||
w, err := http.Post(c.JoinPath("/v1/models"), "application/json", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer w.Body.Close()
|
||||
|
||||
return []ml.DeviceID{}, nil
|
||||
}
|
||||
|
||||
// ModelPath implements llm.LlamaServer.
|
||||
func (c *Client) ModelPath() string {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Pid implements llm.LlamaServer.
|
||||
func (c *Client) Pid() int {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Ping implements llm.LlamaServer.
|
||||
func (c *Client) Ping(ctx context.Context) error {
|
||||
w, err := http.Get(c.JoinPath("/v1/status"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer w.Body.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Tokenize implements llm.LlamaServer.
|
||||
func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
w, err := http.Post(c.JoinPath("/v1/tokenize"), "text/plain", strings.NewReader(content))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer w.Body.Close()
|
||||
|
||||
var tokens []int
|
||||
if err := json.NewDecoder(w.Body).Decode(&tokens); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// TotalSize implements llm.LlamaServer.
|
||||
func (c *Client) TotalSize() uint64 {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// VRAMByGPU implements llm.LlamaServer.
|
||||
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// VRAMSize implements llm.LlamaServer.
|
||||
func (c *Client) VRAMSize() uint64 {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// WaitUntilRunning implements llm.LlamaServer.
|
||||
func (c *Client) WaitUntilRunning(ctx context.Context) error {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
var _ llm.LlamaServer = (*Client)(nil)
|
||||
3
x/mlxrunner/mlx/.gitignore
vendored
3
x/mlxrunner/mlx/.gitignore
vendored
@@ -1,3 +0,0 @@
|
||||
_deps
|
||||
build
|
||||
dist
|
||||
@@ -1,26 +0,0 @@
|
||||
cmake_minimum_required(VERSION 3.5)
|
||||
|
||||
project(mlx)
|
||||
|
||||
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
|
||||
set(CMAKE_INSTALL_PREFIX "${CMAKE_CURRENT_SOURCE_DIR}/dist" CACHE PATH "" FORCE)
|
||||
endif()
|
||||
|
||||
set(MLX_BUILD_GGUF OFF CACHE BOOL "" FORCE)
|
||||
set(MLX_BUILD_SAFETENSORS ON CACHE BOOL "" FORCE)
|
||||
set(MLX_C_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
|
||||
set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
|
||||
|
||||
set(CMAKE_INSTALL_RPATH "@loader_path")
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
set(MLX_C_GIT_TAG "v0.4.1" CACHE STRING "")
|
||||
|
||||
FetchContent_Declare(
|
||||
mlx-c
|
||||
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
|
||||
GIT_TAG ${MLX_C_GIT_TAG}
|
||||
)
|
||||
|
||||
FetchContent_MakeAvailable(mlx-c)
|
||||
@@ -1,45 +0,0 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var geluApprox = sync.OnceValue(func() *Closure {
|
||||
return Compile(func(inputs []*Array) []*Array {
|
||||
input := inputs[0]
|
||||
return []*Array{
|
||||
input.Multiply(
|
||||
FromValue[float32](0.5),
|
||||
).Multiply(
|
||||
input.Add(
|
||||
input.Power(FromValue[float32](3.0)).Multiply(FromValue[float32](0.044715)),
|
||||
).Multiply(
|
||||
FromValue(float32(math.Sqrt(2 / math.Pi))),
|
||||
).Tanh().Add(FromValue[float32](1.0)),
|
||||
).AsType(input.DType()),
|
||||
}
|
||||
}, true)
|
||||
})
|
||||
|
||||
var silu = sync.OnceValue(func() *Closure {
|
||||
return Compile(func(inputs []*Array) []*Array {
|
||||
input := inputs[0]
|
||||
return []*Array{
|
||||
input.Multiply(
|
||||
input.Sigmoid(),
|
||||
).AsType(input.DType()),
|
||||
}
|
||||
}, true)
|
||||
})
|
||||
|
||||
func GELUApprox(t *Array) *Array {
|
||||
return geluApprox().Call([]*Array{t})[0]
|
||||
}
|
||||
|
||||
func SILU(t *Array) *Array {
|
||||
return silu().Call([]*Array{t})[0]
|
||||
}
|
||||
@@ -1,271 +0,0 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"log/slog"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type tensorDesc struct {
|
||||
name string
|
||||
inputs []*Array
|
||||
numRefs int
|
||||
}
|
||||
|
||||
func (d tensorDesc) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.String("name", d.name),
|
||||
slog.Int("inputs", len(d.inputs)),
|
||||
slog.Int("num_refs", d.numRefs),
|
||||
)
|
||||
}
|
||||
|
||||
type Array struct {
|
||||
ctx C.mlx_array
|
||||
desc tensorDesc
|
||||
}
|
||||
|
||||
// constructor utilities
|
||||
|
||||
func New(name string, inputs ...*Array) *Array {
|
||||
t := &Array{
|
||||
desc: tensorDesc{
|
||||
name: name,
|
||||
inputs: inputs,
|
||||
},
|
||||
}
|
||||
|
||||
for _, input := range inputs {
|
||||
input.desc.numRefs++
|
||||
}
|
||||
logutil.Trace("New", "t", t)
|
||||
return t
|
||||
}
|
||||
|
||||
type scalarTypes interface {
|
||||
~bool | ~int | ~float32 | ~float64 | ~complex64
|
||||
}
|
||||
|
||||
func FromValue[T scalarTypes](t T) *Array {
|
||||
tt := New("")
|
||||
switch v := any(t).(type) {
|
||||
case bool:
|
||||
tt.ctx = C.mlx_array_new_bool(C.bool(v))
|
||||
case int:
|
||||
tt.ctx = C.mlx_array_new_int(C.int(v))
|
||||
case float32:
|
||||
tt.ctx = C.mlx_array_new_float32(C.float(v))
|
||||
case float64:
|
||||
tt.ctx = C.mlx_array_new_float64(C.double(v))
|
||||
case complex64:
|
||||
tt.ctx = C.mlx_array_new_complex(C.float(real(v)), C.float(imag(v)))
|
||||
default:
|
||||
panic("unsupported type")
|
||||
}
|
||||
return tt
|
||||
}
|
||||
|
||||
type arrayTypes interface {
|
||||
~bool | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
|
||||
~int8 | ~int16 | ~int32 | ~int64 |
|
||||
~float32 | ~float64 |
|
||||
~complex64
|
||||
}
|
||||
|
||||
func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
|
||||
if len(shape) == 0 {
|
||||
panic("shape must be provided for non-scalar tensors")
|
||||
}
|
||||
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i := range shape {
|
||||
cShape[i] = C.int(shape[i])
|
||||
}
|
||||
|
||||
var dtype DType
|
||||
switch reflect.TypeOf(s).Elem().Kind() {
|
||||
case reflect.Bool:
|
||||
dtype = DTypeBool
|
||||
case reflect.Uint8:
|
||||
dtype = DTypeUint8
|
||||
case reflect.Uint16:
|
||||
dtype = DTypeUint16
|
||||
case reflect.Uint32:
|
||||
dtype = DTypeUint32
|
||||
case reflect.Uint64:
|
||||
dtype = DTypeUint64
|
||||
case reflect.Int8:
|
||||
dtype = DTypeInt8
|
||||
case reflect.Int16:
|
||||
dtype = DTypeInt16
|
||||
case reflect.Int32:
|
||||
dtype = DTypeInt32
|
||||
case reflect.Int64:
|
||||
dtype = DTypeInt64
|
||||
case reflect.Float32:
|
||||
dtype = DTypeFloat32
|
||||
case reflect.Float64:
|
||||
dtype = DTypeFloat64
|
||||
case reflect.Complex64:
|
||||
dtype = DTypeComplex64
|
||||
default:
|
||||
panic("unsupported type")
|
||||
}
|
||||
|
||||
bts := make([]byte, binary.Size(s))
|
||||
if _, err := binary.Encode(bts, binary.LittleEndian, s); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
tt := New("")
|
||||
tt.ctx = C.mlx_array_new_data(unsafe.Pointer(&bts[0]), unsafe.SliceData(cShape), C.int(len(cShape)), C.mlx_dtype(dtype))
|
||||
return tt
|
||||
}
|
||||
|
||||
func (t *Array) Set(other *Array) {
|
||||
other.desc.numRefs++
|
||||
t.desc.inputs = []*Array{other}
|
||||
C.mlx_array_set(&t.ctx, other.ctx)
|
||||
}
|
||||
|
||||
func (t *Array) Clone() *Array {
|
||||
tt := New(t.desc.name, t.desc.inputs...)
|
||||
C.mlx_array_set(&tt.ctx, t.ctx)
|
||||
return tt
|
||||
}
|
||||
|
||||
// misc. utilities
|
||||
|
||||
func (t *Array) Valid() bool {
|
||||
return t.ctx.ctx != nil
|
||||
}
|
||||
|
||||
func (t *Array) String() string {
|
||||
str := C.mlx_string_new()
|
||||
defer C.mlx_string_free(str)
|
||||
C.mlx_array_tostring(&str, t.ctx)
|
||||
return strings.TrimSpace(C.GoString(C.mlx_string_data(str)))
|
||||
}
|
||||
|
||||
func (t *Array) LogValue() slog.Value {
|
||||
attrs := []slog.Attr{slog.Any("", t.desc)}
|
||||
if t.Valid() {
|
||||
attrs = append(attrs,
|
||||
slog.Any("dtype", t.DType()),
|
||||
slog.Any("shape", t.Dims()),
|
||||
slog.Int("num_bytes", t.NumBytes()),
|
||||
)
|
||||
}
|
||||
return slog.GroupValue(attrs...)
|
||||
}
|
||||
|
||||
// shape utilities
|
||||
|
||||
func (t Array) Size() int {
|
||||
return int(C.mlx_array_size(t.ctx))
|
||||
}
|
||||
|
||||
func (t Array) NumBytes() int {
|
||||
return int(C.mlx_array_nbytes(t.ctx))
|
||||
}
|
||||
|
||||
func (t Array) NumDims() int {
|
||||
return int(C.mlx_array_ndim(t.ctx))
|
||||
}
|
||||
|
||||
func (t Array) Dims() []int {
|
||||
dims := make([]int, t.NumDims())
|
||||
for i := range dims {
|
||||
dims[i] = t.Dim(i)
|
||||
}
|
||||
|
||||
return dims
|
||||
}
|
||||
|
||||
func (t Array) Dim(dim int) int {
|
||||
return int(C.mlx_array_dim(t.ctx, C.int(dim)))
|
||||
}
|
||||
|
||||
func (t Array) DType() DType {
|
||||
return DType(C.mlx_array_dtype(t.ctx))
|
||||
}
|
||||
|
||||
// data utilities
|
||||
|
||||
func (t Array) Int() int {
|
||||
var item C.int64_t
|
||||
C.mlx_array_item_int64(&item, t.ctx)
|
||||
return int(item)
|
||||
}
|
||||
|
||||
func (t Array) Float() float64 {
|
||||
var item C.double
|
||||
C.mlx_array_item_float64(&item, t.ctx)
|
||||
return float64(item)
|
||||
}
|
||||
|
||||
func (t Array) Ints() []int {
|
||||
ints := make([]int, t.Size())
|
||||
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) {
|
||||
ints[i] = int(f)
|
||||
}
|
||||
return ints
|
||||
}
|
||||
|
||||
func (t Array) Floats() []float32 {
|
||||
floats := make([]float32, t.Size())
|
||||
for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) {
|
||||
floats[i] = float32(f)
|
||||
}
|
||||
return floats
|
||||
}
|
||||
|
||||
func (t Array) Save(name string) error {
|
||||
cName := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cName))
|
||||
C.mlx_save(cName, t.ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
func Free(s ...*Array) (n int) {
|
||||
now := time.Now()
|
||||
defer func() {
|
||||
if n > 0 {
|
||||
logutil.Trace("Freed tensors", "num_bytes", PrettyBytes(n), "took", time.Since(now))
|
||||
}
|
||||
}()
|
||||
|
||||
free := make([]*Array, 0, 8192)
|
||||
fn := func(t *Array) {
|
||||
if t.Valid() {
|
||||
free = append(free, t.desc.inputs...)
|
||||
t.desc.numRefs--
|
||||
if t.desc.numRefs <= 0 {
|
||||
logutil.Trace("Free", "t", t)
|
||||
n += t.NumBytes()
|
||||
C.mlx_array_free(t.ctx)
|
||||
t.ctx.ctx = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, t := range s {
|
||||
fn(t)
|
||||
}
|
||||
|
||||
for len(free) > 0 {
|
||||
tail := free[len(free)-1]
|
||||
free = free[:len(free)-1]
|
||||
fn(tail)
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
package mlx
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestFromValue(t *testing.T) {
|
||||
for got, want := range map[*Tensor]DType{
|
||||
FromValue(true): DTypeBool,
|
||||
FromValue(false): DTypeBool,
|
||||
FromValue(int(7)): DTypeInt32,
|
||||
FromValue(float32(3.14)): DTypeFloat32,
|
||||
FromValue(float64(2.71)): DTypeFloat64,
|
||||
FromValue(complex64(1 + 2i)): DTypeComplex64,
|
||||
} {
|
||||
t.Run(want.String(), func(t *testing.T) {
|
||||
if got.DType() != want {
|
||||
t.Errorf("want %v, got %v", want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromValues(t *testing.T) {
|
||||
for got, want := range map[*Tensor]DType{
|
||||
FromValues([]bool{true, false, true}, 3): DTypeBool,
|
||||
FromValues([]uint8{1, 2, 3}, 3): DTypeUint8,
|
||||
FromValues([]uint16{1, 2, 3}, 3): DTypeUint16,
|
||||
FromValues([]uint32{1, 2, 3}, 3): DTypeUint32,
|
||||
FromValues([]uint64{1, 2, 3}, 3): DTypeUint64,
|
||||
FromValues([]int8{-1, -2, -3}, 3): DTypeInt8,
|
||||
FromValues([]int16{-1, -2, -3}, 3): DTypeInt16,
|
||||
FromValues([]int32{-1, -2, -3}, 3): DTypeInt32,
|
||||
FromValues([]int64{-1, -2, -3}, 3): DTypeInt64,
|
||||
FromValues([]float32{3.14, 2.71, 1.61}, 3): DTypeFloat32,
|
||||
FromValues([]float64{3.14, 2.71, 1.61}, 3): DTypeFloat64,
|
||||
FromValues([]complex64{1 + 2i, 3 + 4i, 5 + 6i}, 3): DTypeComplex64,
|
||||
} {
|
||||
t.Run(want.String(), func(t *testing.T) {
|
||||
if got.DType() != want {
|
||||
t.Errorf("want %v, got %v", want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,76 +0,0 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
// int goClosureFunc(mlx_vector_array*, mlx_vector_array, void*);
|
||||
// void goClosureDestructor(void*);
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"runtime/cgo"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type Closure struct {
|
||||
ctx C.mlx_closure
|
||||
}
|
||||
|
||||
func (c Closure) Call(inputs []*Array) []*Array {
|
||||
inputsVector := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(inputsVector)
|
||||
|
||||
for _, input := range inputs {
|
||||
C.mlx_vector_array_append_value(inputsVector, input.ctx)
|
||||
}
|
||||
|
||||
outputsVector := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(outputsVector)
|
||||
|
||||
C.mlx_closure_apply(&outputsVector, c.ctx, inputsVector)
|
||||
|
||||
outputs := make([]*Array, int(C.mlx_vector_array_size(outputsVector)))
|
||||
for i := range outputs {
|
||||
t := New("", inputs...)
|
||||
C.mlx_vector_array_get(&t.ctx, outputsVector, C.size_t(i))
|
||||
outputs[i] = t
|
||||
}
|
||||
|
||||
return outputs
|
||||
}
|
||||
|
||||
func Compile(fn func([]*Array) []*Array, shapeless bool) *Closure {
|
||||
closure := C.mlx_closure_new_func_payload(
|
||||
(*[0]byte)(C.goClosureFunc),
|
||||
unsafe.Pointer(cgo.NewHandle(fn)),
|
||||
(*[0]byte)(C.goClosureDestructor),
|
||||
)
|
||||
|
||||
compiled := C.mlx_closure_new()
|
||||
C.mlx_compile(&compiled, closure, C.bool(shapeless))
|
||||
return &Closure{ctx: compiled}
|
||||
}
|
||||
|
||||
//export goClosureFunc
|
||||
func goClosureFunc(outputsVector *C.mlx_vector_array, inputsVector C.mlx_vector_array, payload unsafe.Pointer) C.int {
|
||||
handle := cgo.Handle(payload)
|
||||
fn := handle.Value().(func([]*Array) []*Array)
|
||||
|
||||
inputs := make([]*Array, int(C.mlx_vector_array_size(inputsVector)))
|
||||
for i := range inputs {
|
||||
t := New("")
|
||||
C.mlx_vector_array_get(&t.ctx, inputsVector, C.size_t(i))
|
||||
inputs[i] = t
|
||||
}
|
||||
|
||||
var outputs []C.mlx_array
|
||||
for _, output := range fn(inputs) {
|
||||
outputs = append(outputs, output.ctx)
|
||||
}
|
||||
|
||||
C.mlx_vector_array_set_data(outputsVector, unsafe.SliceData(outputs), C.size_t(len(outputs)))
|
||||
return 0
|
||||
}
|
||||
|
||||
//export goClosureDestructor
|
||||
func goClosureDestructor(payload unsafe.Pointer) {
|
||||
cgo.Handle(payload).Delete()
|
||||
}
|
||||
@@ -1,94 +0,0 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
type DType int
|
||||
|
||||
func (t DType) String() string {
|
||||
switch t {
|
||||
case DTypeBool:
|
||||
return "BOOL"
|
||||
case DTypeUint8:
|
||||
return "U8"
|
||||
case DTypeUint16:
|
||||
return "U16"
|
||||
case DTypeUint32:
|
||||
return "U32"
|
||||
case DTypeUint64:
|
||||
return "U64"
|
||||
case DTypeInt8:
|
||||
return "I8"
|
||||
case DTypeInt16:
|
||||
return "I16"
|
||||
case DTypeInt32:
|
||||
return "I32"
|
||||
case DTypeInt64:
|
||||
return "I64"
|
||||
case DTypeFloat16:
|
||||
return "F16"
|
||||
case DTypeFloat32:
|
||||
return "F32"
|
||||
case DTypeFloat64:
|
||||
return "F64"
|
||||
case DTypeBFloat16:
|
||||
return "BF16"
|
||||
case DTypeComplex64:
|
||||
return "C64"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (t *DType) UnmarshalJSON(b []byte) error {
|
||||
switch string(b) {
|
||||
case `"BOOL"`:
|
||||
*t = DTypeBool
|
||||
case `"U8"`:
|
||||
*t = DTypeUint8
|
||||
case `"U16"`:
|
||||
*t = DTypeUint16
|
||||
case `"U32"`:
|
||||
*t = DTypeUint32
|
||||
case `"U64"`:
|
||||
*t = DTypeUint64
|
||||
case `"I8"`:
|
||||
*t = DTypeInt8
|
||||
case `"I16"`:
|
||||
*t = DTypeInt16
|
||||
case `"I32"`:
|
||||
*t = DTypeInt32
|
||||
case `"I64"`:
|
||||
*t = DTypeInt64
|
||||
case `"F16"`:
|
||||
*t = DTypeFloat16
|
||||
case `"F64"`:
|
||||
*t = DTypeFloat64
|
||||
case `"F32"`:
|
||||
*t = DTypeFloat32
|
||||
case `"BF16"`:
|
||||
*t = DTypeBFloat16
|
||||
case `"C64"`:
|
||||
*t = DTypeComplex64
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
DTypeBool DType = C.MLX_BOOL
|
||||
DTypeUint8 DType = C.MLX_UINT8
|
||||
DTypeUint16 DType = C.MLX_UINT16
|
||||
DTypeUint32 DType = C.MLX_UINT32
|
||||
DTypeUint64 DType = C.MLX_UINT64
|
||||
DTypeInt8 DType = C.MLX_INT8
|
||||
DTypeInt16 DType = C.MLX_INT16
|
||||
DTypeInt32 DType = C.MLX_INT32
|
||||
DTypeInt64 DType = C.MLX_INT64
|
||||
DTypeFloat16 DType = C.MLX_FLOAT16
|
||||
DTypeFloat32 DType = C.MLX_FLOAT32
|
||||
DTypeFloat64 DType = C.MLX_FLOAT64
|
||||
DTypeBFloat16 DType = C.MLX_BFLOAT16
|
||||
DTypeComplex64 DType = C.MLX_COMPLEX64
|
||||
)
|
||||
@@ -1,34 +0,0 @@
|
||||
#include "dynamic.h"
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
#define DLOPEN(path) LoadLibraryA(path)
|
||||
#define DLCLOSE(handle) FreeLibrary((HMODULE)(handle))
|
||||
#else
|
||||
#ifdef __APPLE__
|
||||
#include <mach-o/dyld.h>
|
||||
#include <libgen.h>
|
||||
#endif
|
||||
#include <dlfcn.h>
|
||||
#define DLOPEN(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
|
||||
#define DLCLOSE(handle) dlclose(handle)
|
||||
#endif
|
||||
|
||||
static int mlx_dynamic_open(mlx_dynamic_handle* handle, const char* path) {
|
||||
handle->ctx = (void*) DLOPEN(path);
|
||||
CHECK(handle->ctx != NULL);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int mlx_dynamic_load(mlx_dynamic_handle* handle, const char *path) {
|
||||
return mlx_dynamic_open(handle, path);
|
||||
}
|
||||
|
||||
void mlx_dynamic_unload(mlx_dynamic_handle* handle) {
|
||||
if (handle->ctx) {
|
||||
DLCLOSE(handle->ctx);
|
||||
handle->ctx = NULL;
|
||||
}
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
package mlx
|
||||
|
||||
// #include "dynamic.h"
|
||||
// #include "generated.h"
|
||||
// #include <stdlib.h>
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func init() {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
|
||||
case "windows":
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH")
|
||||
if !ok {
|
||||
slog.Debug("OLLAMA_LIBRARY_PATH not set, skipping mlx dynamic loading")
|
||||
return
|
||||
}
|
||||
|
||||
for _, path := range filepath.SplitList(paths) {
|
||||
matches, err := fs.Glob(os.DirFS(path), "libmlxc.*")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for _, match := range matches {
|
||||
path := filepath.Join(paths, match)
|
||||
slog.Info("Loading MLX dynamic library", "path", path)
|
||||
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
var handle C.mlx_dynamic_handle
|
||||
if C.mlx_dynamic_load(&handle, cPath) != 0 {
|
||||
slog.Error("Failed to load MLX dynamic library", "path", path)
|
||||
continue
|
||||
}
|
||||
|
||||
if C.mlx_dynamic_load_symbols(handle) != 0 {
|
||||
slog.Error("Failed to load MLX dynamic library symbols", "path", path)
|
||||
C.mlx_dynamic_unload(&handle)
|
||||
continue
|
||||
}
|
||||
|
||||
slog.Info("Loaded MLX dynamic library", "path", path)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
panic("Failed to load any MLX dynamic library")
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
#ifndef MLX_DYNAMIC_H
|
||||
#define MLX_DYNAMIC_H
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
#define DLSYM(handle, symbol) GetProcAddress((HMODULE)(handle), symbol)
|
||||
#else
|
||||
#include <dlfcn.h>
|
||||
#define DLSYM(handle, symbol) dlsym(handle.ctx, symbol)
|
||||
#endif
|
||||
|
||||
#define ERROR(fmt, ...) fprintf(stderr, "%s %s - ERROR - %s:%d - " fmt "\n", __DATE__, __TIME__, __FILE__, __LINE__, ##__VA_ARGS__); return 1
|
||||
#define CHECK(x) if (!(x)) { ERROR("CHECK failed: " #x); }
|
||||
#define CHECK_LOAD(handle, x) x##_ = DLSYM(handle, #x); CHECK(x##_)
|
||||
|
||||
typedef struct {
|
||||
void* ctx;
|
||||
} mlx_dynamic_handle;
|
||||
|
||||
int mlx_dynamic_load(
|
||||
mlx_dynamic_handle* handle,
|
||||
const char *path);
|
||||
|
||||
void mlx_dynamic_unload(
|
||||
mlx_dynamic_handle* handle);
|
||||
|
||||
#endif // MLX_DYNAMIC_H
|
||||
@@ -1,72 +0,0 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *Array {
|
||||
if mask == nil {
|
||||
mask = New("")
|
||||
}
|
||||
|
||||
sinks := New("")
|
||||
|
||||
mode := "causal"
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
out := New("FAST_SDPA", query, key, value, mask, sinks)
|
||||
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
type LayerNorm struct {
|
||||
Weight Array `weight:"weight"`
|
||||
Bias Array `weight:"bias"`
|
||||
}
|
||||
|
||||
func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
|
||||
out := New("FAST_LAYERNORM", x)
|
||||
C.mlx_fast_layer_norm(&out.ctx, x.ctx, r.Weight.ctx, r.Bias.ctx, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
type RMSNorm struct {
|
||||
Weight Array `weight:"weight"`
|
||||
}
|
||||
|
||||
func (r RMSNorm) Forward(x *Array, eps float32) *Array {
|
||||
out := New("FAST_RMSNORM", x)
|
||||
C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
type RoPE struct {
|
||||
Dims int
|
||||
Traditional bool
|
||||
Base float32 `json:"rope_theta"`
|
||||
Scale float32
|
||||
}
|
||||
|
||||
func (r RoPE) Forward(t *Array, offset int) *Array {
|
||||
freqs := New("")
|
||||
out := New("FAST_ROPE", t, freqs)
|
||||
C.mlx_fast_rope(
|
||||
&out.ctx,
|
||||
t.ctx,
|
||||
C.int(r.Dims),
|
||||
C._Bool(r.Traditional),
|
||||
C.mlx_optional_float{
|
||||
value: C.float(r.Base),
|
||||
has_value: C._Bool(func() bool { return r.Base != 0 }()),
|
||||
},
|
||||
C.float(r.Scale),
|
||||
C.int(offset),
|
||||
freqs.ctx,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,24 +0,0 @@
|
||||
// This code is auto-generated; DO NOT EDIT.
|
||||
|
||||
#include "generated.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
{{ range .Functions }}
|
||||
{{ .Type }} (*{{ .Name }}_){{ .Parameters }} = NULL;
|
||||
{{- end }}
|
||||
|
||||
int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
{{- range .Functions }}
|
||||
CHECK_LOAD(handle, {{ .Name }});
|
||||
{{- end }}
|
||||
return 0;
|
||||
}
|
||||
|
||||
{{- range .Functions }}
|
||||
|
||||
{{ .Type }} {{ .Name }}{{ .Parameters }} {{ "{" }}
|
||||
return {{ .Name }}_({{ .Args }});
|
||||
{{ "}" }}
|
||||
{{- end }}
|
||||
@@ -1,20 +0,0 @@
|
||||
// This code is auto-generated; DO NOT EDIT.
|
||||
|
||||
#ifndef MLX_GENERATED_H
|
||||
#define MLX_GENERATED_H
|
||||
|
||||
#include "dynamic.h"
|
||||
#include "mlx/c/mlx.h"
|
||||
{{ range .Functions }}
|
||||
#undef {{ .Name }}
|
||||
{{- end }}
|
||||
{{ range .Functions }}
|
||||
extern {{ .Type }} (*{{ .Name }}_){{ .Parameters }};
|
||||
{{- end }}
|
||||
|
||||
int mlx_dynamic_load_symbols(mlx_dynamic_handle handle);
|
||||
{{ range .Functions }}
|
||||
{{ .Type }} {{ .Name }}{{ .Parameters }};
|
||||
{{- end }}
|
||||
|
||||
#endif // MLX_GENERATED_H
|
||||
@@ -1,135 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
tree_sitter "github.com/tree-sitter/go-tree-sitter"
|
||||
tree_sitter_cpp "github.com/tree-sitter/tree-sitter-cpp/bindings/go"
|
||||
)
|
||||
|
||||
//go:embed *.gotmpl
|
||||
var fsys embed.FS
|
||||
|
||||
type Function struct {
|
||||
Type,
|
||||
Name,
|
||||
Parameters,
|
||||
Args string
|
||||
}
|
||||
|
||||
func ParseFunction(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) Function {
|
||||
var fn Function
|
||||
fn.Name = node.ChildByFieldName("declarator").Utf8Text(source)
|
||||
if params := node.ChildByFieldName("parameters"); params != nil {
|
||||
fn.Parameters = params.Utf8Text(source)
|
||||
fn.Args = ParseParameters(params, tc, source)
|
||||
}
|
||||
|
||||
var types []string
|
||||
for node.Parent() != nil && node.Parent().Kind() != "declaration" {
|
||||
if node.Parent().Kind() == "pointer_declarator" {
|
||||
types = append(types, "*")
|
||||
}
|
||||
node = node.Parent()
|
||||
}
|
||||
|
||||
for sibling := node.PrevSibling(); sibling != nil; sibling = sibling.PrevSibling() {
|
||||
types = append(types, sibling.Utf8Text(source))
|
||||
}
|
||||
|
||||
slices.Reverse(types)
|
||||
fn.Type = strings.Join(types, " ")
|
||||
return fn
|
||||
}
|
||||
|
||||
func ParseParameters(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) string {
|
||||
var s []string
|
||||
for _, child := range node.Children(tc) {
|
||||
if child.IsNamed() {
|
||||
child := child.ChildByFieldName("declarator")
|
||||
for child != nil && child.Kind() != "identifier" {
|
||||
if child.Kind() == "parenthesized_declarator" {
|
||||
child = child.Child(1)
|
||||
} else {
|
||||
child = child.ChildByFieldName("declarator")
|
||||
}
|
||||
}
|
||||
|
||||
if child != nil {
|
||||
s = append(s, child.Utf8Text(source))
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Join(s, ", ")
|
||||
}
|
||||
|
||||
func main() {
|
||||
var output string
|
||||
flag.StringVar(&output, "output", ".", "Output directory for generated files")
|
||||
flag.Parse()
|
||||
|
||||
parser := tree_sitter.NewParser()
|
||||
defer parser.Close()
|
||||
|
||||
language := tree_sitter.NewLanguage(tree_sitter_cpp.Language())
|
||||
parser.SetLanguage(language)
|
||||
|
||||
query, _ := tree_sitter.NewQuery(language, `(function_declarator declarator: (identifier)) @func`)
|
||||
defer query.Close()
|
||||
|
||||
qc := tree_sitter.NewQueryCursor()
|
||||
defer qc.Close()
|
||||
|
||||
var funs []Function
|
||||
for _, arg := range flag.Args() {
|
||||
bts, err := os.ReadFile(arg)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error reading file %s: %v\n", arg, err)
|
||||
continue
|
||||
}
|
||||
|
||||
tree := parser.Parse(bts, nil)
|
||||
defer tree.Close()
|
||||
|
||||
tc := tree.Walk()
|
||||
defer tc.Close()
|
||||
|
||||
matches := qc.Matches(query, tree.RootNode(), bts)
|
||||
for match := matches.Next(); match != nil; match = matches.Next() {
|
||||
for _, capture := range match.Captures {
|
||||
funs = append(funs, ParseFunction(&capture.Node, tc, bts))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tmpl, err := template.New("").ParseFS(fsys, "*.gotmpl")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error parsing template: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, tmpl := range tmpl.Templates() {
|
||||
name := filepath.Join(output, strings.TrimSuffix(tmpl.Name(), ".gotmpl"))
|
||||
|
||||
fmt.Println("Generating", name)
|
||||
f, err := os.Create(name)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error creating file %s: %v\n", name, err)
|
||||
continue
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := tmpl.Execute(f, map[string]any{
|
||||
"Functions": funs,
|
||||
}); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error executing template %s: %v\n", tmpl.Name(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,166 +0,0 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"slices"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func Load(path string) iter.Seq2[string, *Array] {
|
||||
return func(yield func(string, *Array) bool) {
|
||||
string2array := C.mlx_map_string_to_array_new()
|
||||
defer C.mlx_map_string_to_array_free(string2array)
|
||||
|
||||
string2string := C.mlx_map_string_to_string_new()
|
||||
defer C.mlx_map_string_to_string_free(string2string)
|
||||
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
cpu := C.mlx_default_cpu_stream_new()
|
||||
defer C.mlx_stream_free(cpu)
|
||||
|
||||
C.mlx_load_safetensors(&string2array, &string2string, cPath, cpu)
|
||||
|
||||
it := C.mlx_map_string_to_array_iterator_new(string2array)
|
||||
defer C.mlx_map_string_to_array_iterator_free(it)
|
||||
|
||||
for {
|
||||
var key *C.char
|
||||
value := C.mlx_array_new()
|
||||
if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 {
|
||||
break
|
||||
}
|
||||
|
||||
name := C.GoString(key)
|
||||
if !yield(name, &Array{ctx: value, desc: tensorDesc{name: name, numRefs: 1000}}) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Parse(root *model.Root, path string) (map[string]Quantization, error) {
|
||||
f, err := root.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var n uint64
|
||||
if err := binary.Read(f, binary.LittleEndian, &n); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bts := make([]byte, n)
|
||||
if _, err := io.ReadFull(f, bts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var m struct {
|
||||
Metadata struct {
|
||||
Quantization map[string]Quantization `json:"quantization"`
|
||||
} `json:"__metadata__"`
|
||||
}
|
||||
if err := json.Unmarshal(bts, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.Metadata.Quantization, nil
|
||||
}
|
||||
|
||||
func LoadWeights(root *model.Root, match string, states map[string]*Array) error {
|
||||
slog.Debug("Loading weights from", "file", match)
|
||||
for name, weight := range Load(root.JoinPath("blobs", root.Real(match))) {
|
||||
if state, ok := states[name]; ok {
|
||||
*state = *weight
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func LoadQuantizations(root *model.Root, match string, quantizations map[string]*Quantization) error {
|
||||
slog.Debug("Loading quantizations from", "file", match)
|
||||
metadata, err := Parse(root, match)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for name := range metadata {
|
||||
if q, ok := quantizations[name+".weight"]; ok {
|
||||
q.GroupSize = metadata[name].GroupSize
|
||||
q.Bits = metadata[name].Bits
|
||||
q.Mode = metadata[name].Mode
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type AfterLoadFunc func(*model.Root) ([]*Array, error)
|
||||
|
||||
func LoadAll(root *model.Root, states map[string]*Array, quantizations map[string]*Quantization, afterLoadFuncs []AfterLoadFunc) error {
|
||||
matches, err := root.Glob("model*.safetensors")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for match := range matches {
|
||||
if err := errors.Join(
|
||||
LoadWeights(root, match, states),
|
||||
LoadQuantizations(root, match, quantizations),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, afterLoadFunc := range afterLoadFuncs {
|
||||
weights, err := afterLoadFunc(root)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, weight := range weights {
|
||||
weight.desc.numRefs = 1000
|
||||
Eval(weight)
|
||||
|
||||
var freeAll func(...*Array)
|
||||
freeAll = func(inputs ...*Array) {
|
||||
for _, input := range inputs {
|
||||
input.desc.numRefs = 0
|
||||
freeAll(input.desc.inputs...)
|
||||
}
|
||||
Free(inputs...)
|
||||
}
|
||||
|
||||
freeAll(weight.desc.inputs...)
|
||||
}
|
||||
}
|
||||
|
||||
Eval(slices.Collect(maps.Values(states))...)
|
||||
ClearCache()
|
||||
slog.Info("Loaded weights", "count", len(states), "memory", Memory{})
|
||||
return nil
|
||||
}
|
||||
|
||||
func UnloadAll(states map[string]*Array) {
|
||||
weights := slices.Collect(maps.Values(states))
|
||||
for _, weight := range weights {
|
||||
weight.desc.numRefs = 0
|
||||
}
|
||||
|
||||
numBytes := Free(weights...)
|
||||
slog.Info("Unloaded weights", "count", len(states), "num_bytes", PrettyBytes(numBytes), "memory", Memory{})
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func (b Byte) String() string {
|
||||
return strconv.FormatInt(int64(b), 10) + " B"
|
||||
}
|
||||
|
||||
func (b KibiByte) String() string {
|
||||
return strconv.FormatFloat(float64(b)/(1<<10), 'f', 2, 64) + " KiB"
|
||||
}
|
||||
|
||||
func (b MebiByte) String() string {
|
||||
return strconv.FormatFloat(float64(b)/(1<<(2*10)), 'f', 2, 64) + " MiB"
|
||||
}
|
||||
|
||||
func (b GibiByte) String() string {
|
||||
return strconv.FormatFloat(float64(b)/(1<<(3*10)), 'f', 2, 64) + " GiB"
|
||||
}
|
||||
|
||||
func (b TebiByte) String() string {
|
||||
return strconv.FormatFloat(float64(b)/(1<<(4*10)), 'f', 2, 64) + " TiB"
|
||||
}
|
||||
|
||||
func PrettyBytes(n int) fmt.Stringer {
|
||||
switch {
|
||||
case n < 1<<10:
|
||||
return Byte(n)
|
||||
case n < 1<<(2*10):
|
||||
return KibiByte(n)
|
||||
case n < 1<<(3*10):
|
||||
return MebiByte(n)
|
||||
case n < 1<<(4*10):
|
||||
return GibiByte(n)
|
||||
default:
|
||||
return TebiByte(n)
|
||||
}
|
||||
}
|
||||
|
||||
func ActiveMemory() int {
|
||||
var active C.size_t
|
||||
C.mlx_get_active_memory(&active)
|
||||
return int(active)
|
||||
}
|
||||
|
||||
func CacheMemory() int {
|
||||
var cache C.size_t
|
||||
C.mlx_get_cache_memory(&cache)
|
||||
return int(cache)
|
||||
}
|
||||
|
||||
func PeakMemory() int {
|
||||
var peak C.size_t
|
||||
C.mlx_get_peak_memory(&peak)
|
||||
return int(peak)
|
||||
}
|
||||
|
||||
type Memory struct{}
|
||||
|
||||
func (Memory) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.Any("active", PrettyBytes(ActiveMemory())),
|
||||
slog.Any("cache", PrettyBytes(CacheMemory())),
|
||||
slog.Any("peak", PrettyBytes(PeakMemory())),
|
||||
)
|
||||
}
|
||||
|
||||
type (
|
||||
Byte int
|
||||
KibiByte int
|
||||
MebiByte int
|
||||
GibiByte int
|
||||
TebiByte int
|
||||
)
|
||||
|
||||
func ClearCache() {
|
||||
C.mlx_clear_cache()
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
package mlx
|
||||
|
||||
//go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist -DCMAKE_BUILD_TYPE=Release
|
||||
//go:generate cmake --build build --parallel
|
||||
//go:generate cmake --install build
|
||||
//go:generate sh -c "go run generator/main.go -output=. ./dist/include/mlx/c/*.h"
|
||||
|
||||
// #cgo CXXFLAGS: -std=c++17
|
||||
// #cgo CPPFLAGS: -I${SRCDIR}/dist/include
|
||||
// #cgo LDFLAGS: -L${SRCDIR}/dist/lib -lstdc++
|
||||
// #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
func doEval(outputs []*Array, async bool) {
|
||||
vector := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
for _, output := range outputs {
|
||||
if output.Valid() {
|
||||
C.mlx_vector_array_append_value(vector, output.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
if async {
|
||||
C.mlx_async_eval(vector)
|
||||
} else {
|
||||
C.mlx_eval(vector)
|
||||
}
|
||||
}
|
||||
|
||||
func AsyncEval(outputs ...*Array) {
|
||||
doEval(outputs, true)
|
||||
}
|
||||
|
||||
func Eval(outputs ...*Array) {
|
||||
doEval(outputs, false)
|
||||
}
|
||||
@@ -1,102 +0,0 @@
|
||||
package mlx
|
||||
|
||||
import "cmp"
|
||||
|
||||
type Quantization struct {
|
||||
Scales Array `weight:"scales"`
|
||||
Biases Array `weight:"biases"`
|
||||
GroupSize int `json:"group_size"`
|
||||
Bits int `json:"bits"`
|
||||
Mode string `json:"mode"`
|
||||
}
|
||||
|
||||
type Linear struct {
|
||||
Weight Array `weight:"weight"`
|
||||
Bias Array `weight:"bias"`
|
||||
|
||||
Quantization
|
||||
}
|
||||
|
||||
// Forward computes the linear transformation: x @ Weight.T + Bias
|
||||
func (m Linear) Forward(x *Array) *Array {
|
||||
if m.Scales.Valid() {
|
||||
x = x.QuantizedMatmul(
|
||||
&m.Weight,
|
||||
&m.Scales,
|
||||
&m.Biases,
|
||||
true,
|
||||
m.GroupSize,
|
||||
m.Bits,
|
||||
cmp.Or(m.Mode, "affine"),
|
||||
)
|
||||
if m.Bias.Valid() {
|
||||
x = m.Bias.Add(x)
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
w := m.Weight.Transpose(1, 0)
|
||||
if m.Bias.Valid() {
|
||||
return m.Bias.Addmm(x, w, 1.0, 1.0)
|
||||
}
|
||||
|
||||
return x.Matmul(w)
|
||||
}
|
||||
|
||||
func (m Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array {
|
||||
if m.Scales.Valid() {
|
||||
x = x.GatherQMM(
|
||||
&m.Weight,
|
||||
&m.Scales,
|
||||
&m.Biases,
|
||||
lhs,
|
||||
rhs,
|
||||
sorted,
|
||||
m.GroupSize,
|
||||
m.Bits,
|
||||
cmp.Or(m.Mode, "affine"),
|
||||
sorted,
|
||||
)
|
||||
if m.Bias.Valid() {
|
||||
x = m.Bias.Add(x)
|
||||
}
|
||||
return x
|
||||
} else {
|
||||
w := m.Weight.Transpose(0, 2, 1)
|
||||
x = x.GatherMM(w, lhs, rhs, sorted)
|
||||
}
|
||||
|
||||
if m.Bias.Valid() {
|
||||
x = m.Bias.Add(x)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
type Embedding struct {
|
||||
Weight Array `weight:"weight"`
|
||||
|
||||
Quantization
|
||||
}
|
||||
|
||||
func (e *Embedding) Forward(indices *Array) *Array {
|
||||
if e.Scales.Valid() {
|
||||
w := e.Weight.TakeAxis(indices, 0)
|
||||
return w.Dequantize(
|
||||
e.Scales.TakeAxis(indices, 0),
|
||||
e.Biases.TakeAxis(indices, 0),
|
||||
e.GroupSize,
|
||||
e.Bits,
|
||||
cmp.Or(e.Mode, "affine"),
|
||||
)
|
||||
}
|
||||
|
||||
return e.Weight.TakeAxis(indices, 0)
|
||||
}
|
||||
|
||||
func (e *Embedding) AsLinear() Linear {
|
||||
return Linear{
|
||||
Weight: e.Weight,
|
||||
Quantization: e.Quantization,
|
||||
}
|
||||
}
|
||||
@@ -1,341 +0,0 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func (t *Array) Abs() *Array {
|
||||
out := New("ABS", t)
|
||||
C.mlx_abs(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Add(other *Array) *Array {
|
||||
out := New("ADD", t, other)
|
||||
C.mlx_add(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Addmm(a, b *Array, alpha, beta float32) *Array {
|
||||
out := New("ADDMM", t, a, b)
|
||||
C.mlx_addmm(&out.ctx, t.ctx, a.ctx, b.ctx, C.float(alpha), C.float(beta), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Argmax(axis int, keepDims bool) *Array {
|
||||
out := New("ARGMAX", t)
|
||||
C.mlx_argmax_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) ArgpartitionAxis(kth int, axis int) *Array {
|
||||
out := New("ARGPARTITION", t)
|
||||
C.mlx_argpartition_axis(&out.ctx, t.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) ArgsortAxis(axis int) *Array {
|
||||
out := New("ARGSORT_AXIS", t)
|
||||
C.mlx_argsort_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) AsType(dtype DType) *Array {
|
||||
out := New("AS_TYPE", t)
|
||||
C.mlx_astype(&out.ctx, t.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array {
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
}
|
||||
|
||||
cStrides := make([]C.int64_t, len(strides))
|
||||
for i, s := range strides {
|
||||
cStrides[i] = C.int64_t(s)
|
||||
}
|
||||
|
||||
out := New("AS_STRIDED", t)
|
||||
C.mlx_as_strided(
|
||||
&out.ctx, t.ctx,
|
||||
unsafe.SliceData(cShape), C.size_t(len(shape)),
|
||||
unsafe.SliceData(cStrides), C.size_t(len(strides)),
|
||||
C.size_t(offset),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Concatenate(axis int, others ...*Array) *Array {
|
||||
vector := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
s := append([]*Array{t}, others...)
|
||||
for _, other := range s {
|
||||
C.mlx_vector_array_append_value(vector, other.ctx)
|
||||
}
|
||||
|
||||
out := New("CONCATENATE", s...)
|
||||
C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Divide(other *Array) *Array {
|
||||
out := New("DIVIDE", t, other)
|
||||
C.mlx_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Dequantize(scales, biases *Array, groupSize, bits int, mode string) *Array {
|
||||
out := New("DEQUANTIZE", t, scales, biases)
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
C.mlx_dequantize(
|
||||
&out.ctx,
|
||||
t.ctx,
|
||||
scales.ctx,
|
||||
biases.ctx,
|
||||
C.mlx_optional_int{
|
||||
value: C.int(groupSize),
|
||||
has_value: C.bool(groupSize > 0),
|
||||
},
|
||||
C.mlx_optional_int{
|
||||
value: C.int(bits),
|
||||
has_value: C.bool(bits > 0),
|
||||
},
|
||||
cMode,
|
||||
C.mlx_optional_dtype{
|
||||
has_value: false,
|
||||
},
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) ExpandDims(axis int) *Array {
|
||||
out := New("EXPAND_DIMS", t)
|
||||
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Flatten(startAxis, endAxis int) *Array {
|
||||
out := New("FLATTEN", t)
|
||||
C.mlx_flatten(&out.ctx, t.ctx, C.int(startAxis), C.int(endAxis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) FloorDivide(other *Array) *Array {
|
||||
out := New("FLOOR_DIVIDE", t, other)
|
||||
C.mlx_floor_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array {
|
||||
if lhs == nil {
|
||||
lhs = New("")
|
||||
}
|
||||
if rhs == nil {
|
||||
rhs = New("")
|
||||
}
|
||||
out := New("GATHER_MM", t, other, lhs, rhs)
|
||||
C.mlx_gather_mm(&out.ctx, t.ctx, other.ctx, lhs.ctx, rhs.ctx, C.bool(sorted), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) GatherQMM(weight, scales, biases, lhs, rhs *Array, transpose bool, groupSize, bits int, mode string, sorted bool) *Array {
|
||||
if lhs == nil {
|
||||
lhs = New("")
|
||||
}
|
||||
if rhs == nil {
|
||||
rhs = New("")
|
||||
}
|
||||
out := New("GATHER_QMM", t, weight, scales, biases, lhs, rhs)
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
C.mlx_gather_qmm(
|
||||
&out.ctx,
|
||||
t.ctx,
|
||||
weight.ctx,
|
||||
scales.ctx,
|
||||
biases.ctx,
|
||||
lhs.ctx,
|
||||
rhs.ctx,
|
||||
C.bool(transpose),
|
||||
C.mlx_optional_int{
|
||||
value: C.int(groupSize),
|
||||
has_value: C.bool(groupSize > 0),
|
||||
},
|
||||
C.mlx_optional_int{
|
||||
value: C.int(bits),
|
||||
has_value: C.bool(bits > 0),
|
||||
},
|
||||
cMode,
|
||||
C.bool(sorted),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Logsumexp(keepDims bool) *Array {
|
||||
out := New("LOGSUMEXP", t)
|
||||
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Matmul(other *Array) *Array {
|
||||
out := New("MATMUL", t, other)
|
||||
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Multiply(other *Array) *Array {
|
||||
out := New("MULTIPLY", t, other)
|
||||
C.mlx_multiply(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Negative() *Array {
|
||||
out := New("NEGATIVE", t)
|
||||
C.mlx_negative(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Power(exponent *Array) *Array {
|
||||
out := New("POWER", t, exponent)
|
||||
C.mlx_power(&out.ctx, t.ctx, exponent.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array {
|
||||
out := New("PUT_ALONG_AXIS", t, indices, values)
|
||||
C.mlx_put_along_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) QuantizedMatmul(weight, scales, biases *Array, transpose bool, groupSize, bits int, mode string) *Array {
|
||||
out := New("QUANTIZED_MATMUL", t, weight, scales, biases)
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
C.mlx_quantized_matmul(
|
||||
&out.ctx,
|
||||
t.ctx,
|
||||
weight.ctx,
|
||||
scales.ctx,
|
||||
biases.ctx,
|
||||
C.bool(transpose),
|
||||
C.mlx_optional_int{
|
||||
value: C.int(groupSize),
|
||||
has_value: C.bool(groupSize > 0),
|
||||
},
|
||||
C.mlx_optional_int{
|
||||
value: C.int(bits),
|
||||
has_value: C.bool(bits > 0),
|
||||
},
|
||||
cMode,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Reshape(axes ...int) *Array {
|
||||
cAxes := make([]C.int, len(axes))
|
||||
for i := range axes {
|
||||
cAxes[i] = C.int(axes[i])
|
||||
}
|
||||
|
||||
out := New("RESHAPE", t)
|
||||
C.mlx_reshape(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Sigmoid() *Array {
|
||||
out := New("SIGMOID", t)
|
||||
C.mlx_sigmoid(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Sqrt() *Array {
|
||||
out := New("SQRT", t)
|
||||
C.mlx_sqrt(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Squeeze(axis int) *Array {
|
||||
out := New("SQUEEZE", t)
|
||||
C.mlx_squeeze_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) StackAxis(axis int, others ...*Array) *Array {
|
||||
vectorData := make([]C.mlx_array, len(others)+1)
|
||||
vectorData[0] = t.ctx
|
||||
for i := range others {
|
||||
vectorData[i+1] = others[i].ctx
|
||||
}
|
||||
|
||||
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
out := New("STACK_AXIS", append(others, t)...)
|
||||
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Subtract(other *Array) *Array {
|
||||
out := New("SUBTRACT", t, other)
|
||||
C.mlx_subtract(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) SumAxis(axis int, keepDims bool) *Array {
|
||||
out := New("SUM_AXIS", t)
|
||||
C.mlx_sum_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) TakeAxis(indices *Array, axis int) *Array {
|
||||
out := New("TAKE_AXIS", t, indices)
|
||||
C.mlx_take_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) TakeAlongAxis(indices *Array, axis int) *Array {
|
||||
out := New("TAKE_ALONG_AXIS", t, indices)
|
||||
C.mlx_take_along_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Tanh() *Array {
|
||||
out := New("TANH", t)
|
||||
C.mlx_tanh(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Transpose(axes ...int) *Array {
|
||||
cAxes := make([]C.int, len(axes))
|
||||
for i, axis := range axes {
|
||||
cAxes[i] = C.int(axis)
|
||||
}
|
||||
|
||||
out := New("TRANSPOSE", t)
|
||||
C.mlx_transpose_axes(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Zeros(dtype DType, shape ...int) *Array {
|
||||
cAxes := make([]C.int, len(shape))
|
||||
for i := range shape {
|
||||
cAxes[i] = C.int(shape[i])
|
||||
}
|
||||
|
||||
t := New("ZEROS")
|
||||
C.mlx_zeros(&t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
return t
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
func (t *Array) Categorical(axis int) *Array {
|
||||
key := New("")
|
||||
out := New("", t, key)
|
||||
C.mlx_random_categorical(&out.ctx, t.ctx, C.int(axis), key.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -1,84 +0,0 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type slice struct {
|
||||
args []int
|
||||
}
|
||||
|
||||
func Slice(args ...int) slice {
|
||||
return slice{args: args}
|
||||
}
|
||||
|
||||
func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
|
||||
if len(slices) != len(dims) {
|
||||
panic("number of slice arguments must match number of tensor dimensions")
|
||||
}
|
||||
|
||||
args := [3][]C.int{
|
||||
make([]C.int, len(slices)),
|
||||
make([]C.int, len(slices)),
|
||||
make([]C.int, len(slices)),
|
||||
}
|
||||
|
||||
for i, s := range slices {
|
||||
switch len(s.args) {
|
||||
case 0:
|
||||
// slice[:]
|
||||
args[0][i] = C.int(0)
|
||||
args[1][i] = C.int(dims[i])
|
||||
args[2][i] = C.int(1)
|
||||
case 1:
|
||||
// slice[i]
|
||||
args[0][i] = C.int(s.args[0])
|
||||
args[1][i] = C.int(s.args[0] + 1)
|
||||
args[2][i] = C.int(1)
|
||||
case 2:
|
||||
// slice[i:j]
|
||||
args[0][i] = C.int(s.args[0])
|
||||
args[1][i] = cmp.Or(C.int(s.args[1]), C.int(dims[i]))
|
||||
args[2][i] = C.int(1)
|
||||
case 3:
|
||||
// slice[i:j:k]
|
||||
args[0][i] = C.int(s.args[0])
|
||||
args[1][i] = cmp.Or(C.int(s.args[1]), C.int(dims[i]))
|
||||
args[2][i] = C.int(s.args[2])
|
||||
default:
|
||||
panic("invalid slice arguments")
|
||||
}
|
||||
}
|
||||
|
||||
return args[0], args[1], args[2]
|
||||
}
|
||||
|
||||
func (t *Array) Slice(slices ...slice) *Array {
|
||||
starts, stops, strides := makeSlices(t.Dims(), slices...)
|
||||
out := New("SLICE", t)
|
||||
C.mlx_slice(
|
||||
&out.ctx, t.ctx,
|
||||
unsafe.SliceData(starts), C.size_t(len(starts)),
|
||||
unsafe.SliceData(stops), C.size_t(len(stops)),
|
||||
unsafe.SliceData(strides), C.size_t(len(strides)),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) SliceUpdate(other *Array, slices ...slice) *Array {
|
||||
starts, stops, strides := makeSlices(t.Dims(), slices...)
|
||||
out := New("SLICE_UPDATE", t, other)
|
||||
C.mlx_slice_update(
|
||||
&out.ctx, t.ctx, other.ctx,
|
||||
unsafe.SliceData(starts), C.size_t(len(starts)),
|
||||
unsafe.SliceData(stops), C.size_t(len(stops)),
|
||||
unsafe.SliceData(strides), C.size_t(len(strides)),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Device struct {
|
||||
ctx C.mlx_device
|
||||
}
|
||||
|
||||
func (d Device) LogValue() slog.Value {
|
||||
str := C.mlx_string_new()
|
||||
defer C.mlx_string_free(str)
|
||||
C.mlx_device_tostring(&str, d.ctx)
|
||||
return slog.StringValue(C.GoString(C.mlx_string_data(str)))
|
||||
}
|
||||
|
||||
var DefaultDevice = sync.OnceValue(func() Device {
|
||||
d := C.mlx_device_new()
|
||||
C.mlx_get_default_device(&d)
|
||||
return Device{d}
|
||||
})
|
||||
|
||||
type Stream struct {
|
||||
ctx C.mlx_stream
|
||||
}
|
||||
|
||||
func (s Stream) LogValue() slog.Value {
|
||||
str := C.mlx_string_new()
|
||||
defer C.mlx_string_free(str)
|
||||
C.mlx_stream_tostring(&str, s.ctx)
|
||||
return slog.StringValue(C.GoString(C.mlx_string_data(str)))
|
||||
}
|
||||
|
||||
var DefaultStream = sync.OnceValue(func() Stream {
|
||||
s := C.mlx_stream_new()
|
||||
C.mlx_get_default_stream(&s, DefaultDevice().ctx)
|
||||
return Stream{s}
|
||||
})
|
||||
@@ -1,8 +0,0 @@
|
||||
package base
|
||||
|
||||
import "github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
|
||||
// Cacher is implemented by models that support custom caching mechanisms.
|
||||
type Cacher interface {
|
||||
Cache() []cache.Cache
|
||||
}
|
||||
@@ -1,116 +0,0 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
type Model interface {
|
||||
// Forward performs a forward pass through the model.
|
||||
Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array
|
||||
|
||||
// NumLayers returns the number of layers in the model.
|
||||
// This is used to initialize caches.
|
||||
// TODO: consider moving cache initialization into the model itself.
|
||||
NumLayers() int
|
||||
}
|
||||
|
||||
type TextGeneration interface {
|
||||
Model
|
||||
Unembed(*mlx.Array) *mlx.Array
|
||||
}
|
||||
|
||||
func Walk(m Model) (map[string]*mlx.Array, map[string]*mlx.Quantization, []mlx.AfterLoadFunc) {
|
||||
weights := make(map[string]*mlx.Array)
|
||||
quantizations := make(map[string]*mlx.Quantization)
|
||||
var afterLoadFuncs []mlx.AfterLoadFunc
|
||||
var fn func(v reflect.Value, tags []string)
|
||||
fn = func(v reflect.Value, tags []string) {
|
||||
t := v.Type()
|
||||
|
||||
if method := v.Addr().MethodByName("AfterLoad"); method.IsValid() {
|
||||
var afterLoadFunc mlx.AfterLoadFunc
|
||||
reflect.ValueOf(&afterLoadFunc).Elem().Set(method)
|
||||
afterLoadFuncs = append(afterLoadFuncs, afterLoadFunc)
|
||||
}
|
||||
|
||||
if t == reflect.TypeOf((*mlx.Array)(nil)).Elem() {
|
||||
name := strings.Join(tags, ".")
|
||||
weights[name] = v.Addr().Interface().(*mlx.Array)
|
||||
return
|
||||
} else if t == reflect.TypeOf((*mlx.Quantization)(nil)).Elem() {
|
||||
quantizations[strings.Join(tags, ".")] = v.Addr().Interface().(*mlx.Quantization)
|
||||
}
|
||||
|
||||
for _, field := range reflect.VisibleFields(t) {
|
||||
if field.IsExported() {
|
||||
tt, vv := field.Type, v.FieldByIndex(field.Index)
|
||||
|
||||
// create local copy so tags are not modified between fields
|
||||
tags := tags
|
||||
if tag := field.Tag.Get("weight"); tag != "" {
|
||||
// TODO: use model.Tag
|
||||
tags = append(tags, tag)
|
||||
}
|
||||
|
||||
switch tt.Kind() {
|
||||
case reflect.Interface:
|
||||
vv = vv.Elem()
|
||||
fallthrough
|
||||
case reflect.Pointer:
|
||||
vv = vv.Elem()
|
||||
fallthrough
|
||||
case reflect.Struct:
|
||||
fn(vv, tags)
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := range vv.Len() {
|
||||
fn(vv.Index(i), append(tags, strconv.Itoa(i)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
fn(reflect.ValueOf(m).Elem(), []string{})
|
||||
return weights, quantizations, afterLoadFuncs
|
||||
}
|
||||
|
||||
var m = make(map[string]func(*model.Root) (Model, error))
|
||||
|
||||
func Register(name string, f func(*model.Root) (Model, error)) {
|
||||
if _, exists := m[name]; exists {
|
||||
panic("model already registered: " + name)
|
||||
}
|
||||
|
||||
m[name] = f
|
||||
}
|
||||
|
||||
func New(root *model.Root) (Model, error) {
|
||||
c, err := root.Open("config.json")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
var config struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(c).Decode(&config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
slog.Info("Model architecture", "arch", config.Architectures[0])
|
||||
if f, exists := m[config.Architectures[0]]; exists {
|
||||
return f(root)
|
||||
}
|
||||
|
||||
return nil, errors.New("unknown architecture")
|
||||
}
|
||||
@@ -1,84 +0,0 @@
|
||||
package gemma
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
Text TextModel `weight:"language_model"`
|
||||
}
|
||||
|
||||
func (m *Model) NumLayers() int {
|
||||
return len(m.Text.Layers)
|
||||
}
|
||||
|
||||
func (m Model) Cache() []cache.Cache {
|
||||
caches := make([]cache.Cache, m.NumLayers())
|
||||
for i := range caches {
|
||||
if (i+1)%m.Text.Options.SlidingWindowPattern == 0 {
|
||||
caches[i] = cache.NewKVCache()
|
||||
} else {
|
||||
caches[i] = cache.NewRotatingKVCache(m.Text.Options.SlidingWindow)
|
||||
}
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
func (m *Model) Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array {
|
||||
return m.Text.Forward(inputs, cache)
|
||||
}
|
||||
|
||||
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
||||
return m.Text.EmbedTokens.AsLinear().Forward(x)
|
||||
}
|
||||
|
||||
func init() {
|
||||
base.Register("Gemma3ForConditionalGeneration", func(root *model.Root) (base.Model, error) {
|
||||
bts, err := root.ReadFile("config.json")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
Text TextOptions `json:"text_config"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(bts, &opts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opts.Text.NumAttentionHeads = cmp.Or(opts.Text.NumAttentionHeads, 8)
|
||||
opts.Text.NumKeyValueHeads = cmp.Or(opts.Text.NumKeyValueHeads, 4)
|
||||
opts.Text.HeadDim = cmp.Or(opts.Text.HeadDim, 256)
|
||||
opts.Text.RMSNormEps = cmp.Or(opts.Text.RMSNormEps, 1e-6)
|
||||
opts.Text.SlidingWindowPattern = cmp.Or(opts.Text.SlidingWindowPattern, 6)
|
||||
|
||||
// TODO: implement json.Unmarshaler
|
||||
opts.Text.RoPE = map[bool]mlx.RoPE{
|
||||
true: {Dims: opts.Text.HeadDim, Traditional: false, Base: 1_000_000, Scale: 1. / 8.},
|
||||
false: {Dims: opts.Text.HeadDim, Traditional: false, Base: 10_000, Scale: 1},
|
||||
}
|
||||
|
||||
return &Model{
|
||||
Text: TextModel{
|
||||
Layers: make([]TextDecoderLayer, opts.Text.NumHiddenLayers),
|
||||
Options: opts.Text,
|
||||
},
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
type RMSNorm struct {
|
||||
mlx.RMSNorm
|
||||
}
|
||||
|
||||
func (m *RMSNorm) AfterLoad(*model.Root) ([]*mlx.Array, error) {
|
||||
m.Weight.Set(m.Weight.Add(mlx.FromValue(1)))
|
||||
return []*mlx.Array{}, nil
|
||||
}
|
||||
@@ -1,118 +0,0 @@
|
||||
package gemma
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
type TextOptions struct {
|
||||
HiddenSize int `json:"hidden_size"`
|
||||
NumHiddenLayers int `json:"num_hidden_layers"`
|
||||
IntermediateSize int `json:"intermediate_size"`
|
||||
NumAttentionHeads int `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int `json:"num_key_value_heads"`
|
||||
HeadDim int `json:"head_dim"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
SlidingWindow int `json:"sliding_window"`
|
||||
SlidingWindowPattern int `json:"sliding_window_pattern"`
|
||||
|
||||
RoPE map[bool]mlx.RoPE
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
EmbedTokens mlx.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []TextDecoderLayer `weight:"model.layers"`
|
||||
Norm RMSNorm `weight:"model.norm"`
|
||||
|
||||
Options TextOptions
|
||||
}
|
||||
|
||||
func (m TextModel) Forward(inputs *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
B, L := inputs.Dim(0), inputs.Dim(1)
|
||||
hiddenStates := m.EmbedTokens.Forward(inputs)
|
||||
|
||||
hiddenSize := mlx.FromValue(m.Options.HiddenSize).AsType(hiddenStates.DType())
|
||||
hiddenStates = hiddenStates.Multiply(hiddenSize.Sqrt())
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
hiddenStates = layer.Forward(hiddenStates, caches[i], B, L, m.Options.RoPE[(i+1)%m.Options.SlidingWindowPattern == 0], m.Options)
|
||||
}
|
||||
|
||||
hiddenStates = m.Norm.Forward(hiddenStates, m.Options.RMSNormEps)
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type TextDecoderLayer struct {
|
||||
InputNorm RMSNorm `weight:"input_layernorm"`
|
||||
Attention TextAttention `weight:"self_attn"`
|
||||
PostAttnNorm RMSNorm `weight:"post_attention_layernorm"`
|
||||
PreFFNorm RMSNorm `weight:"pre_feedforward_layernorm"`
|
||||
MLP TextMLP `weight:"mlp"`
|
||||
PostFFNorm RMSNorm `weight:"post_feedforward_layernorm"`
|
||||
}
|
||||
|
||||
func (m TextDecoderLayer) Forward(hiddenStates *mlx.Array, cache cache.Cache, B, L int, rope mlx.RoPE, opts TextOptions) *mlx.Array {
|
||||
residual := hiddenStates
|
||||
hiddenStates = m.InputNorm.Forward(hiddenStates, opts.RMSNormEps)
|
||||
hiddenStates = m.Attention.Forward(hiddenStates, cache, B, L, rope, opts)
|
||||
hiddenStates = m.PostAttnNorm.Forward(hiddenStates, opts.RMSNormEps)
|
||||
hiddenStates = hiddenStates.Add(residual)
|
||||
|
||||
residual = hiddenStates
|
||||
hiddenStates = m.PreFFNorm.Forward(hiddenStates, opts.RMSNormEps)
|
||||
hiddenStates = m.MLP.Forward(hiddenStates, opts)
|
||||
hiddenStates = m.PostFFNorm.Forward(hiddenStates, opts.RMSNormEps)
|
||||
hiddenStates = hiddenStates.Add(residual)
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type TextAttention struct {
|
||||
QProj mlx.Linear `weight:"q_proj"`
|
||||
QNorm RMSNorm `weight:"q_norm"`
|
||||
KProj mlx.Linear `weight:"k_proj"`
|
||||
KNorm RMSNorm `weight:"k_norm"`
|
||||
VProj mlx.Linear `weight:"v_proj"`
|
||||
OProj mlx.Linear `weight:"o_proj"`
|
||||
}
|
||||
|
||||
func (m TextAttention) Forward(hiddenStates *mlx.Array, cache cache.Cache, B, L int, rope mlx.RoPE, opts TextOptions) *mlx.Array {
|
||||
query := m.QProj.Forward(hiddenStates)
|
||||
key := m.KProj.Forward(hiddenStates)
|
||||
value := m.VProj.Forward(hiddenStates)
|
||||
|
||||
query = query.AsStrided(
|
||||
[]int{B, opts.NumAttentionHeads, L, opts.HeadDim},
|
||||
[]int{L * opts.NumAttentionHeads * opts.HeadDim, opts.HeadDim, opts.NumAttentionHeads * opts.HeadDim, 1},
|
||||
0)
|
||||
key = key.AsStrided(
|
||||
[]int{B, opts.NumKeyValueHeads, L, opts.HeadDim},
|
||||
[]int{L * opts.NumKeyValueHeads * opts.HeadDim, opts.HeadDim, opts.NumKeyValueHeads * opts.HeadDim, 1},
|
||||
0)
|
||||
value = value.AsStrided(
|
||||
[]int{B, opts.NumKeyValueHeads, L, opts.HeadDim},
|
||||
[]int{L * opts.NumKeyValueHeads * opts.HeadDim, opts.HeadDim, opts.NumKeyValueHeads * opts.HeadDim, 1},
|
||||
0)
|
||||
|
||||
query = m.QNorm.Forward(query, opts.RMSNormEps)
|
||||
key = m.KNorm.Forward(key, opts.RMSNormEps)
|
||||
|
||||
query = rope.Forward(query, cache.Offset())
|
||||
key = rope.Forward(key, cache.Offset())
|
||||
key, value = cache.Update(key, value)
|
||||
|
||||
attention := mlx.ScaledDotProductAttention(query, key, value, nil, 1.0/float32(math.Sqrt(float64(opts.HeadDim))))
|
||||
attention = attention.Transpose(0, 2, 1, 3).Reshape(B, L, -1)
|
||||
return m.OProj.Forward(attention)
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
GateProj mlx.Linear `weight:"gate_proj"`
|
||||
UpProj mlx.Linear `weight:"up_proj"`
|
||||
DownProj mlx.Linear `weight:"down_proj"`
|
||||
}
|
||||
|
||||
func (m TextMLP) Forward(h *mlx.Array, opts TextOptions) *mlx.Array {
|
||||
return m.DownProj.Forward(mlx.GELUApprox(m.GateProj.Forward(h)).Multiply(m.UpProj.Forward(h)))
|
||||
}
|
||||
@@ -1,334 +0,0 @@
|
||||
package glm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
HiddenSize int `json:"hidden_size"`
|
||||
NumHiddenLayers int `json:"num_hidden_layers"`
|
||||
IntermediateSize int `json:"intermediate_size"`
|
||||
NumAttentionHeads int `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int `json:"num_key_value_heads"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
|
||||
QLoraRank int `json:"q_lora_rank"`
|
||||
KVLoraRank int `json:"kv_lora_rank"`
|
||||
QKRopeHeadDim int `json:"qk_rope_head_dim"`
|
||||
QKNopeHeadDim int `json:"qk_nope_head_dim"`
|
||||
|
||||
NumRoutedExperts int `json:"n_routed_experts"`
|
||||
NumSharedExperts int `json:"n_shared_experts"`
|
||||
NumExpertsPerTok int `json:"num_experts_per_tok"`
|
||||
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
|
||||
NormTopKProb bool `json:"norm_topk_prob"`
|
||||
FirstKDenseReplace int `json:"first_k_dense_replace"`
|
||||
|
||||
mlx.RoPE
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
EmbedTokens mlx.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []Layer `weight:"model.layers"`
|
||||
Norm mlx.RMSNorm `weight:"model.norm"`
|
||||
LMHead mlx.Linear `weight:"lm_head"`
|
||||
|
||||
Options
|
||||
}
|
||||
|
||||
func (m Model) NumLayers() int {
|
||||
return len(m.Layers)
|
||||
}
|
||||
|
||||
func (m Model) Forward(inputs *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
B, L := inputs.Dim(0), inputs.Dim(1)
|
||||
h := m.EmbedTokens.Forward(inputs)
|
||||
for i, layer := range m.Layers {
|
||||
h = layer.Forward(h, caches[i], B, L, m.Options)
|
||||
}
|
||||
|
||||
h = m.Norm.Forward(h, m.RMSNormEps)
|
||||
return h
|
||||
}
|
||||
|
||||
func (m Model) Unembed(x *mlx.Array) *mlx.Array {
|
||||
return m.LMHead.Forward(x)
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
InputLayernorm mlx.RMSNorm `weight:"input_layernorm"`
|
||||
Attention Attention `weight:"self_attn"`
|
||||
PostAttentionLayernorm mlx.RMSNorm `weight:"post_attention_layernorm"`
|
||||
MLP MLP `weight:"mlp"`
|
||||
}
|
||||
|
||||
func (m Layer) Forward(h *mlx.Array, cache cache.Cache, B, L int, opts Options) *mlx.Array {
|
||||
r := h
|
||||
h = m.InputLayernorm.Forward(h, opts.RMSNormEps)
|
||||
h = m.Attention.Forward(h, cache, B, L, opts)
|
||||
h = h.Add(r)
|
||||
|
||||
r = h
|
||||
h = m.PostAttentionLayernorm.Forward(h, opts.RMSNormEps)
|
||||
h = m.MLP.Forward(h, B, L, opts)
|
||||
h = h.Add(r)
|
||||
return h
|
||||
}
|
||||
|
||||
type MultiLinear struct {
|
||||
Weight mlx.Array `weight:"weight"`
|
||||
}
|
||||
|
||||
func (m MultiLinear) Forward(x *mlx.Array) *mlx.Array {
|
||||
return x.Matmul(m.Weight.Transpose(0, 2, 1))
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
QAProj mlx.Linear `weight:"q_a_proj"`
|
||||
QALayernorm mlx.RMSNorm `weight:"q_a_layernorm"`
|
||||
QBProj mlx.Linear `weight:"q_b_proj"`
|
||||
|
||||
KVAProjWithMQA mlx.Linear `weight:"kv_a_proj_with_mqa"`
|
||||
KVALayernorm mlx.RMSNorm `weight:"kv_a_layernorm"`
|
||||
KVBProj mlx.Linear `weight:"kv_b_proj"`
|
||||
|
||||
embedQ MultiLinear
|
||||
unembedOut MultiLinear
|
||||
|
||||
OProj mlx.Linear `weight:"o_proj"`
|
||||
}
|
||||
|
||||
func (m *Attention) AfterLoad(root *model.Root) ([]*mlx.Array, error) {
|
||||
bts, err := root.ReadFile("config.json")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
NumAttentionHeads int `json:"num_attention_heads"`
|
||||
QKNopeHeadDim int `json:"qk_nope_head_dim"`
|
||||
KVLoraRank int `json:"kv_lora_rank"`
|
||||
}
|
||||
if err := json.Unmarshal(bts, &opts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
w := m.KVBProj.Weight.Reshape(opts.NumAttentionHeads, -1, opts.KVLoraRank)
|
||||
m.embedQ.Weight.Set(w.Slice(mlx.Slice(), mlx.Slice(0, opts.QKNopeHeadDim), mlx.Slice()).Transpose(0, 2, 1))
|
||||
m.unembedOut.Weight.Set(w.Slice(mlx.Slice(), mlx.Slice(opts.QKNopeHeadDim, 0), mlx.Slice()))
|
||||
|
||||
return []*mlx.Array{
|
||||
&m.embedQ.Weight,
|
||||
&m.unembedOut.Weight,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m Attention) Forward(hiddenStates *mlx.Array, cache cache.Cache, B, L int, opts Options) *mlx.Array {
|
||||
query := m.QAProj.Forward(hiddenStates)
|
||||
query = m.QALayernorm.Forward(query, opts.RMSNormEps)
|
||||
query = m.QBProj.Forward(query)
|
||||
|
||||
query = query.Reshape(B, L, opts.NumAttentionHeads, -1)
|
||||
query = query.Transpose(0, 2, 1, 3)
|
||||
|
||||
queryNope := query.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(), mlx.Slice(0, opts.QKNopeHeadDim))
|
||||
queryRope := query.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(), mlx.Slice(opts.QKNopeHeadDim, 0))
|
||||
|
||||
compressedKV := m.KVAProjWithMQA.Forward(hiddenStates)
|
||||
|
||||
keyRope := compressedKV.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(opts.KVLoraRank, 0))
|
||||
keyRope = keyRope.Reshape(B, L, 1, opts.QKRopeHeadDim)
|
||||
keyRope = keyRope.Transpose(0, 2, 1, 3)
|
||||
|
||||
kvCompressed := compressedKV.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, opts.KVLoraRank))
|
||||
|
||||
var offset int
|
||||
if cache != nil {
|
||||
offset = cache.Offset()
|
||||
}
|
||||
|
||||
queryRope = opts.RoPE.Forward(queryRope, offset)
|
||||
keyRope = opts.RoPE.Forward(keyRope, offset)
|
||||
|
||||
key := m.KVALayernorm.Forward(kvCompressed, opts.RMSNormEps).
|
||||
ExpandDims(1).
|
||||
Concatenate(3, keyRope)
|
||||
|
||||
if cache != nil {
|
||||
key, _ = cache.Update(key, mlx.Zeros(mlx.DTypeBFloat16, B, 1, L, 0))
|
||||
}
|
||||
|
||||
value := key.Clone().Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(), mlx.Slice(0, opts.KVLoraRank))
|
||||
query = m.embedQ.Forward(queryNope).Concatenate(3, queryRope)
|
||||
|
||||
attention := mlx.ScaledDotProductAttention(query, key, value, nil, float32(1.0/math.Sqrt(float64(opts.QKNopeHeadDim+opts.QKRopeHeadDim))))
|
||||
attention = m.unembedOut.Forward(attention)
|
||||
attention = attention.Transpose(0, 2, 1, 3).Reshape(B, L, -1)
|
||||
return m.OProj.Forward(attention)
|
||||
}
|
||||
|
||||
type MLP interface {
|
||||
Forward(*mlx.Array, int, int, Options) *mlx.Array
|
||||
}
|
||||
|
||||
type dense struct {
|
||||
GateProj mlx.Linear `weight:"gate_proj"`
|
||||
UpProj mlx.Linear `weight:"up_proj"`
|
||||
DownProj mlx.Linear `weight:"down_proj"`
|
||||
}
|
||||
|
||||
func (m dense) Forward(h *mlx.Array, _, _ int, opts Options) *mlx.Array {
|
||||
h = mlx.SILU(m.GateProj.Forward(h)).Multiply(m.UpProj.Forward(h))
|
||||
return m.DownProj.Forward(h)
|
||||
}
|
||||
|
||||
type Gate struct {
|
||||
Gate mlx.Linear `weight:"gate"`
|
||||
CorrectionBias mlx.Array `weight:"gate.e_score_correction_bias"`
|
||||
}
|
||||
|
||||
var expertSelect *mlx.Closure
|
||||
|
||||
func ExpertSelect(opts Options) *mlx.Closure {
|
||||
if expertSelect == nil {
|
||||
expertSelect = mlx.Compile(func(inputs []*mlx.Array) []*mlx.Array {
|
||||
scores, correctionBias := inputs[0], inputs[1]
|
||||
|
||||
scores = scores.Sigmoid()
|
||||
original := scores
|
||||
scores = scores.Add(correctionBias)
|
||||
|
||||
indices := scores.Negative().ArgpartitionAxis(opts.NumExpertsPerTok-1, -1)
|
||||
indices = indices.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, opts.NumExpertsPerTok))
|
||||
|
||||
scores = original.TakeAlongAxis(indices, -1)
|
||||
if opts.NumExpertsPerTok > 1 && opts.NormTopKProb {
|
||||
scores = scores.Divide(scores.SumAxis(-1, true).Add(mlx.FromValue[float32](1e-20)))
|
||||
}
|
||||
|
||||
scores = scores.Multiply(mlx.FromValue(opts.RoutedScalingFactor))
|
||||
return []*mlx.Array{indices, scores}
|
||||
}, false)
|
||||
}
|
||||
|
||||
return expertSelect
|
||||
}
|
||||
|
||||
func (m Gate) Forward(h *mlx.Array, opts Options) (indices, scores *mlx.Array) {
|
||||
outputs := ExpertSelect(opts).Call([]*mlx.Array{
|
||||
m.Gate.Forward(h).AsType(mlx.DTypeFloat32),
|
||||
&m.CorrectionBias,
|
||||
})
|
||||
return outputs[0], outputs[1]
|
||||
}
|
||||
|
||||
type sparse struct {
|
||||
Gate
|
||||
|
||||
Experts []dense `weight:"experts"`
|
||||
fused struct {
|
||||
GateProj mlx.Linear
|
||||
UpProj mlx.Linear
|
||||
DownProj mlx.Linear
|
||||
}
|
||||
|
||||
SharedExperts dense `weight:"shared_experts"`
|
||||
}
|
||||
|
||||
func (m *sparse) AfterLoad(*model.Root) ([]*mlx.Array, error) {
|
||||
w1 := make([]*mlx.Array, len(m.Experts))
|
||||
w2 := make([]*mlx.Array, len(m.Experts))
|
||||
w3 := make([]*mlx.Array, len(m.Experts))
|
||||
|
||||
for i := range m.Experts {
|
||||
w1[i] = &m.Experts[i].GateProj.Weight
|
||||
w2[i] = &m.Experts[i].UpProj.Weight
|
||||
w3[i] = &m.Experts[i].DownProj.Weight
|
||||
}
|
||||
|
||||
m.fused.GateProj.Weight.Set(w1[0].StackAxis(0, w1[1:]...))
|
||||
m.fused.UpProj.Weight.Set(w2[0].StackAxis(0, w2[1:]...))
|
||||
m.fused.DownProj.Weight.Set(w3[0].StackAxis(0, w3[1:]...))
|
||||
|
||||
return []*mlx.Array{
|
||||
&m.fused.GateProj.Weight,
|
||||
&m.fused.UpProj.Weight,
|
||||
&m.fused.DownProj.Weight,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m sparse) Forward(h *mlx.Array, B, L int, opts Options) *mlx.Array {
|
||||
indices, scores := m.Gate.Forward(h, opts)
|
||||
scores = scores.ExpandDims(-1)
|
||||
|
||||
flat := h.ExpandDims(-2).ExpandDims(-2).Reshape(-1, 1, 1, opts.HiddenSize)
|
||||
indices = indices.Reshape(-1, opts.NumExpertsPerTok)
|
||||
|
||||
sort := B*L >= 64
|
||||
var inverseOrder *mlx.Array
|
||||
if sort {
|
||||
indicesAll := indices.Flatten(0, len(indices.Dims())-1)
|
||||
order := indicesAll.ArgsortAxis(0)
|
||||
inverseOrder = order.ArgsortAxis(0)
|
||||
flat = flat.Squeeze(1).TakeAxis(order.FloorDivide(mlx.FromValue(opts.NumExpertsPerTok)), 0).ExpandDims(1)
|
||||
indices = indicesAll.TakeAxis(order, 0).Reshape(B*L*opts.NumExpertsPerTok, 1)
|
||||
}
|
||||
|
||||
experts := mlx.SILU(m.fused.GateProj.Gather(flat, nil, indices, sort)).
|
||||
Multiply(m.fused.UpProj.Gather(flat, nil, indices, sort))
|
||||
experts = m.fused.DownProj.Gather(experts, nil, indices, sort)
|
||||
|
||||
if sort {
|
||||
experts = experts.Squeeze(2).Squeeze(1).TakeAxis(inverseOrder, 0)
|
||||
experts = experts.Reshape(-1, opts.NumExpertsPerTok, opts.HiddenSize)
|
||||
} else {
|
||||
experts = experts.Squeeze(2)
|
||||
}
|
||||
|
||||
experts = experts.Reshape(B, L, opts.NumExpertsPerTok, opts.HiddenSize)
|
||||
experts = experts.Multiply(scores).SumAxis(-2, false).AsType(experts.DType())
|
||||
experts = experts.Add(m.SharedExperts.Forward(h, B, L, opts))
|
||||
return experts.Reshape(B, L, -1)
|
||||
}
|
||||
|
||||
func init() {
|
||||
base.Register("Glm4MoeLiteForCausalLM", func(root *model.Root) (base.Model, error) {
|
||||
bts, err := root.ReadFile("config.json")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var opts Options
|
||||
if err := json.Unmarshal(bts, &opts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opts.RoPE = mlx.RoPE{
|
||||
Dims: opts.QKRopeHeadDim,
|
||||
Traditional: true,
|
||||
Base: opts.RopeTheta,
|
||||
Scale: 1,
|
||||
}
|
||||
|
||||
layers := make([]Layer, opts.NumHiddenLayers)
|
||||
for i := range layers {
|
||||
if i < opts.FirstKDenseReplace {
|
||||
layers[i].MLP = &dense{}
|
||||
} else {
|
||||
layers[i].MLP = &sparse{Experts: make([]dense, opts.NumRoutedExperts)}
|
||||
}
|
||||
}
|
||||
|
||||
return &Model{
|
||||
Layers: layers,
|
||||
Options: opts,
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
@@ -1,130 +0,0 @@
|
||||
package llama
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
HiddenAct string `json:"hidden_act"`
|
||||
HiddenSize int `json:"hidden_size"`
|
||||
IntermediateSize int `json:"intermediate_size"`
|
||||
NumAttentionHeads int `json:"num_attention_heads"`
|
||||
NumHiddenLayers int `json:"num_hidden_layers"`
|
||||
NumKeyValueHeads int `json:"num_key_value_heads"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
|
||||
mlx.RoPE
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
EmbedTokens mlx.Embedding `weight:"model.embed_tokens"`
|
||||
Layers []Layer `weight:"model.layers"`
|
||||
Norm mlx.RMSNorm `weight:"model.norm"`
|
||||
Output mlx.Linear `weight:"lm_head"`
|
||||
|
||||
Options
|
||||
}
|
||||
|
||||
func (m Model) NumLayers() int {
|
||||
return len(m.Layers)
|
||||
}
|
||||
|
||||
func (m Model) Forward(inputs *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
slog.Debug("Model.forward", "input shape", inputs.Dims(), "m.EmbedTokens", m.EmbedTokens.Weight.Dims())
|
||||
B, L := inputs.Dim(0), inputs.Dim(1)
|
||||
hiddenStates := m.EmbedTokens.Forward(inputs)
|
||||
for i, layer := range m.Layers {
|
||||
hiddenStates = layer.Forward(hiddenStates, caches[i], B, L, m.Options)
|
||||
}
|
||||
hiddenStates = m.Norm.Forward(hiddenStates, m.RMSNormEps)
|
||||
hiddenStates = m.Output.Forward(hiddenStates)
|
||||
slog.Debug("Model.forward", "output shape", hiddenStates.Dims(), "m.Output", m.Output.Weight.Dims())
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
AttentionNorm mlx.RMSNorm `weight:"input_layernorm"`
|
||||
Attention Attention `weight:"self_attn"`
|
||||
MLPNorm mlx.RMSNorm `weight:"post_attention_layernorm"`
|
||||
MLP MLP `weight:"mlp"`
|
||||
}
|
||||
|
||||
func (m Layer) Forward(hiddenStates *mlx.Array, c cache.Cache, B, L int, opts Options) *mlx.Array {
|
||||
residual := hiddenStates
|
||||
hiddenStates = m.AttentionNorm.Forward(hiddenStates, opts.RMSNormEps)
|
||||
hiddenStates = m.Attention.Forward(hiddenStates, c, B, L, opts)
|
||||
hiddenStates = hiddenStates.Add(residual)
|
||||
|
||||
residual = hiddenStates
|
||||
hiddenStates = m.MLPNorm.Forward(hiddenStates, opts.RMSNormEps)
|
||||
hiddenStates = m.MLP.Forward(hiddenStates)
|
||||
hiddenStates = hiddenStates.Add(residual)
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
QueryProj mlx.Linear `weight:"q_proj"`
|
||||
KeyProj mlx.Linear `weight:"k_proj"`
|
||||
ValueProj mlx.Linear `weight:"v_proj"`
|
||||
OutputProj mlx.Linear `weight:"o_proj"`
|
||||
}
|
||||
|
||||
func (m Attention) Forward(hiddenStates *mlx.Array, cache cache.Cache, B, L int, opts Options) *mlx.Array {
|
||||
query := m.QueryProj.Forward(hiddenStates)
|
||||
query = query.Reshape(B, L, opts.NumAttentionHeads, -1).Transpose(0, 2, 1, 3)
|
||||
|
||||
key := m.KeyProj.Forward(hiddenStates)
|
||||
key = key.Reshape(B, L, opts.NumKeyValueHeads, -1).Transpose(0, 2, 1, 3)
|
||||
|
||||
value := m.ValueProj.Forward(hiddenStates)
|
||||
value = value.Reshape(B, L, opts.NumKeyValueHeads, -1).Transpose(0, 2, 1, 3)
|
||||
|
||||
query = opts.RoPE.Forward(query, cache.Offset())
|
||||
key = opts.RoPE.Forward(key, cache.Offset())
|
||||
key, value = cache.Update(key, value)
|
||||
|
||||
attention := mlx.ScaledDotProductAttention(query, key, value, nil, 1.0/float32(math.Sqrt(float64(key.Dim(-1)))))
|
||||
attention = attention.Transpose(0, 2, 1, 3).Reshape(B, L, -1)
|
||||
return m.OutputProj.Forward(attention)
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
Gate mlx.Linear `weight:"gate_proj"`
|
||||
Up mlx.Linear `weight:"up_proj"`
|
||||
Down mlx.Linear `weight:"down_proj"`
|
||||
}
|
||||
|
||||
func (m MLP) Forward(h *mlx.Array) *mlx.Array {
|
||||
return m.Down.Forward(mlx.SILU(m.Gate.Forward(h)).Multiply(m.Up.Forward(h)))
|
||||
}
|
||||
|
||||
func init() {
|
||||
base.Register("MistralForCausalLM", func(root *model.Root) (base.Model, error) {
|
||||
bts, err := root.ReadFile("config.json")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var opts Options
|
||||
// TODO: implement json.Unmarshal for Options
|
||||
if err := json.Unmarshal(bts, &opts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(bts, &opts.RoPE); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Model{
|
||||
Layers: make([]Layer, opts.NumHiddenLayers),
|
||||
Options: opts,
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
_ "github.com/ollama/ollama/x/mlxrunner/model/gemma/3"
|
||||
_ "github.com/ollama/ollama/x/mlxrunner/model/glm/4/moe/lite"
|
||||
_ "github.com/ollama/ollama/x/mlxrunner/model/llama"
|
||||
)
|
||||
@@ -1,138 +0,0 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
)
|
||||
|
||||
func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
model, ok := r.Model.(base.TextGeneration)
|
||||
if !ok {
|
||||
return errors.New("model does not support causal language modeling")
|
||||
}
|
||||
|
||||
inputs, err := r.Tokenizer.Encode(request.Prompt, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
caches, tokens := r.FindNearestCache(inputs)
|
||||
if len(caches) == 0 {
|
||||
if cacher, ok := model.(base.Cacher); ok {
|
||||
caches = cacher.Cache()
|
||||
} else {
|
||||
caches = make([]cache.Cache, model.NumLayers())
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
total, processed := len(tokens), 0
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
for total-processed > 1 {
|
||||
n := min(2<<10, total-processed-1)
|
||||
temp := model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
||||
defer mlx.Free(temp)
|
||||
mlx.Eval(func() []*mlx.Array {
|
||||
s := make([]*mlx.Array, 2*len(caches))
|
||||
for i, c := range caches {
|
||||
s[2*i], s[2*i+1] = c.State()
|
||||
}
|
||||
return s
|
||||
}()...)
|
||||
processed += n
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
mlx.ClearCache()
|
||||
}
|
||||
|
||||
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
logits := model.Unembed(model.Forward(token.ExpandDims(0), caches))
|
||||
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
||||
|
||||
// TODO: additional logit processing (logit bias, repetition penalty, etc.)
|
||||
|
||||
logprobs := logits.Subtract(logits.Logsumexp(true))
|
||||
return request.Sample(logprobs), logprobs
|
||||
}
|
||||
|
||||
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed))
|
||||
mlx.AsyncEval(sample, logprobs)
|
||||
|
||||
// buffer partial, multibyte unicode
|
||||
var b bytes.Buffer
|
||||
|
||||
now := time.Now()
|
||||
final := Response{PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
|
||||
outputs := make([]int32, 0, request.Options.MaxTokens)
|
||||
for i := range request.Options.MaxTokens {
|
||||
nextSample, nextLogprobs := step(sample)
|
||||
mlx.AsyncEval(nextSample, nextLogprobs)
|
||||
|
||||
if i == 0 {
|
||||
slog.Info("Prompt processing progress", "processed", total, "total", total)
|
||||
mlx.Eval(sample)
|
||||
final.PromptTokensDuration = time.Since(now)
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
output := int32(sample.Int())
|
||||
outputs = append(outputs, output)
|
||||
|
||||
if r.Tokenizer.Is(output, tokenizer.SpecialEOS) {
|
||||
final.Token = int(output)
|
||||
final.DoneReason = 0
|
||||
final.CompletionTokens = i
|
||||
break
|
||||
}
|
||||
|
||||
request.Responses <- Response{
|
||||
Text: r.Decode(output, &b),
|
||||
Token: int(output),
|
||||
}
|
||||
|
||||
mlx.Free(sample, logprobs)
|
||||
if i%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
|
||||
sample, logprobs = nextSample, nextLogprobs
|
||||
}
|
||||
|
||||
mlx.Free(sample, logprobs)
|
||||
final.CompletionTokensDuration = time.Since(now)
|
||||
request.Responses <- final
|
||||
r.InsertCache(append(inputs, outputs...), caches)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
|
||||
token, err := r.Tokenizer.Decode([]int32{sample})
|
||||
if err != nil {
|
||||
slog.Error("Failed to decode tokens", "error", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
if _, err := b.WriteString(token); err != nil {
|
||||
slog.Error("Failed to write token to buffer", "error", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
if text := b.String(); utf8.ValidString(text) {
|
||||
b.Reset()
|
||||
return text
|
||||
} else if b.Len() >= utf8.UTFMax {
|
||||
b.Reset()
|
||||
return text
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
@@ -1,110 +0,0 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
_ "github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/mlxrunner/sample"
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
TextCompletionsRequest
|
||||
Responses chan Response
|
||||
Pipeline func(Request) error
|
||||
|
||||
sample.Sampler
|
||||
caches []cache.Cache
|
||||
}
|
||||
|
||||
type TextCompletionsRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Options struct {
|
||||
Temperature float32 `json:"temperature"`
|
||||
TopP float32 `json:"top_p"`
|
||||
MinP float32 `json:"min_p"`
|
||||
TopK int `json:"top_k"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
|
||||
// Deprecated: use MaxTokens instead
|
||||
NumPredict int `json:"num_predict"`
|
||||
} `json:"options"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Text string `json:"content,omitempty"`
|
||||
Token int `json:"token,omitempty"`
|
||||
Logprobs []float32 `json:"logprobs,omitempty"`
|
||||
Done bool `json:"done,omitempty"`
|
||||
DoneReason int `json:"done_reason,omitempty"`
|
||||
|
||||
PromptTokens int `json:"prompt_eval_count,omitempty"`
|
||||
PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
||||
CompletionTokens int `json:"eval_count,omitempty"`
|
||||
CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"`
|
||||
TotalTokens int `json:"total_tokens,omitempty"`
|
||||
}
|
||||
|
||||
type Runner struct {
|
||||
Model base.Model
|
||||
Tokenizer tokenizer.Tokenizer
|
||||
Requests chan Request
|
||||
CacheEntries map[int32]*CacheEntry
|
||||
}
|
||||
|
||||
func (r *Runner) Load(name model.Name) (err error) {
|
||||
root, err := model.Open(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer root.Close()
|
||||
|
||||
r.Model, err = base.New(root)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.Tokenizer, err = tokenizer.New(root)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
weights, quantizations, afterLoadFuncs := base.Walk(r.Model)
|
||||
return mlx.LoadAll(root, weights, quantizations, afterLoadFuncs)
|
||||
}
|
||||
|
||||
func (r *Runner) Run(host, port string, mux http.Handler) error {
|
||||
g, ctx := errgroup.WithContext(context.Background())
|
||||
|
||||
g.Go(func() error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case request := <-r.Requests:
|
||||
if err := request.Pipeline(request); err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
close(request.Responses)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
slog.Info("Starting HTTP server", "host", host, "port", port)
|
||||
return http.ListenAndServe(net.JoinHostPort(host, port), mux)
|
||||
})
|
||||
|
||||
return g.Wait()
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
type Sampler interface {
|
||||
Sample(*mlx.Array) *mlx.Array
|
||||
}
|
||||
|
||||
func New(temp, top_p, min_p float32, top_k int) Sampler {
|
||||
if temp == 0 {
|
||||
return greedy{}
|
||||
}
|
||||
|
||||
var samplers []Sampler
|
||||
if top_p > 0 && top_p < 1 {
|
||||
samplers = append(samplers, TopP(top_p))
|
||||
}
|
||||
|
||||
if min_p != 0 {
|
||||
samplers = append(samplers, MinP(min_p))
|
||||
}
|
||||
|
||||
if top_k > 0 {
|
||||
samplers = append(samplers, TopK(top_k))
|
||||
}
|
||||
|
||||
samplers = append(samplers, Temperature(temp))
|
||||
return chain(samplers)
|
||||
}
|
||||
|
||||
type greedy struct{}
|
||||
|
||||
func (greedy) Sample(logits *mlx.Array) *mlx.Array {
|
||||
return logits.Argmax(-1, false)
|
||||
}
|
||||
|
||||
type chain []Sampler
|
||||
|
||||
func (c chain) Sample(logits *mlx.Array) *mlx.Array {
|
||||
for _, sampler := range c {
|
||||
logits = sampler.Sample(logits)
|
||||
}
|
||||
return logits
|
||||
}
|
||||
|
||||
type Temperature float32
|
||||
|
||||
func (t Temperature) Sample(logits *mlx.Array) *mlx.Array {
|
||||
return logits.Multiply(mlx.FromValue(1 / float32(t))).Categorical(-1)
|
||||
}
|
||||
|
||||
type TopP float32
|
||||
|
||||
func (p TopP) Sample(logprobs *mlx.Array) *mlx.Array {
|
||||
// TODO: implement
|
||||
return logprobs
|
||||
}
|
||||
|
||||
type MinP float32
|
||||
|
||||
func (p MinP) Sample(logprobs *mlx.Array) *mlx.Array {
|
||||
// TODO: implement
|
||||
return logprobs
|
||||
}
|
||||
|
||||
type TopK int
|
||||
|
||||
func (k TopK) Sample(logprobs *mlx.Array) *mlx.Array {
|
||||
mask := logprobs.Negative().ArgpartitionAxis(int(k)-1, -1).Slice(mlx.Slice(), mlx.Slice(int(k), 0))
|
||||
return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
|
||||
}
|
||||
@@ -1,180 +0,0 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/sample"
|
||||
)
|
||||
|
||||
func Execute(args []string) error {
|
||||
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
|
||||
|
||||
var (
|
||||
name model.Name
|
||||
port int
|
||||
)
|
||||
|
||||
flagSet := flag.NewFlagSet("mlxrunner", flag.ExitOnError)
|
||||
flagSet.Var(&name, "model", "Model name")
|
||||
flagSet.IntVar(&port, "port", 0, "Port to listen on")
|
||||
_ = flagSet.Bool("verbose", false, "Enable debug logging")
|
||||
flagSet.Parse(args)
|
||||
|
||||
runner := Runner{
|
||||
Requests: make(chan Request),
|
||||
CacheEntries: make(map[int32]*CacheEntry),
|
||||
}
|
||||
|
||||
if err := runner.Load(name); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": 0,
|
||||
"progress": 100,
|
||||
}); err != nil {
|
||||
slog.Error("Failed to encode response", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
mux.HandleFunc("/v1/models", func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case "POST":
|
||||
fallthrough
|
||||
case "GET":
|
||||
if err := json.NewEncoder(w).Encode(map[string]any{
|
||||
"Success": true,
|
||||
}); err != nil {
|
||||
slog.Error("Failed to encode response", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
case "DELETE":
|
||||
// TODO: cleanup model and cache
|
||||
}
|
||||
})
|
||||
|
||||
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
request := Request{Responses: make(chan Response)}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
|
||||
slog.Error("Failed to decode request", "error", err)
|
||||
http.Error(w, "Bad Request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
|
||||
if request.Options.MaxTokens < 1 {
|
||||
request.Options.MaxTokens = 16 << 10
|
||||
}
|
||||
|
||||
request.Pipeline = runner.TextGenerationPipeline
|
||||
request.Sampler = sample.New(
|
||||
request.Options.Temperature,
|
||||
request.Options.TopP,
|
||||
request.Options.MinP,
|
||||
request.Options.TopK,
|
||||
)
|
||||
|
||||
runner.Requests <- request
|
||||
|
||||
w.Header().Set("Content-Type", "application/jsonl")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
enc := json.NewEncoder(w)
|
||||
for response := range request.Responses {
|
||||
if err := enc.Encode(response); err != nil {
|
||||
slog.Error("Failed to encode response", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
mux.HandleFunc("POST /v1/tokenize", func(w http.ResponseWriter, r *http.Request) {
|
||||
var b bytes.Buffer
|
||||
if _, err := io.Copy(&b, r.Body); err != nil {
|
||||
slog.Error("Failed to read request body", "error", err)
|
||||
http.Error(w, "Bad Request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
tokens, err := runner.Tokenizer.Encode(b.String(), true)
|
||||
if err != nil {
|
||||
slog.Error("Failed to tokenize text", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(tokens); err != nil {
|
||||
slog.Error("Failed to encode response", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
for source, target := range map[string]string{
|
||||
"GET /health": "/v1/status",
|
||||
"POST /load": "/v1/models",
|
||||
"POST /completion": "/v1/completions",
|
||||
} {
|
||||
mux.Handle(source, http.RedirectHandler(target, http.StatusPermanentRedirect))
|
||||
}
|
||||
|
||||
return runner.Run("127.0.0.1", strconv.Itoa(port), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
recorder := &statusRecorder{ResponseWriter: w, code: http.StatusOK}
|
||||
t := time.Now()
|
||||
mux.ServeHTTP(recorder, r)
|
||||
|
||||
var level slog.Level
|
||||
switch {
|
||||
case recorder.code >= 500:
|
||||
level = slog.LevelError
|
||||
case recorder.code >= 400:
|
||||
level = slog.LevelWarn
|
||||
case recorder.code >= 300:
|
||||
return
|
||||
}
|
||||
|
||||
slog.Log(r.Context(), level, "ServeHTTP", "method", r.Method, "path", r.URL.Path, "took", time.Since(t), "status", recorder.Status())
|
||||
}))
|
||||
}
|
||||
|
||||
type statusRecorder struct {
|
||||
http.ResponseWriter
|
||||
code int
|
||||
}
|
||||
|
||||
func (w *statusRecorder) WriteHeader(code int) {
|
||||
w.code = code
|
||||
w.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (w *statusRecorder) Status() string {
|
||||
return strconv.Itoa(w.code) + " " + http.StatusText(w.code)
|
||||
}
|
||||
|
||||
func (w *statusRecorder) Flush() {
|
||||
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user