mirror of
https://github.com/ollama/ollama.git
synced 2026-02-23 10:45:08 -05:00
Compare commits
2 Commits
jessegross
...
brucemacd/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
365a3657ad | ||
|
|
71c1d8d0a9 |
@@ -1 +1 @@
|
||||
v0.4.1
|
||||
v0.5.0
|
||||
|
||||
22
auth/auth.go
22
auth/auth.go
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -83,3 +84,24 @@ func Sign(ctx context.Context, bts []byte) (string, error) {
|
||||
// signature is <pubkey>:<signature>
|
||||
return fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)), nil
|
||||
}
|
||||
|
||||
// SignRequest adds a nonce query parameter and an Authorization header with
|
||||
// an Ed25519 signature to req.
|
||||
func SignRequest(ctx context.Context, req *http.Request) error {
|
||||
nonce, err := NewNonce(rand.Reader, 16)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
q := req.URL.Query()
|
||||
q.Set("nonce", nonce)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
|
||||
data := []byte(fmt.Sprintf("%s,%s", req.Method, req.URL.RequestURI()))
|
||||
signature, err := Sign(ctx, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Authorization", signature)
|
||||
return nil
|
||||
}
|
||||
|
||||
28
cmd/cmd.go
28
cmd/cmd.go
@@ -1900,6 +1900,21 @@ func runInteractiveTUI(cmd *cobra.Command) {
|
||||
return
|
||||
}
|
||||
|
||||
if version.Version != "0.0.0" && version.IsOfficialInstall() && version.IsLocalHost(envconfig.Host()) {
|
||||
if version.HasCachedUpdate() {
|
||||
fmt.Print("A new version of Ollama is available. Run \"ollama update\" to install.\n\n")
|
||||
_ = version.ClearCachedUpdate()
|
||||
}
|
||||
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if available, err := version.CheckForUpdate(ctx); err == nil && available {
|
||||
_ = version.CacheAvailableUpdate()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Selector adapters for tui
|
||||
singleSelector := func(title string, items []config.ModelItem, current string) (string, error) {
|
||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||
@@ -2317,6 +2332,18 @@ func NewCLI() *cobra.Command {
|
||||
}
|
||||
}
|
||||
|
||||
updateCmd := &cobra.Command{
|
||||
Use: "update",
|
||||
Short: "Update Ollama to the latest version",
|
||||
Args: cobra.ExactArgs(0),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
force, _ := cmd.Flags().GetBool("force")
|
||||
_ = version.ClearCachedUpdate()
|
||||
return version.DoUpdate(force)
|
||||
},
|
||||
}
|
||||
updateCmd.Flags().BoolP("force", "f", false, "Force update even if installed via a package manager")
|
||||
|
||||
rootCmd.AddCommand(
|
||||
serveCmd,
|
||||
createCmd,
|
||||
@@ -2334,6 +2361,7 @@ func NewCLI() *cobra.Command {
|
||||
copyCmd,
|
||||
deleteCmd,
|
||||
runnerCmd,
|
||||
updateCmd,
|
||||
config.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
|
||||
)
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"golang.org/x/mod/semver"
|
||||
)
|
||||
|
||||
@@ -33,10 +32,6 @@ func (c *Codex) Run(model string, args []string) error {
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Env = append(os.Environ(),
|
||||
"OPENAI_BASE_URL="+envconfig.Host().String()+"/v1/",
|
||||
"OPENAI_API_KEY=ollama",
|
||||
)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
|
||||
1
go.mod
1
go.mod
@@ -26,7 +26,6 @@ require (
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||
github.com/dlclark/regexp2 v1.11.4
|
||||
github.com/emirpasic/gods/v2 v2.0.0-alpha
|
||||
github.com/klauspost/compress v1.18.3
|
||||
github.com/mattn/go-runewidth v0.0.16
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
|
||||
4
go.sum
4
go.sum
@@ -122,6 +122,7 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
|
||||
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=
|
||||
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/flatbuffers v2.0.0+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
|
||||
github.com/google/flatbuffers v24.3.25+incompatible h1:CX395cjN9Kke9mmalRoL3d81AtFUxJM+yDthflgJGkI=
|
||||
@@ -149,9 +150,8 @@ github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+
|
||||
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.13.1 h1:wXr2uRxZTJXHLly6qhJabee5JqIhTRoLBhDOA74hDEQ=
|
||||
github.com/klauspost/compress v1.13.1/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg=
|
||||
github.com/klauspost/compress v1.18.3 h1:9PJRvfbmTabkOX8moIpXPbMMbYN60bWImDDU7L+/6zw=
|
||||
github.com/klauspost/compress v1.18.3/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/openai"
|
||||
@@ -497,17 +496,6 @@ func (w *ResponsesWriter) Write(data []byte) (int, error) {
|
||||
|
||||
func ResponsesMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.GetHeader("Content-Encoding") == "zstd" {
|
||||
reader, err := zstd.NewReader(c.Request.Body, zstd.WithDecoderMaxMemory(8<<20))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "failed to decompress zstd body"))
|
||||
return
|
||||
}
|
||||
defer reader.Close()
|
||||
c.Request.Body = io.NopCloser(reader)
|
||||
c.Request.Header.Del("Content-Encoding")
|
||||
}
|
||||
|
||||
var req openai.ResponsesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/openai"
|
||||
@@ -1239,102 +1238,3 @@ func TestImageEditsMiddleware(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func zstdCompress(t *testing.T, data []byte) []byte {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
w, err := zstd.NewWriter(&buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := w.Write(data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func TestResponsesMiddlewareZstd(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
useZstd bool
|
||||
oversized bool
|
||||
wantCode int
|
||||
wantModel string
|
||||
wantMessage string
|
||||
}{
|
||||
{
|
||||
name: "plain JSON",
|
||||
body: `{"model": "test-model", "input": "Hello"}`,
|
||||
wantCode: http.StatusOK,
|
||||
wantModel: "test-model",
|
||||
wantMessage: "Hello",
|
||||
},
|
||||
{
|
||||
name: "zstd compressed",
|
||||
body: `{"model": "test-model", "input": "Hello"}`,
|
||||
useZstd: true,
|
||||
wantCode: http.StatusOK,
|
||||
wantModel: "test-model",
|
||||
wantMessage: "Hello",
|
||||
},
|
||||
{
|
||||
name: "zstd over max decompressed size",
|
||||
oversized: true,
|
||||
useZstd: true,
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var capturedRequest *api.ChatRequest
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(ResponsesMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/v1/responses", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
var bodyReader io.Reader
|
||||
if tt.oversized {
|
||||
bodyReader = bytes.NewReader(zstdCompress(t, bytes.Repeat([]byte("A"), 9<<20)))
|
||||
} else if tt.useZstd {
|
||||
bodyReader = bytes.NewReader(zstdCompress(t, []byte(tt.body)))
|
||||
} else {
|
||||
bodyReader = strings.NewReader(tt.body)
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/responses", bodyReader)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if tt.useZstd || tt.oversized {
|
||||
req.Header.Set("Content-Encoding", "zstd")
|
||||
}
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != tt.wantCode {
|
||||
t.Fatalf("expected status %d, got %d: %s", tt.wantCode, resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
if tt.wantCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
|
||||
if capturedRequest == nil {
|
||||
t.Fatal("expected captured request, got nil")
|
||||
}
|
||||
if capturedRequest.Model != tt.wantModel {
|
||||
t.Fatalf("expected model %q, got %q", tt.wantModel, capturedRequest.Model)
|
||||
}
|
||||
if len(capturedRequest.Messages) != 1 || capturedRequest.Messages[0].Content != tt.wantMessage {
|
||||
t.Fatalf("expected single user message %q, got %+v", tt.wantMessage, capturedRequest.Messages)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,10 +2,6 @@
|
||||
# This script installs Ollama on Linux and macOS.
|
||||
# It detects the current operating system architecture and installs the appropriate version of Ollama.
|
||||
|
||||
# Wrap script in main function so that a truncated partial download doesn't end
|
||||
# up executing half a script.
|
||||
main() {
|
||||
|
||||
set -eu
|
||||
|
||||
red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)"
|
||||
@@ -450,6 +446,3 @@ fi
|
||||
|
||||
status "NVIDIA GPU ready."
|
||||
install_success
|
||||
}
|
||||
|
||||
main
|
||||
|
||||
190
version/update.go
Normal file
190
version/update.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package version
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/auth"
|
||||
)
|
||||
|
||||
var updateCheckURLBase = "https://ollama.com"
|
||||
|
||||
// CheckForUpdate calls the ollama.com update API and reports whether a
|
||||
// newer version is available.
|
||||
func CheckForUpdate(ctx context.Context) (bool, error) {
|
||||
requestURL, err := url.Parse(updateCheckURLBase + "/api/update")
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parse update URL: %w", err)
|
||||
}
|
||||
|
||||
query := requestURL.Query()
|
||||
query.Add("os", runtime.GOOS)
|
||||
query.Add("arch", runtime.GOARCH)
|
||||
query.Add("version", Version)
|
||||
requestURL.RawQuery = query.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
_ = auth.SignRequest(ctx, req)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("update check request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return resp.StatusCode == http.StatusOK, nil
|
||||
}
|
||||
|
||||
func cacheFilePath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".ollama", "update"), nil
|
||||
}
|
||||
|
||||
// CacheAvailableUpdate creates the update marker file.
|
||||
func CacheAvailableUpdate() error {
|
||||
path, err := cacheFilePath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return f.Close()
|
||||
}
|
||||
|
||||
// HasCachedUpdate reports whether a non-stale update marker exists.
|
||||
func HasCachedUpdate() bool {
|
||||
path, err := cacheFilePath()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
fi, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return time.Since(fi.ModTime()) <= 24*time.Hour
|
||||
}
|
||||
|
||||
// ClearCachedUpdate removes the update marker file.
|
||||
func ClearCachedUpdate() error {
|
||||
path, err := cacheFilePath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = os.Remove(path)
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func IsOfficialInstall() bool {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
exe, err = filepath.EvalSymlinks(exe)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
localAppData := os.Getenv("LOCALAPPDATA")
|
||||
if localAppData == "" {
|
||||
return false
|
||||
}
|
||||
return strings.HasPrefix(strings.ToLower(exe), strings.ToLower(filepath.Join(localAppData, "Programs", "Ollama")+string(filepath.Separator)))
|
||||
case "darwin":
|
||||
return strings.HasPrefix(exe, "/Applications/Ollama.app/")
|
||||
default:
|
||||
dir := filepath.Dir(exe)
|
||||
return dir == "/usr/local/bin" || dir == "/usr/bin" || dir == "/bin"
|
||||
}
|
||||
}
|
||||
|
||||
// DoUpdate downloads and runs the platform-appropriate install script.
|
||||
func DoUpdate(force bool) error {
|
||||
if !force && !IsOfficialInstall() {
|
||||
return fmt.Errorf("ollama appears to be installed through a package manager. Please update it using your package manager")
|
||||
}
|
||||
|
||||
var scriptURL, tmpPattern, shell string
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
scriptURL = "https://ollama.com/install.ps1"
|
||||
tmpPattern = "ollama-install-*.ps1"
|
||||
shell = "powershell"
|
||||
default:
|
||||
scriptURL = "https://ollama.com/install.sh"
|
||||
tmpPattern = "ollama-install-*.sh"
|
||||
shell = "sh"
|
||||
}
|
||||
|
||||
resp, err := http.Get(scriptURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("download install script: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("download install script: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
tmpFile, err := os.CreateTemp("", tmpPattern)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create temp file: %w", err)
|
||||
}
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
if _, err := io.Copy(tmpFile, resp.Body); err != nil {
|
||||
tmpFile.Close()
|
||||
return fmt.Errorf("write install script: %w", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
|
||||
cmd := exec.Command(shell, tmpFile.Name())
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// IsLocalHost reports whether the configured Ollama host points to the
|
||||
// local machine.
|
||||
func IsLocalHost(host *url.URL) bool {
|
||||
hostname := host.Hostname()
|
||||
switch hostname {
|
||||
case "", "127.0.0.1", "localhost", "::1", "0.0.0.0":
|
||||
return true
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(hostname); ip != nil {
|
||||
return ip.IsLoopback()
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
146
version/update_test.go
Normal file
146
version/update_test.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package version
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func setHome(t *testing.T, dir string) {
|
||||
t.Helper()
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Setenv("USERPROFILE", dir)
|
||||
} else {
|
||||
t.Setenv("HOME", dir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckForUpdate(t *testing.T) {
|
||||
t.Run("update available", func(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Query().Get("os") == "" || r.URL.Query().Get("arch") == "" || r.URL.Query().Get("version") == "" {
|
||||
t.Error("missing expected query parameters")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
old := updateCheckURLBase
|
||||
updateCheckURLBase = ts.URL
|
||||
defer func() { updateCheckURLBase = old }()
|
||||
|
||||
available, err := CheckForUpdate(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !available {
|
||||
t.Fatal("expected update to be available")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("up to date", func(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
old := updateCheckURLBase
|
||||
updateCheckURLBase = ts.URL
|
||||
defer func() { updateCheckURLBase = old }()
|
||||
|
||||
available, err := CheckForUpdate(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if available {
|
||||
t.Fatal("expected no update available")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("network error", func(t *testing.T) {
|
||||
old := updateCheckURLBase
|
||||
updateCheckURLBase = "http://localhost:1"
|
||||
defer func() { updateCheckURLBase = old }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err := CheckForUpdate(ctx)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unreachable server")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCacheRoundTrip(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
setHome(t, tmp)
|
||||
os.MkdirAll(filepath.Join(tmp, ".ollama"), 0o755)
|
||||
|
||||
if err := CacheAvailableUpdate(); err != nil {
|
||||
t.Fatalf("cache write: %v", err)
|
||||
}
|
||||
|
||||
if !HasCachedUpdate() {
|
||||
t.Fatal("expected cached update to be present")
|
||||
}
|
||||
|
||||
if err := ClearCachedUpdate(); err != nil {
|
||||
t.Fatalf("cache clear: %v", err)
|
||||
}
|
||||
|
||||
if HasCachedUpdate() {
|
||||
t.Fatal("expected no cached update after clear")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasCachedUpdateStale(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
setHome(t, tmp)
|
||||
os.MkdirAll(filepath.Join(tmp, ".ollama"), 0o755)
|
||||
|
||||
if err := CacheAvailableUpdate(); err != nil {
|
||||
t.Fatalf("cache write: %v", err)
|
||||
}
|
||||
|
||||
// Backdate the file to make it stale
|
||||
path := filepath.Join(tmp, ".ollama", "update")
|
||||
staleTime := time.Now().Add(-25 * time.Hour)
|
||||
os.Chtimes(path, staleTime, staleTime)
|
||||
|
||||
if HasCachedUpdate() {
|
||||
t.Fatal("expected no cached update for stale file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsLocalHost(t *testing.T) {
|
||||
tests := []struct {
|
||||
host string
|
||||
local bool
|
||||
}{
|
||||
{"http://127.0.0.1:11434", true},
|
||||
{"http://localhost:11434", true},
|
||||
{"http://[::1]:11434", true},
|
||||
{"http://0.0.0.0:11434", true},
|
||||
{"http://remote.example.com:11434", false},
|
||||
{"http://192.168.1.100:11434", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.host, func(t *testing.T) {
|
||||
u, err := url.Parse(tt.host)
|
||||
if err != nil {
|
||||
t.Fatalf("parse URL: %v", err)
|
||||
}
|
||||
if got := IsLocalHost(u); got != tt.local {
|
||||
t.Errorf("IsLocalHost(%s) = %v, want %v", tt.host, got, tt.local)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -16,10 +16,10 @@ import (
|
||||
)
|
||||
|
||||
type Function struct {
|
||||
Name string
|
||||
ReturnType string
|
||||
Params string
|
||||
ParamNames []string
|
||||
Name string
|
||||
ReturnType string
|
||||
Params string
|
||||
ParamNames []string
|
||||
NeedsARM64Guard bool
|
||||
}
|
||||
|
||||
@@ -29,6 +29,11 @@ func findHeaders(directory string) ([]string, error) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Private headers contain C++ implementation helpers and are not part of
|
||||
// the C API surface; parsing them can produce invalid wrapper signatures.
|
||||
if d.IsDir() && d.Name() == "private" {
|
||||
return fs.SkipDir
|
||||
}
|
||||
if !d.IsDir() && strings.HasSuffix(path, ".h") {
|
||||
headers = append(headers, path)
|
||||
}
|
||||
@@ -194,10 +199,10 @@ func parseFunctions(content string) []Function {
|
||||
needsGuard := needsARM64Guard(funcName, returnType, params)
|
||||
|
||||
functions = append(functions, Function{
|
||||
Name: funcName,
|
||||
ReturnType: returnType,
|
||||
Params: params,
|
||||
ParamNames: paramNames,
|
||||
Name: funcName,
|
||||
ReturnType: returnType,
|
||||
Params: params,
|
||||
ParamNames: paramNames,
|
||||
NeedsARM64Guard: needsGuard,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -20,6 +20,8 @@ mlx_array (*mlx_array_new_float64_ptr)(double val) = NULL;
|
||||
mlx_array (*mlx_array_new_double_ptr)(double val) = NULL;
|
||||
mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val) = NULL;
|
||||
mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype) = NULL;
|
||||
mlx_array (*mlx_array_new_data_managed_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) = NULL;
|
||||
mlx_array (*mlx_array_new_data_managed_payload_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) = NULL;
|
||||
int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src) = NULL;
|
||||
int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val) = NULL;
|
||||
int (*mlx_array_set_int_ptr)(mlx_array* arr, int val) = NULL;
|
||||
@@ -49,7 +51,7 @@ int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_complex64_ptr)(float _Complex* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_complex64_ptr)(mlx_complex64_t* res, const mlx_array arr) = NULL;
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr) = NULL;
|
||||
#endif
|
||||
@@ -67,7 +69,7 @@ const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr) = NULL;
|
||||
const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr) = NULL;
|
||||
const float* (*mlx_array_data_float32_ptr)(const mlx_array arr) = NULL;
|
||||
const double* (*mlx_array_data_float64_ptr)(const mlx_array arr) = NULL;
|
||||
const float _Complex* (*mlx_array_data_complex64_ptr)(const mlx_array arr) = NULL;
|
||||
const mlx_complex64_t* (*mlx_array_data_complex64_ptr)(const mlx_array arr) = NULL;
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr) = NULL;
|
||||
#endif
|
||||
@@ -123,6 +125,7 @@ int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id) = NULL;
|
||||
int (*mlx_disable_compile_ptr)(void) = NULL;
|
||||
int (*mlx_enable_compile_ptr)(void) = NULL;
|
||||
int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode) = NULL;
|
||||
int (*mlx_cuda_is_available_ptr)(bool* res) = NULL;
|
||||
mlx_device (*mlx_device_new_ptr)(void) = NULL;
|
||||
mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index) = NULL;
|
||||
int (*mlx_device_free_ptr)(mlx_device dev) = NULL;
|
||||
@@ -133,6 +136,16 @@ int (*mlx_device_get_index_ptr)(int* index, mlx_device dev) = NULL;
|
||||
int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev) = NULL;
|
||||
int (*mlx_get_default_device_ptr)(mlx_device* dev) = NULL;
|
||||
int (*mlx_set_default_device_ptr)(mlx_device dev) = NULL;
|
||||
int (*mlx_device_is_available_ptr)(bool* avail, mlx_device dev) = NULL;
|
||||
int (*mlx_device_count_ptr)(int* count, mlx_device_type type) = NULL;
|
||||
mlx_device_info (*mlx_device_info_new_ptr)(void) = NULL;
|
||||
int (*mlx_device_info_get_ptr)(mlx_device_info* info, mlx_device dev) = NULL;
|
||||
int (*mlx_device_info_free_ptr)(mlx_device_info info) = NULL;
|
||||
int (*mlx_device_info_has_key_ptr)(bool* exists, mlx_device_info info, const char* key) = NULL;
|
||||
int (*mlx_device_info_is_string_ptr)(bool* is_string, mlx_device_info info, const char* key) = NULL;
|
||||
int (*mlx_device_info_get_string_ptr)(const char** value, mlx_device_info info, const char* key) = NULL;
|
||||
int (*mlx_device_info_get_size_ptr)(size_t* value, mlx_device_info info, const char* key) = NULL;
|
||||
int (*mlx_device_info_get_keys_ptr)(mlx_vector_string* keys, mlx_device_info info) = NULL;
|
||||
int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) = NULL;
|
||||
int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL;
|
||||
int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL;
|
||||
@@ -263,7 +276,6 @@ int (*mlx_reset_peak_memory_ptr)(void) = NULL;
|
||||
int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit) = NULL;
|
||||
int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit) = NULL;
|
||||
int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit) = NULL;
|
||||
mlx_metal_device_info_t (*mlx_metal_device_info_ptr)(void) = NULL;
|
||||
int (*mlx_metal_is_available_ptr)(bool* res) = NULL;
|
||||
int (*mlx_metal_start_capture_ptr)(const char* path) = NULL;
|
||||
int (*mlx_metal_stop_capture_ptr)(void) = NULL;
|
||||
@@ -658,6 +670,16 @@ int mlx_load_functions(void* handle) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_array_new_data_managed_ptr = dlsym(handle, "mlx_array_new_data_managed");
|
||||
if (mlx_array_new_data_managed_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_array_new_data_managed_payload_ptr = dlsym(handle, "mlx_array_new_data_managed_payload");
|
||||
if (mlx_array_new_data_managed_payload_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed_payload\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_array_set_ptr = dlsym(handle, "mlx_array_set");
|
||||
if (mlx_array_set_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set\n");
|
||||
@@ -1141,6 +1163,11 @@ int mlx_load_functions(void* handle) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_compile_mode\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_cuda_is_available_ptr = dlsym(handle, "mlx_cuda_is_available");
|
||||
if (mlx_cuda_is_available_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_cuda_is_available\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_new_ptr = dlsym(handle, "mlx_device_new");
|
||||
if (mlx_device_new_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_new\n");
|
||||
@@ -1191,6 +1218,56 @@ int mlx_load_functions(void* handle) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_default_device\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_is_available_ptr = dlsym(handle, "mlx_device_is_available");
|
||||
if (mlx_device_is_available_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_is_available\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_count_ptr = dlsym(handle, "mlx_device_count");
|
||||
if (mlx_device_count_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_count\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_new_ptr = dlsym(handle, "mlx_device_info_new");
|
||||
if (mlx_device_info_new_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_new\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_get_ptr = dlsym(handle, "mlx_device_info_get");
|
||||
if (mlx_device_info_get_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_free_ptr = dlsym(handle, "mlx_device_info_free");
|
||||
if (mlx_device_info_free_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_free\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_has_key_ptr = dlsym(handle, "mlx_device_info_has_key");
|
||||
if (mlx_device_info_has_key_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_has_key\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_is_string_ptr = dlsym(handle, "mlx_device_info_is_string");
|
||||
if (mlx_device_info_is_string_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_is_string\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_get_string_ptr = dlsym(handle, "mlx_device_info_get_string");
|
||||
if (mlx_device_info_get_string_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_string\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_get_size_ptr = dlsym(handle, "mlx_device_info_get_size");
|
||||
if (mlx_device_info_get_size_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_size\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_get_keys_ptr = dlsym(handle, "mlx_device_info_get_keys");
|
||||
if (mlx_device_info_get_keys_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_keys\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_distributed_all_gather_ptr = dlsym(handle, "mlx_distributed_all_gather");
|
||||
if (mlx_distributed_all_gather_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_gather\n");
|
||||
@@ -1841,11 +1918,6 @@ int mlx_load_functions(void* handle) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_wired_limit\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_metal_device_info_ptr = dlsym(handle, "mlx_metal_device_info");
|
||||
if (mlx_metal_device_info_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_device_info\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_metal_is_available_ptr = dlsym(handle, "mlx_metal_is_available");
|
||||
if (mlx_metal_is_available_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_is_available\n");
|
||||
@@ -3528,6 +3600,14 @@ mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dt
|
||||
return mlx_array_new_data_ptr(data, shape, dim, dtype);
|
||||
}
|
||||
|
||||
mlx_array mlx_array_new_data_managed(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) {
|
||||
return mlx_array_new_data_managed_ptr(data, shape, dim, dtype, dtor);
|
||||
}
|
||||
|
||||
mlx_array mlx_array_new_data_managed_payload(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) {
|
||||
return mlx_array_new_data_managed_payload_ptr(data, shape, dim, dtype, payload, dtor);
|
||||
}
|
||||
|
||||
int mlx_array_set(mlx_array* arr, const mlx_array src) {
|
||||
return mlx_array_set_ptr(arr, src);
|
||||
}
|
||||
@@ -3644,7 +3724,7 @@ int mlx_array_item_float64(double* res, const mlx_array arr) {
|
||||
return mlx_array_item_float64_ptr(res, arr);
|
||||
}
|
||||
|
||||
int mlx_array_item_complex64(float _Complex* res, const mlx_array arr) {
|
||||
int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr) {
|
||||
return mlx_array_item_complex64_ptr(res, arr);
|
||||
}
|
||||
|
||||
@@ -3704,7 +3784,7 @@ const double* mlx_array_data_float64(const mlx_array arr) {
|
||||
return mlx_array_data_float64_ptr(arr);
|
||||
}
|
||||
|
||||
const float _Complex* mlx_array_data_complex64(const mlx_array arr) {
|
||||
const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr) {
|
||||
return mlx_array_data_complex64_ptr(arr);
|
||||
}
|
||||
|
||||
@@ -3916,6 +3996,10 @@ int mlx_set_compile_mode(mlx_compile_mode mode) {
|
||||
return mlx_set_compile_mode_ptr(mode);
|
||||
}
|
||||
|
||||
int mlx_cuda_is_available(bool* res) {
|
||||
return mlx_cuda_is_available_ptr(res);
|
||||
}
|
||||
|
||||
mlx_device mlx_device_new(void) {
|
||||
return mlx_device_new_ptr();
|
||||
}
|
||||
@@ -3956,6 +4040,46 @@ int mlx_set_default_device(mlx_device dev) {
|
||||
return mlx_set_default_device_ptr(dev);
|
||||
}
|
||||
|
||||
int mlx_device_is_available(bool* avail, mlx_device dev) {
|
||||
return mlx_device_is_available_ptr(avail, dev);
|
||||
}
|
||||
|
||||
int mlx_device_count(int* count, mlx_device_type type) {
|
||||
return mlx_device_count_ptr(count, type);
|
||||
}
|
||||
|
||||
mlx_device_info mlx_device_info_new(void) {
|
||||
return mlx_device_info_new_ptr();
|
||||
}
|
||||
|
||||
int mlx_device_info_get(mlx_device_info* info, mlx_device dev) {
|
||||
return mlx_device_info_get_ptr(info, dev);
|
||||
}
|
||||
|
||||
int mlx_device_info_free(mlx_device_info info) {
|
||||
return mlx_device_info_free_ptr(info);
|
||||
}
|
||||
|
||||
int mlx_device_info_has_key(bool* exists, mlx_device_info info, const char* key) {
|
||||
return mlx_device_info_has_key_ptr(exists, info, key);
|
||||
}
|
||||
|
||||
int mlx_device_info_is_string(bool* is_string, mlx_device_info info, const char* key) {
|
||||
return mlx_device_info_is_string_ptr(is_string, info, key);
|
||||
}
|
||||
|
||||
int mlx_device_info_get_string(const char** value, mlx_device_info info, const char* key) {
|
||||
return mlx_device_info_get_string_ptr(value, info, key);
|
||||
}
|
||||
|
||||
int mlx_device_info_get_size(size_t* value, mlx_device_info info, const char* key) {
|
||||
return mlx_device_info_get_size_ptr(value, info, key);
|
||||
}
|
||||
|
||||
int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info) {
|
||||
return mlx_device_info_get_keys_ptr(keys, info);
|
||||
}
|
||||
|
||||
int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) {
|
||||
return mlx_distributed_all_gather_ptr(res, x, group, S);
|
||||
}
|
||||
@@ -4476,10 +4600,6 @@ int mlx_set_wired_limit(size_t* res, size_t limit) {
|
||||
return mlx_set_wired_limit_ptr(res, limit);
|
||||
}
|
||||
|
||||
mlx_metal_device_info_t mlx_metal_device_info(void) {
|
||||
return mlx_metal_device_info_ptr();
|
||||
}
|
||||
|
||||
int mlx_metal_is_available(bool* res) {
|
||||
return mlx_metal_is_available_ptr(res);
|
||||
}
|
||||
|
||||
@@ -26,6 +26,8 @@
|
||||
#undef mlx_array_new_double
|
||||
#undef mlx_array_new_complex
|
||||
#undef mlx_array_new_data
|
||||
#undef mlx_array_new_data_managed
|
||||
#undef mlx_array_new_data_managed_payload
|
||||
#undef mlx_array_set
|
||||
#undef mlx_array_set_bool
|
||||
#undef mlx_array_set_int
|
||||
@@ -121,6 +123,7 @@
|
||||
#undef mlx_disable_compile
|
||||
#undef mlx_enable_compile
|
||||
#undef mlx_set_compile_mode
|
||||
#undef mlx_cuda_is_available
|
||||
#undef mlx_device_new
|
||||
#undef mlx_device_new_type
|
||||
#undef mlx_device_free
|
||||
@@ -131,6 +134,16 @@
|
||||
#undef mlx_device_get_type
|
||||
#undef mlx_get_default_device
|
||||
#undef mlx_set_default_device
|
||||
#undef mlx_device_is_available
|
||||
#undef mlx_device_count
|
||||
#undef mlx_device_info_new
|
||||
#undef mlx_device_info_get
|
||||
#undef mlx_device_info_free
|
||||
#undef mlx_device_info_has_key
|
||||
#undef mlx_device_info_is_string
|
||||
#undef mlx_device_info_get_string
|
||||
#undef mlx_device_info_get_size
|
||||
#undef mlx_device_info_get_keys
|
||||
#undef mlx_distributed_all_gather
|
||||
#undef mlx_distributed_all_max
|
||||
#undef mlx_distributed_all_min
|
||||
@@ -261,7 +274,6 @@
|
||||
#undef mlx_set_cache_limit
|
||||
#undef mlx_set_memory_limit
|
||||
#undef mlx_set_wired_limit
|
||||
#undef mlx_metal_device_info
|
||||
#undef mlx_metal_is_available
|
||||
#undef mlx_metal_start_capture
|
||||
#undef mlx_metal_stop_capture
|
||||
@@ -602,6 +614,8 @@ extern mlx_array (*mlx_array_new_float64_ptr)(double val);
|
||||
extern mlx_array (*mlx_array_new_double_ptr)(double val);
|
||||
extern mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val);
|
||||
extern mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype);
|
||||
extern mlx_array (*mlx_array_new_data_managed_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*));
|
||||
extern mlx_array (*mlx_array_new_data_managed_payload_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*));
|
||||
extern int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src);
|
||||
extern int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val);
|
||||
extern int (*mlx_array_set_int_ptr)(mlx_array* arr, int val);
|
||||
@@ -631,7 +645,7 @@ extern int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr);
|
||||
extern int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr);
|
||||
extern int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr);
|
||||
extern int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr);
|
||||
extern int (*mlx_array_item_complex64_ptr)(float _Complex* res, const mlx_array arr);
|
||||
extern int (*mlx_array_item_complex64_ptr)(mlx_complex64_t* res, const mlx_array arr);
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
extern int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr);
|
||||
#endif
|
||||
@@ -649,7 +663,7 @@ extern const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr);
|
||||
extern const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr);
|
||||
extern const float* (*mlx_array_data_float32_ptr)(const mlx_array arr);
|
||||
extern const double* (*mlx_array_data_float64_ptr)(const mlx_array arr);
|
||||
extern const float _Complex* (*mlx_array_data_complex64_ptr)(const mlx_array arr);
|
||||
extern const mlx_complex64_t* (*mlx_array_data_complex64_ptr)(const mlx_array arr);
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
extern const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr);
|
||||
#endif
|
||||
@@ -705,6 +719,7 @@ extern int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id);
|
||||
extern int (*mlx_disable_compile_ptr)(void);
|
||||
extern int (*mlx_enable_compile_ptr)(void);
|
||||
extern int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode);
|
||||
extern int (*mlx_cuda_is_available_ptr)(bool* res);
|
||||
extern mlx_device (*mlx_device_new_ptr)(void);
|
||||
extern mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index);
|
||||
extern int (*mlx_device_free_ptr)(mlx_device dev);
|
||||
@@ -715,6 +730,16 @@ extern int (*mlx_device_get_index_ptr)(int* index, mlx_device dev);
|
||||
extern int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev);
|
||||
extern int (*mlx_get_default_device_ptr)(mlx_device* dev);
|
||||
extern int (*mlx_set_default_device_ptr)(mlx_device dev);
|
||||
extern int (*mlx_device_is_available_ptr)(bool* avail, mlx_device dev);
|
||||
extern int (*mlx_device_count_ptr)(int* count, mlx_device_type type);
|
||||
extern mlx_device_info (*mlx_device_info_new_ptr)(void);
|
||||
extern int (*mlx_device_info_get_ptr)(mlx_device_info* info, mlx_device dev);
|
||||
extern int (*mlx_device_info_free_ptr)(mlx_device_info info);
|
||||
extern int (*mlx_device_info_has_key_ptr)(bool* exists, mlx_device_info info, const char* key);
|
||||
extern int (*mlx_device_info_is_string_ptr)(bool* is_string, mlx_device_info info, const char* key);
|
||||
extern int (*mlx_device_info_get_string_ptr)(const char** value, mlx_device_info info, const char* key);
|
||||
extern int (*mlx_device_info_get_size_ptr)(size_t* value, mlx_device_info info, const char* key);
|
||||
extern int (*mlx_device_info_get_keys_ptr)(mlx_vector_string* keys, mlx_device_info info);
|
||||
extern int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S);
|
||||
extern int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
|
||||
extern int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
|
||||
@@ -845,7 +870,6 @@ extern int (*mlx_reset_peak_memory_ptr)(void);
|
||||
extern int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit);
|
||||
extern int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit);
|
||||
extern int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit);
|
||||
extern mlx_metal_device_info_t (*mlx_metal_device_info_ptr)(void);
|
||||
extern int (*mlx_metal_is_available_ptr)(bool* res);
|
||||
extern int (*mlx_metal_start_capture_ptr)(const char* path);
|
||||
extern int (*mlx_metal_stop_capture_ptr)(void);
|
||||
@@ -1202,6 +1226,10 @@ mlx_array mlx_array_new_complex(float real_val, float imag_val);
|
||||
|
||||
mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dtype dtype);
|
||||
|
||||
mlx_array mlx_array_new_data_managed(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*));
|
||||
|
||||
mlx_array mlx_array_new_data_managed_payload(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*));
|
||||
|
||||
int mlx_array_set(mlx_array* arr, const mlx_array src);
|
||||
|
||||
int mlx_array_set_bool(mlx_array* arr, bool val);
|
||||
@@ -1260,7 +1288,7 @@ int mlx_array_item_float32(float* res, const mlx_array arr);
|
||||
|
||||
int mlx_array_item_float64(double* res, const mlx_array arr);
|
||||
|
||||
int mlx_array_item_complex64(float _Complex* res, const mlx_array arr);
|
||||
int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr);
|
||||
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
int mlx_array_item_float16(float16_t* res, const mlx_array arr);
|
||||
@@ -1292,7 +1320,7 @@ const float* mlx_array_data_float32(const mlx_array arr);
|
||||
|
||||
const double* mlx_array_data_float64(const mlx_array arr);
|
||||
|
||||
const float _Complex* mlx_array_data_complex64(const mlx_array arr);
|
||||
const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr);
|
||||
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
const float16_t* mlx_array_data_float16(const mlx_array arr);
|
||||
@@ -1400,6 +1428,8 @@ int mlx_enable_compile(void);
|
||||
|
||||
int mlx_set_compile_mode(mlx_compile_mode mode);
|
||||
|
||||
int mlx_cuda_is_available(bool* res);
|
||||
|
||||
mlx_device mlx_device_new(void);
|
||||
|
||||
mlx_device mlx_device_new_type(mlx_device_type type, int index);
|
||||
@@ -1420,6 +1450,26 @@ int mlx_get_default_device(mlx_device* dev);
|
||||
|
||||
int mlx_set_default_device(mlx_device dev);
|
||||
|
||||
int mlx_device_is_available(bool* avail, mlx_device dev);
|
||||
|
||||
int mlx_device_count(int* count, mlx_device_type type);
|
||||
|
||||
mlx_device_info mlx_device_info_new(void);
|
||||
|
||||
int mlx_device_info_get(mlx_device_info* info, mlx_device dev);
|
||||
|
||||
int mlx_device_info_free(mlx_device_info info);
|
||||
|
||||
int mlx_device_info_has_key(bool* exists, mlx_device_info info, const char* key);
|
||||
|
||||
int mlx_device_info_is_string(bool* is_string, mlx_device_info info, const char* key);
|
||||
|
||||
int mlx_device_info_get_string(const char** value, mlx_device_info info, const char* key);
|
||||
|
||||
int mlx_device_info_get_size(size_t* value, mlx_device_info info, const char* key);
|
||||
|
||||
int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info);
|
||||
|
||||
int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S);
|
||||
|
||||
int mlx_distributed_all_max(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
|
||||
@@ -1680,8 +1730,6 @@ int mlx_set_memory_limit(size_t* res, size_t limit);
|
||||
|
||||
int mlx_set_wired_limit(size_t* res, size_t limit);
|
||||
|
||||
mlx_metal_device_info_t mlx_metal_device_info(void);
|
||||
|
||||
int mlx_metal_is_available(bool* res);
|
||||
|
||||
int mlx_metal_start_capture(const char* path);
|
||||
|
||||
@@ -3,65 +3,94 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
// CacheEntry stores a single sequence
|
||||
type CacheEntry struct {
|
||||
Tokens []int32
|
||||
Caches []cache.Cache
|
||||
Caches []cache.Cache
|
||||
Count int
|
||||
Entries map[int32]*CacheEntry
|
||||
}
|
||||
|
||||
// FindNearestCache finds the longest common prefix between tokens and the cached sequence
|
||||
func (r *Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
|
||||
if r.cache == nil {
|
||||
slog.Info("Cache miss", "left", len(tokens))
|
||||
return nil, tokens
|
||||
}
|
||||
|
||||
// Find longest common prefix
|
||||
prefix := 0
|
||||
for prefix < len(tokens) && prefix < len(r.cache.Tokens) && tokens[prefix] == r.cache.Tokens[prefix] {
|
||||
prefix++
|
||||
}
|
||||
|
||||
switch {
|
||||
case prefix == 0:
|
||||
for _, c := range r.cache.Caches {
|
||||
c.Free()
|
||||
func (s Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
|
||||
current := &CacheEntry{Entries: s.CacheEntries}
|
||||
index, cacheIndex := 0, -1
|
||||
for _, token := range tokens {
|
||||
if _, ok := current.Entries[token]; !ok {
|
||||
break
|
||||
}
|
||||
r.cache = nil
|
||||
slog.Info("Cache miss", "left", len(tokens))
|
||||
return nil, tokens
|
||||
case prefix < len(r.cache.Tokens):
|
||||
trim := len(r.cache.Tokens) - prefix
|
||||
for _, c := range r.cache.Caches {
|
||||
c.Trim(trim)
|
||||
|
||||
current = current.Entries[token]
|
||||
if len(current.Caches) > 0 {
|
||||
cacheIndex = index
|
||||
}
|
||||
r.cache.Tokens = r.cache.Tokens[:prefix]
|
||||
|
||||
index += 1
|
||||
}
|
||||
|
||||
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
|
||||
return r.cache.Caches, tokens[prefix:]
|
||||
if cacheIndex == len(tokens)-1 {
|
||||
slog.Info("Cache hit", "type", "exact", "total", len(tokens), "cached", len(tokens), "left", len(tokens))
|
||||
return current.Caches, []int32{}
|
||||
} else if cacheIndex > 1 {
|
||||
slog.Info("Cache hit", "type", "partial", "total", len(tokens), "cached", cacheIndex+1, "left", len(tokens[cacheIndex+1:]))
|
||||
return current.Caches, tokens[cacheIndex+1:]
|
||||
} else if index > 0 && cacheIndex < 0 {
|
||||
type stackItem struct {
|
||||
entry *CacheEntry
|
||||
tokens []int32
|
||||
}
|
||||
|
||||
var best, item stackItem
|
||||
stack := []stackItem{{entry: current, tokens: []int32{}}}
|
||||
for len(stack) > 0 {
|
||||
item, stack = stack[len(stack)-1], stack[:len(stack)-1]
|
||||
if len(item.entry.Caches) > 0 {
|
||||
if len(best.tokens) == 0 || len(item.tokens) < len(best.tokens) {
|
||||
best = item
|
||||
}
|
||||
} else {
|
||||
for token, entry := range item.entry.Entries {
|
||||
stack = append(stack, stackItem{
|
||||
entry: entry,
|
||||
tokens: append(item.tokens, token),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prefix := min(len(tokens)-1, index)
|
||||
caches := make([]cache.Cache, len(best.entry.Caches))
|
||||
trim := len(best.tokens)+1
|
||||
for i := range caches {
|
||||
caches[i] = best.entry.Caches[i].Clone()
|
||||
caches[i].Trim(trim)
|
||||
}
|
||||
|
||||
slog.Info("Cache hit", "type", "prefix", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]), "trimmed", trim)
|
||||
return caches, tokens[prefix:]
|
||||
}
|
||||
|
||||
slog.Info("Cache miss", "left", len(tokens))
|
||||
return nil, tokens
|
||||
}
|
||||
|
||||
func (r *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
|
||||
r.cache = &CacheEntry{
|
||||
Tokens: tokens,
|
||||
Caches: caches,
|
||||
}
|
||||
}
|
||||
func (s *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
|
||||
current := &CacheEntry{Entries: s.CacheEntries}
|
||||
for _, token := range tokens {
|
||||
if _, ok := current.Entries[token]; !ok {
|
||||
current.Entries[token] = &CacheEntry{
|
||||
Entries: make(map[int32]*CacheEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CacheEntry) LogCache() {
|
||||
var totalBytes int
|
||||
for _, kv := range c.Caches {
|
||||
k, v := kv.State()
|
||||
totalBytes += k.NumBytes() + v.NumBytes()
|
||||
current = current.Entries[token]
|
||||
}
|
||||
|
||||
if len(current.Caches) > 0 {
|
||||
current.Count += 1
|
||||
} else {
|
||||
current.Caches = caches
|
||||
}
|
||||
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", c.Caches[0].Offset(), mlx.PrettyBytes(totalBytes)))
|
||||
}
|
||||
|
||||
15
x/mlxrunner/cache/cache.go
vendored
15
x/mlxrunner/cache/cache.go
vendored
@@ -13,7 +13,6 @@ type Cache interface {
|
||||
State() (keys, values *mlx.Array)
|
||||
Trim(int) int
|
||||
Clone() Cache
|
||||
Free()
|
||||
Offset() int
|
||||
Len() int
|
||||
}
|
||||
@@ -48,7 +47,6 @@ func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
c.values.Set(c.values.Concatenate(2, newValues))
|
||||
} else {
|
||||
c.keys, c.values = newKeys, newValues
|
||||
mlx.Pin(c.keys, c.values)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,19 +73,12 @@ func (c *KVCache) Trim(n int) int {
|
||||
}
|
||||
|
||||
func (c *KVCache) Clone() Cache {
|
||||
clone := &KVCache{
|
||||
return &KVCache{
|
||||
keys: c.keys.Clone(),
|
||||
values: c.values.Clone(),
|
||||
offset: c.offset,
|
||||
step: c.step,
|
||||
}
|
||||
mlx.Pin(clone.keys, clone.values)
|
||||
return clone
|
||||
}
|
||||
|
||||
func (c *KVCache) Free() {
|
||||
mlx.Unpin(c.keys, c.values)
|
||||
c.keys, c.values = nil, nil
|
||||
}
|
||||
|
||||
func (c *KVCache) Offset() int { return c.offset }
|
||||
@@ -115,8 +106,7 @@ func (c *RotatingKVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Arra
|
||||
func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV *mlx.Array) {
|
||||
slog.Debug("(*RotatingKVCache).concat", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
|
||||
if c.keys == nil {
|
||||
c.keys, c.values = keys.Clone(), values.Clone()
|
||||
mlx.Pin(c.keys, c.values)
|
||||
c.keys, c.values = keys, values
|
||||
} else {
|
||||
if c.idx < c.keys.Dim(2) {
|
||||
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
||||
@@ -155,7 +145,6 @@ func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Arra
|
||||
c.values.Set(c.values.Concatenate(2, newValues))
|
||||
} else {
|
||||
c.keys, c.values = newKeys, newValues
|
||||
mlx.Pin(c.keys, c.values)
|
||||
}
|
||||
c.idx = prev
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ set(CMAKE_INSTALL_RPATH "@loader_path")
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
set(MLX_C_GIT_TAG "v0.4.1" CACHE STRING "")
|
||||
set(MLX_C_GIT_TAG "v0.5.0" CACHE STRING "")
|
||||
|
||||
FetchContent_Declare(
|
||||
mlx-c
|
||||
|
||||
@@ -7,29 +7,48 @@ import "C"
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type Array struct {
|
||||
ctx C.mlx_array
|
||||
name string
|
||||
pinned bool
|
||||
type tensorDesc struct {
|
||||
name string
|
||||
inputs []*Array
|
||||
numRefs int
|
||||
}
|
||||
|
||||
var arrays []*Array
|
||||
func (d tensorDesc) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.String("name", d.name),
|
||||
slog.Int("inputs", len(d.inputs)),
|
||||
slog.Int("num_refs", d.numRefs),
|
||||
)
|
||||
}
|
||||
|
||||
type Array struct {
|
||||
ctx C.mlx_array
|
||||
desc tensorDesc
|
||||
}
|
||||
|
||||
// constructor utilities
|
||||
|
||||
func New(name string) *Array {
|
||||
t := &Array{name: name}
|
||||
arrays = append(arrays, t)
|
||||
func New(name string, inputs ...*Array) *Array {
|
||||
t := &Array{
|
||||
desc: tensorDesc{
|
||||
name: name,
|
||||
inputs: inputs,
|
||||
},
|
||||
}
|
||||
|
||||
for _, input := range inputs {
|
||||
input.desc.numRefs++
|
||||
}
|
||||
logutil.Trace("New", "t", t)
|
||||
return t
|
||||
}
|
||||
|
||||
@@ -114,51 +133,18 @@ func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
|
||||
}
|
||||
|
||||
func (t *Array) Set(other *Array) {
|
||||
Free(t.desc.inputs...)
|
||||
other.desc.numRefs++
|
||||
t.desc.inputs = []*Array{other}
|
||||
C.mlx_array_set(&t.ctx, other.ctx)
|
||||
}
|
||||
|
||||
func (t *Array) Clone() *Array {
|
||||
tt := New(t.name)
|
||||
tt := New(t.desc.name, t.desc.inputs...)
|
||||
C.mlx_array_set(&tt.ctx, t.ctx)
|
||||
return tt
|
||||
}
|
||||
|
||||
// lifecycle utilities
|
||||
|
||||
// Pin marks arrays as in-use so they are retained during Sweep.
|
||||
func Pin(s ...*Array) {
|
||||
for _, t := range s {
|
||||
if t != nil {
|
||||
t.pinned = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Unpin marks arrays as no longer in-use, allowing Sweep to free them.
|
||||
func Unpin(s ...*Array) {
|
||||
for _, t := range s {
|
||||
if t != nil {
|
||||
t.pinned = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sweep releases all unpinned arrays, primarily intermediate tensors. MLX will truly
|
||||
// free them when there are no other references, including dependencies in the graph.
|
||||
func Sweep() {
|
||||
n := 0
|
||||
for _, t := range arrays {
|
||||
if t.pinned && t.Valid() {
|
||||
arrays[n] = t
|
||||
n++
|
||||
} else if t.Valid() {
|
||||
C.mlx_array_free(t.ctx)
|
||||
t.ctx.ctx = nil
|
||||
}
|
||||
}
|
||||
arrays = arrays[:n]
|
||||
}
|
||||
|
||||
// misc. utilities
|
||||
|
||||
func (t *Array) Valid() bool {
|
||||
@@ -173,10 +159,7 @@ func (t *Array) String() string {
|
||||
}
|
||||
|
||||
func (t *Array) LogValue() slog.Value {
|
||||
attrs := []slog.Attr{
|
||||
slog.String("name", t.name),
|
||||
slog.Bool("pinned", t.pinned),
|
||||
}
|
||||
attrs := []slog.Attr{slog.Any("", t.desc)}
|
||||
if t.Valid() {
|
||||
attrs = append(attrs,
|
||||
slog.Any("dtype", t.DType()),
|
||||
@@ -255,15 +238,37 @@ func (t Array) Save(name string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// LogArrays logs all live arrays, sorted by size
|
||||
func LogArrays() {
|
||||
sort.Slice(arrays, func(i, j int) bool {
|
||||
return arrays[i].NumBytes() > arrays[j].NumBytes()
|
||||
})
|
||||
func Free(s ...*Array) (n int) {
|
||||
now := time.Now()
|
||||
defer func() {
|
||||
if n > 0 {
|
||||
logutil.Trace("Freed tensors", "num_bytes", PrettyBytes(n), "took", time.Since(now))
|
||||
}
|
||||
}()
|
||||
|
||||
for _, t := range arrays {
|
||||
nb := t.NumBytes()
|
||||
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s %v", t.name, t.DType(), PrettyBytes(nb), t.Dims()))
|
||||
free := make([]*Array, 0, 8192)
|
||||
fn := func(t *Array) {
|
||||
if t.Valid() {
|
||||
t.desc.numRefs--
|
||||
if t.desc.numRefs <= 0 {
|
||||
free = append(free, t.desc.inputs...)
|
||||
logutil.Trace("Free", "t", t)
|
||||
n += t.NumBytes()
|
||||
C.mlx_array_free(t.ctx)
|
||||
t.ctx.ctx = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s", len(arrays), PrettyBytes(ActiveMemory())))
|
||||
|
||||
for _, t := range s {
|
||||
fn(t)
|
||||
}
|
||||
|
||||
for len(free) > 0 {
|
||||
tail := free[len(free)-1]
|
||||
free = free[:len(free)-1]
|
||||
fn(tail)
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
@@ -55,30 +55,6 @@ func tryLoadFromDir(dir string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// tryLoadByName attempts to load the library using just its name,
|
||||
// allowing the system to use rpath, LD_LIBRARY_PATH, or standard search paths.
|
||||
// Returns true if the library was successfully loaded.
|
||||
func tryLoadByName() bool {
|
||||
libraryName := "libmlxc.dylib"
|
||||
if runtime.GOOS == "linux" {
|
||||
libraryName = "libmlxc.so"
|
||||
}
|
||||
|
||||
cPath := C.CString(libraryName)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
var handle C.mlx_dynamic_handle
|
||||
if C.mlx_dynamic_load(&handle, cPath) != 0 {
|
||||
return false
|
||||
}
|
||||
if C.mlx_dynamic_load_symbols(handle) != 0 {
|
||||
C.mlx_dynamic_unload(&handle)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func init() {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
@@ -97,11 +73,6 @@ func init() {
|
||||
}
|
||||
}
|
||||
|
||||
// Try loading via rpath/standard library search
|
||||
if tryLoadByName() {
|
||||
return
|
||||
}
|
||||
|
||||
// Build search paths: executable directory, then build directories
|
||||
var searchDirs []string
|
||||
if exe, err := os.Executable(); err == nil {
|
||||
|
||||
@@ -20,7 +20,7 @@ func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *A
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
out := New("FAST_SDPA")
|
||||
out := New("FAST_SDPA", query, key, value, mask, sinks)
|
||||
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -31,7 +31,7 @@ type LayerNorm struct {
|
||||
}
|
||||
|
||||
func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
|
||||
out := New("FAST_LAYERNORM")
|
||||
out := New("FAST_LAYERNORM", x)
|
||||
C.mlx_fast_layer_norm(&out.ctx, x.ctx, r.Weight.ctx, r.Bias.ctx, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -41,7 +41,7 @@ type RMSNorm struct {
|
||||
}
|
||||
|
||||
func (r RMSNorm) Forward(x *Array, eps float32) *Array {
|
||||
out := New("FAST_RMSNORM")
|
||||
out := New("FAST_RMSNORM", x)
|
||||
C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -55,7 +55,7 @@ type RoPE struct {
|
||||
|
||||
func (r RoPE) Forward(t *Array, offset int) *Array {
|
||||
freqs := New("")
|
||||
out := New("FAST_ROPE")
|
||||
out := New("FAST_ROPE", t, freqs)
|
||||
C.mlx_fast_rope(
|
||||
&out.ctx,
|
||||
t.ctx,
|
||||
|
||||
@@ -22,6 +22,19 @@ mlx_array (*mlx_array_new_data_)(
|
||||
const int* shape,
|
||||
int dim,
|
||||
mlx_dtype dtype) = NULL;
|
||||
mlx_array (*mlx_array_new_data_managed_)(
|
||||
void* data,
|
||||
const int* shape,
|
||||
int dim,
|
||||
mlx_dtype dtype,
|
||||
void (*dtor)(void*)) = NULL;
|
||||
mlx_array (*mlx_array_new_data_managed_payload_)(
|
||||
void* data,
|
||||
const int* shape,
|
||||
int dim,
|
||||
mlx_dtype dtype,
|
||||
void* payload,
|
||||
void (*dtor)(void*)) = NULL;
|
||||
int (*mlx_array_set_)(mlx_array* arr, const mlx_array src) = NULL;
|
||||
int (*mlx_array_set_bool_)(mlx_array* arr, bool val) = NULL;
|
||||
int (*mlx_array_set_int_)(mlx_array* arr, int val) = NULL;
|
||||
@@ -56,7 +69,7 @@ int (*mlx_array_item_int32_)(int32_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_int64_)(int64_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_float32_)(float* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_float64_)(double* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_complex64_)(float _Complex* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_complex64_)(mlx_complex64_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_float16_)(float16_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_bfloat16_)(bfloat16_t* res, const mlx_array arr) = NULL;
|
||||
const bool * (*mlx_array_data_bool_)(const mlx_array arr) = NULL;
|
||||
@@ -70,7 +83,7 @@ const int32_t * (*mlx_array_data_int32_)(const mlx_array arr) = NULL;
|
||||
const int64_t * (*mlx_array_data_int64_)(const mlx_array arr) = NULL;
|
||||
const float * (*mlx_array_data_float32_)(const mlx_array arr) = NULL;
|
||||
const double * (*mlx_array_data_float64_)(const mlx_array arr) = NULL;
|
||||
const float _Complex * (*mlx_array_data_complex64_)(const mlx_array arr) = NULL;
|
||||
const mlx_complex64_t * (*mlx_array_data_complex64_)(const mlx_array arr) = NULL;
|
||||
const float16_t * (*mlx_array_data_float16_)(const mlx_array arr) = NULL;
|
||||
const bfloat16_t * (*mlx_array_data_bfloat16_)(const mlx_array arr) = NULL;
|
||||
int (*_mlx_array_is_available_)(bool* res, const mlx_array arr) = NULL;
|
||||
@@ -94,10 +107,11 @@ int (*mlx_closure_apply_)(
|
||||
mlx_closure (*mlx_closure_new_unary_)(int (*fun)(mlx_array*, const mlx_array)) = NULL;
|
||||
mlx_closure_kwargs (*mlx_closure_kwargs_new_)(void) = NULL;
|
||||
int (*mlx_closure_kwargs_free_)(mlx_closure_kwargs cls) = NULL;
|
||||
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_map_string_to_array)) = NULL;
|
||||
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_map_string_to_array)) = NULL;
|
||||
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_payload_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
@@ -136,11 +150,12 @@ int (*mlx_closure_value_and_grad_apply_)(
|
||||
const mlx_vector_array input) = NULL;
|
||||
mlx_closure_custom (*mlx_closure_custom_new_)(void) = NULL;
|
||||
int (*mlx_closure_custom_free_)(mlx_closure_custom cls) = NULL;
|
||||
mlx_closure_custom (*mlx_closure_custom_new_func_)(int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array)) = NULL;
|
||||
mlx_closure_custom (*mlx_closure_custom_new_func_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array)) = NULL;
|
||||
mlx_closure_custom (*mlx_closure_custom_new_func_payload_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
@@ -161,12 +176,13 @@ int (*mlx_closure_custom_apply_)(
|
||||
const mlx_vector_array input_2) = NULL;
|
||||
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_)(void) = NULL;
|
||||
int (*mlx_closure_custom_jvp_free_)(mlx_closure_custom_jvp cls) = NULL;
|
||||
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num)) = NULL;
|
||||
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num)) = NULL;
|
||||
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_payload_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
@@ -189,12 +205,13 @@ int (*mlx_closure_custom_jvp_apply_)(
|
||||
size_t input_2_num) = NULL;
|
||||
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_)(void) = NULL;
|
||||
int (*mlx_closure_custom_vmap_free_)(mlx_closure_custom_vmap cls) = NULL;
|
||||
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(int (*fun)(
|
||||
mlx_vector_array*,
|
||||
mlx_vector_int*,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num)) = NULL;
|
||||
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
mlx_vector_int*,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num)) = NULL;
|
||||
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_payload_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
@@ -228,6 +245,7 @@ int (*mlx_detail_compile_erase_)(uintptr_t fun_id) = NULL;
|
||||
int (*mlx_disable_compile_)(void) = NULL;
|
||||
int (*mlx_enable_compile_)(void) = NULL;
|
||||
int (*mlx_set_compile_mode_)(mlx_compile_mode mode) = NULL;
|
||||
int (*mlx_cuda_is_available_)(bool* res) = NULL;
|
||||
mlx_device (*mlx_device_new_)(void) = NULL;
|
||||
mlx_device (*mlx_device_new_type_)(mlx_device_type type, int index) = NULL;
|
||||
int (*mlx_device_free_)(mlx_device dev) = NULL;
|
||||
@@ -238,11 +256,28 @@ int (*mlx_device_get_index_)(int* index, mlx_device dev) = NULL;
|
||||
int (*mlx_device_get_type_)(mlx_device_type* type, mlx_device dev) = NULL;
|
||||
int (*mlx_get_default_device_)(mlx_device* dev) = NULL;
|
||||
int (*mlx_set_default_device_)(mlx_device dev) = NULL;
|
||||
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
|
||||
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
|
||||
bool (*mlx_distributed_is_available_)(void) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL;
|
||||
int (*mlx_device_is_available_)(bool* avail, mlx_device dev) = NULL;
|
||||
int (*mlx_device_count_)(int* count, mlx_device_type type) = NULL;
|
||||
mlx_device_info (*mlx_device_info_new_)(void) = NULL;
|
||||
int (*mlx_device_info_get_)(mlx_device_info* info, mlx_device dev) = NULL;
|
||||
int (*mlx_device_info_free_)(mlx_device_info info) = NULL;
|
||||
int (*mlx_device_info_has_key_)(
|
||||
bool* exists,
|
||||
mlx_device_info info,
|
||||
const char* key) = NULL;
|
||||
int (*mlx_device_info_is_string_)(
|
||||
bool* is_string,
|
||||
mlx_device_info info,
|
||||
const char* key) = NULL;
|
||||
int (*mlx_device_info_get_string_)(
|
||||
const char** value,
|
||||
mlx_device_info info,
|
||||
const char* key) = NULL;
|
||||
int (*mlx_device_info_get_size_)(
|
||||
size_t* value,
|
||||
mlx_device_info info,
|
||||
const char* key) = NULL;
|
||||
int (*mlx_device_info_get_keys_)(mlx_vector_string* keys, mlx_device_info info) = NULL;
|
||||
int (*mlx_distributed_all_gather_)(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
@@ -288,6 +323,11 @@ int (*mlx_distributed_sum_scatter_)(
|
||||
const mlx_array x,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
|
||||
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
|
||||
bool (*mlx_distributed_is_available_)(void) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL;
|
||||
void (*mlx_set_error_handler_)(
|
||||
mlx_error_handler_func handler,
|
||||
void* data,
|
||||
@@ -450,6 +490,16 @@ int (*mlx_fast_rope_)(
|
||||
int offset,
|
||||
const mlx_array freqs /* may be null */,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_fast_rope_dynamic_)(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
int dims,
|
||||
bool traditional,
|
||||
mlx_optional_float base,
|
||||
float scale,
|
||||
const mlx_array offset,
|
||||
const mlx_array freqs /* may be null */,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_fast_scaled_dot_product_attention_)(
|
||||
mlx_array* res,
|
||||
const mlx_array queries,
|
||||
@@ -560,14 +610,6 @@ int (*mlx_fft_rfftn_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable) = NULL;
|
||||
int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io) = NULL;
|
||||
int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io) = NULL;
|
||||
int (*mlx_io_reader_free_)(mlx_io_reader io) = NULL;
|
||||
mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable) = NULL;
|
||||
int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io) = NULL;
|
||||
int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io) = NULL;
|
||||
int (*mlx_io_writer_free_)(mlx_io_writer io) = NULL;
|
||||
int (*mlx_load_reader_)(
|
||||
mlx_array* res,
|
||||
mlx_io_reader in_stream,
|
||||
@@ -593,6 +635,14 @@ int (*mlx_save_safetensors_)(
|
||||
const char* file,
|
||||
const mlx_map_string_to_array param,
|
||||
const mlx_map_string_to_string metadata) = NULL;
|
||||
mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable) = NULL;
|
||||
int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io) = NULL;
|
||||
int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io) = NULL;
|
||||
int (*mlx_io_reader_free_)(mlx_io_reader io) = NULL;
|
||||
mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable) = NULL;
|
||||
int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io) = NULL;
|
||||
int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io) = NULL;
|
||||
int (*mlx_io_writer_free_)(mlx_io_writer io) = NULL;
|
||||
int (*mlx_linalg_cholesky_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -733,7 +783,6 @@ int (*mlx_reset_peak_memory_)(void) = NULL;
|
||||
int (*mlx_set_cache_limit_)(size_t* res, size_t limit) = NULL;
|
||||
int (*mlx_set_memory_limit_)(size_t* res, size_t limit) = NULL;
|
||||
int (*mlx_set_wired_limit_)(size_t* res, size_t limit) = NULL;
|
||||
mlx_metal_device_info_t (*mlx_metal_device_info_)(void) = NULL;
|
||||
int (*mlx_metal_is_available_)(bool* res) = NULL;
|
||||
int (*mlx_metal_start_capture_)(const char* path) = NULL;
|
||||
int (*mlx_metal_stop_capture_)(void) = NULL;
|
||||
@@ -1162,6 +1211,14 @@ int (*mlx_gather_)(
|
||||
const int* slice_sizes,
|
||||
size_t slice_sizes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_gather_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
int axis,
|
||||
const int* slice_sizes,
|
||||
size_t slice_sizes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_gather_mm_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -1483,6 +1540,15 @@ int (*mlx_put_along_axis_)(
|
||||
const mlx_array values,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_qqmm_)(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
const mlx_array w,
|
||||
const mlx_array w_scales /* may be null */,
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_quantize_)(
|
||||
mlx_vector_array* res,
|
||||
const mlx_array w,
|
||||
@@ -1566,6 +1632,13 @@ int (*mlx_scatter_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
const mlx_array updates,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_add_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -1574,6 +1647,13 @@ int (*mlx_scatter_add_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_add_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
const mlx_array updates,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_add_axis_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -1589,6 +1669,13 @@ int (*mlx_scatter_max_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_max_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
const mlx_array updates,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_min_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -1597,6 +1684,13 @@ int (*mlx_scatter_min_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_min_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
const mlx_array updates,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_prod_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -1605,6 +1699,13 @@ int (*mlx_scatter_prod_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_prod_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
const mlx_array updates,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_segmented_mm_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -2028,22 +2129,6 @@ mlx_string (*mlx_string_new_data_)(const char* str) = NULL;
|
||||
int (*mlx_string_set_)(mlx_string* str, const mlx_string src) = NULL;
|
||||
const char * (*mlx_string_data_)(mlx_string str) = NULL;
|
||||
int (*mlx_string_free_)(mlx_string str) = NULL;
|
||||
int (*mlx_detail_vmap_replace_)(
|
||||
mlx_vector_array* res,
|
||||
const mlx_vector_array inputs,
|
||||
const mlx_vector_array s_inputs,
|
||||
const mlx_vector_array s_outputs,
|
||||
const int* in_axes,
|
||||
size_t in_axes_num,
|
||||
const int* out_axes,
|
||||
size_t out_axes_num) = NULL;
|
||||
int (*mlx_detail_vmap_trace_)(
|
||||
mlx_vector_array* res_0,
|
||||
mlx_vector_array* res_1,
|
||||
const mlx_closure fun,
|
||||
const mlx_vector_array inputs,
|
||||
const int* in_axes,
|
||||
size_t in_axes_num) = NULL;
|
||||
int (*mlx_async_eval_)(const mlx_vector_array outputs) = NULL;
|
||||
int (*mlx_checkpoint_)(mlx_closure* res, const mlx_closure fun) = NULL;
|
||||
int (*mlx_custom_function_)(
|
||||
@@ -2074,6 +2159,22 @@ int (*mlx_vjp_)(
|
||||
const mlx_closure fun,
|
||||
const mlx_vector_array primals,
|
||||
const mlx_vector_array cotangents) = NULL;
|
||||
int (*mlx_detail_vmap_replace_)(
|
||||
mlx_vector_array* res,
|
||||
const mlx_vector_array inputs,
|
||||
const mlx_vector_array s_inputs,
|
||||
const mlx_vector_array s_outputs,
|
||||
const int* in_axes,
|
||||
size_t in_axes_num,
|
||||
const int* out_axes,
|
||||
size_t out_axes_num) = NULL;
|
||||
int (*mlx_detail_vmap_trace_)(
|
||||
mlx_vector_array* res_0,
|
||||
mlx_vector_array* res_1,
|
||||
const mlx_closure fun,
|
||||
const mlx_vector_array inputs,
|
||||
const int* in_axes,
|
||||
size_t in_axes_num) = NULL;
|
||||
mlx_vector_array (*mlx_vector_array_new_)(void) = NULL;
|
||||
int (*mlx_vector_array_set_)(mlx_vector_array* vec, const mlx_vector_array src) = NULL;
|
||||
int (*mlx_vector_array_free_)(mlx_vector_array vec) = NULL;
|
||||
@@ -2166,6 +2267,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_array_new_double);
|
||||
CHECK_LOAD(handle, mlx_array_new_complex);
|
||||
CHECK_LOAD(handle, mlx_array_new_data);
|
||||
CHECK_LOAD(handle, mlx_array_new_data_managed);
|
||||
CHECK_LOAD(handle, mlx_array_new_data_managed_payload);
|
||||
CHECK_LOAD(handle, mlx_array_set);
|
||||
CHECK_LOAD(handle, mlx_array_set_bool);
|
||||
CHECK_LOAD(handle, mlx_array_set_int);
|
||||
@@ -2261,6 +2364,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_disable_compile);
|
||||
CHECK_LOAD(handle, mlx_enable_compile);
|
||||
CHECK_LOAD(handle, mlx_set_compile_mode);
|
||||
CHECK_LOAD(handle, mlx_cuda_is_available);
|
||||
CHECK_LOAD(handle, mlx_device_new);
|
||||
CHECK_LOAD(handle, mlx_device_new_type);
|
||||
CHECK_LOAD(handle, mlx_device_free);
|
||||
@@ -2271,11 +2375,16 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_device_get_type);
|
||||
CHECK_LOAD(handle, mlx_get_default_device);
|
||||
CHECK_LOAD(handle, mlx_set_default_device);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_rank);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_size);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_split);
|
||||
CHECK_LOAD(handle, mlx_distributed_is_available);
|
||||
CHECK_LOAD(handle, mlx_distributed_init);
|
||||
CHECK_LOAD(handle, mlx_device_is_available);
|
||||
CHECK_LOAD(handle, mlx_device_count);
|
||||
CHECK_LOAD(handle, mlx_device_info_new);
|
||||
CHECK_LOAD(handle, mlx_device_info_get);
|
||||
CHECK_LOAD(handle, mlx_device_info_free);
|
||||
CHECK_LOAD(handle, mlx_device_info_has_key);
|
||||
CHECK_LOAD(handle, mlx_device_info_is_string);
|
||||
CHECK_LOAD(handle, mlx_device_info_get_string);
|
||||
CHECK_LOAD(handle, mlx_device_info_get_size);
|
||||
CHECK_LOAD(handle, mlx_device_info_get_keys);
|
||||
CHECK_LOAD(handle, mlx_distributed_all_gather);
|
||||
CHECK_LOAD(handle, mlx_distributed_all_max);
|
||||
CHECK_LOAD(handle, mlx_distributed_all_min);
|
||||
@@ -2284,6 +2393,11 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_distributed_recv_like);
|
||||
CHECK_LOAD(handle, mlx_distributed_send);
|
||||
CHECK_LOAD(handle, mlx_distributed_sum_scatter);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_rank);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_size);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_split);
|
||||
CHECK_LOAD(handle, mlx_distributed_is_available);
|
||||
CHECK_LOAD(handle, mlx_distributed_init);
|
||||
CHECK_LOAD(handle, mlx_set_error_handler);
|
||||
CHECK_LOAD(handle, _mlx_error);
|
||||
CHECK_LOAD(handle, mlx_export_function);
|
||||
@@ -2325,6 +2439,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_fast_metal_kernel_apply);
|
||||
CHECK_LOAD(handle, mlx_fast_rms_norm);
|
||||
CHECK_LOAD(handle, mlx_fast_rope);
|
||||
CHECK_LOAD(handle, mlx_fast_rope_dynamic);
|
||||
CHECK_LOAD(handle, mlx_fast_scaled_dot_product_attention);
|
||||
CHECK_LOAD(handle, mlx_fft_fft);
|
||||
CHECK_LOAD(handle, mlx_fft_fft2);
|
||||
@@ -2340,14 +2455,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_fft_rfft);
|
||||
CHECK_LOAD(handle, mlx_fft_rfft2);
|
||||
CHECK_LOAD(handle, mlx_fft_rfftn);
|
||||
CHECK_LOAD(handle, mlx_io_reader_new);
|
||||
CHECK_LOAD(handle, mlx_io_reader_descriptor);
|
||||
CHECK_LOAD(handle, mlx_io_reader_tostring);
|
||||
CHECK_LOAD(handle, mlx_io_reader_free);
|
||||
CHECK_LOAD(handle, mlx_io_writer_new);
|
||||
CHECK_LOAD(handle, mlx_io_writer_descriptor);
|
||||
CHECK_LOAD(handle, mlx_io_writer_tostring);
|
||||
CHECK_LOAD(handle, mlx_io_writer_free);
|
||||
CHECK_LOAD(handle, mlx_load_reader);
|
||||
CHECK_LOAD(handle, mlx_load);
|
||||
CHECK_LOAD(handle, mlx_load_safetensors_reader);
|
||||
@@ -2356,6 +2463,14 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_save);
|
||||
CHECK_LOAD(handle, mlx_save_safetensors_writer);
|
||||
CHECK_LOAD(handle, mlx_save_safetensors);
|
||||
CHECK_LOAD(handle, mlx_io_reader_new);
|
||||
CHECK_LOAD(handle, mlx_io_reader_descriptor);
|
||||
CHECK_LOAD(handle, mlx_io_reader_tostring);
|
||||
CHECK_LOAD(handle, mlx_io_reader_free);
|
||||
CHECK_LOAD(handle, mlx_io_writer_new);
|
||||
CHECK_LOAD(handle, mlx_io_writer_descriptor);
|
||||
CHECK_LOAD(handle, mlx_io_writer_tostring);
|
||||
CHECK_LOAD(handle, mlx_io_writer_free);
|
||||
CHECK_LOAD(handle, mlx_linalg_cholesky);
|
||||
CHECK_LOAD(handle, mlx_linalg_cholesky_inv);
|
||||
CHECK_LOAD(handle, mlx_linalg_cross);
|
||||
@@ -2400,7 +2515,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_set_cache_limit);
|
||||
CHECK_LOAD(handle, mlx_set_memory_limit);
|
||||
CHECK_LOAD(handle, mlx_set_wired_limit);
|
||||
CHECK_LOAD(handle, mlx_metal_device_info);
|
||||
CHECK_LOAD(handle, mlx_metal_is_available);
|
||||
CHECK_LOAD(handle, mlx_metal_start_capture);
|
||||
CHECK_LOAD(handle, mlx_metal_stop_capture);
|
||||
@@ -2486,6 +2600,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_full);
|
||||
CHECK_LOAD(handle, mlx_full_like);
|
||||
CHECK_LOAD(handle, mlx_gather);
|
||||
CHECK_LOAD(handle, mlx_gather_single);
|
||||
CHECK_LOAD(handle, mlx_gather_mm);
|
||||
CHECK_LOAD(handle, mlx_gather_qmm);
|
||||
CHECK_LOAD(handle, mlx_greater);
|
||||
@@ -2550,6 +2665,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_prod_axis);
|
||||
CHECK_LOAD(handle, mlx_prod);
|
||||
CHECK_LOAD(handle, mlx_put_along_axis);
|
||||
CHECK_LOAD(handle, mlx_qqmm);
|
||||
CHECK_LOAD(handle, mlx_quantize);
|
||||
CHECK_LOAD(handle, mlx_quantized_matmul);
|
||||
CHECK_LOAD(handle, mlx_radians);
|
||||
@@ -2566,11 +2682,16 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_round);
|
||||
CHECK_LOAD(handle, mlx_rsqrt);
|
||||
CHECK_LOAD(handle, mlx_scatter);
|
||||
CHECK_LOAD(handle, mlx_scatter_single);
|
||||
CHECK_LOAD(handle, mlx_scatter_add);
|
||||
CHECK_LOAD(handle, mlx_scatter_add_single);
|
||||
CHECK_LOAD(handle, mlx_scatter_add_axis);
|
||||
CHECK_LOAD(handle, mlx_scatter_max);
|
||||
CHECK_LOAD(handle, mlx_scatter_max_single);
|
||||
CHECK_LOAD(handle, mlx_scatter_min);
|
||||
CHECK_LOAD(handle, mlx_scatter_min_single);
|
||||
CHECK_LOAD(handle, mlx_scatter_prod);
|
||||
CHECK_LOAD(handle, mlx_scatter_prod_single);
|
||||
CHECK_LOAD(handle, mlx_segmented_mm);
|
||||
CHECK_LOAD(handle, mlx_sigmoid);
|
||||
CHECK_LOAD(handle, mlx_sign);
|
||||
@@ -2665,8 +2786,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_string_set);
|
||||
CHECK_LOAD(handle, mlx_string_data);
|
||||
CHECK_LOAD(handle, mlx_string_free);
|
||||
CHECK_LOAD(handle, mlx_detail_vmap_replace);
|
||||
CHECK_LOAD(handle, mlx_detail_vmap_trace);
|
||||
CHECK_LOAD(handle, mlx_async_eval);
|
||||
CHECK_LOAD(handle, mlx_checkpoint);
|
||||
CHECK_LOAD(handle, mlx_custom_function);
|
||||
@@ -2675,6 +2794,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_jvp);
|
||||
CHECK_LOAD(handle, mlx_value_and_grad);
|
||||
CHECK_LOAD(handle, mlx_vjp);
|
||||
CHECK_LOAD(handle, mlx_detail_vmap_replace);
|
||||
CHECK_LOAD(handle, mlx_detail_vmap_trace);
|
||||
CHECK_LOAD(handle, mlx_vector_array_new);
|
||||
CHECK_LOAD(handle, mlx_vector_array_set);
|
||||
CHECK_LOAD(handle, mlx_vector_array_free);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,6 +4,10 @@
|
||||
#define MLX_GENERATED_H
|
||||
|
||||
#include "dynamic.h"
|
||||
{{ range .Functions }}
|
||||
#define {{ .Name }} {{ .Name }}_mlx_gen_orig_
|
||||
{{- end }}
|
||||
|
||||
#include "mlx/c/mlx.h"
|
||||
{{ range .Functions }}
|
||||
#undef {{ .Name }}
|
||||
|
||||
@@ -37,9 +37,7 @@ func Load(path string) iter.Seq2[string, *Array] {
|
||||
}
|
||||
|
||||
name := C.GoString(key)
|
||||
arr := New(name)
|
||||
arr.ctx = value
|
||||
if !yield(name, arr) {
|
||||
if !yield(name, &Array{ctx: value, desc: tensorDesc{name: name, numRefs: 1000}}) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,43 +10,43 @@ import (
|
||||
)
|
||||
|
||||
func (t *Array) Abs() *Array {
|
||||
out := New("ABS")
|
||||
out := New("ABS", t)
|
||||
C.mlx_abs(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Add(other *Array) *Array {
|
||||
out := New("ADD")
|
||||
out := New("ADD", t, other)
|
||||
C.mlx_add(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Addmm(a, b *Array, alpha, beta float32) *Array {
|
||||
out := New("ADDMM")
|
||||
out := New("ADDMM", t, a, b)
|
||||
C.mlx_addmm(&out.ctx, t.ctx, a.ctx, b.ctx, C.float(alpha), C.float(beta), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Argmax(axis int, keepDims bool) *Array {
|
||||
out := New("ARGMAX")
|
||||
out := New("ARGMAX", t)
|
||||
C.mlx_argmax_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) ArgpartitionAxis(kth int, axis int) *Array {
|
||||
out := New("ARGPARTITION")
|
||||
out := New("ARGPARTITION", t)
|
||||
C.mlx_argpartition_axis(&out.ctx, t.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) ArgsortAxis(axis int) *Array {
|
||||
out := New("ARGSORT_AXIS")
|
||||
out := New("ARGSORT_AXIS", t)
|
||||
C.mlx_argsort_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) AsType(dtype DType) *Array {
|
||||
out := New("AS_TYPE")
|
||||
out := New("AS_TYPE", t)
|
||||
C.mlx_astype(&out.ctx, t.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -62,7 +62,7 @@ func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array {
|
||||
cStrides[i] = C.int64_t(s)
|
||||
}
|
||||
|
||||
out := New("AS_STRIDED")
|
||||
out := New("AS_STRIDED", t)
|
||||
C.mlx_as_strided(
|
||||
&out.ctx, t.ctx,
|
||||
unsafe.SliceData(cShape), C.size_t(len(shape)),
|
||||
@@ -82,31 +82,31 @@ func (t *Array) Concatenate(axis int, others ...*Array) *Array {
|
||||
C.mlx_vector_array_append_value(vector, other.ctx)
|
||||
}
|
||||
|
||||
out := New("CONCATENATE")
|
||||
out := New("CONCATENATE", s...)
|
||||
C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Divide(other *Array) *Array {
|
||||
out := New("DIVIDE")
|
||||
out := New("DIVIDE", t, other)
|
||||
C.mlx_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) ExpandDims(axis int) *Array {
|
||||
out := New("EXPAND_DIMS")
|
||||
out := New("EXPAND_DIMS", t)
|
||||
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Flatten(startAxis, endAxis int) *Array {
|
||||
out := New("FLATTEN")
|
||||
out := New("FLATTEN", t)
|
||||
C.mlx_flatten(&out.ctx, t.ctx, C.int(startAxis), C.int(endAxis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) FloorDivide(other *Array) *Array {
|
||||
out := New("FLOOR_DIVIDE")
|
||||
out := New("FLOOR_DIVIDE", t, other)
|
||||
C.mlx_floor_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -118,43 +118,43 @@ func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array {
|
||||
if rhs == nil {
|
||||
rhs = New("")
|
||||
}
|
||||
out := New("GATHER_MM")
|
||||
out := New("GATHER_MM", t, other, lhs, rhs)
|
||||
C.mlx_gather_mm(&out.ctx, t.ctx, other.ctx, lhs.ctx, rhs.ctx, C.bool(sorted), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Logsumexp(keepDims bool) *Array {
|
||||
out := New("LOGSUMEXP")
|
||||
out := New("LOGSUMEXP", t)
|
||||
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Matmul(other *Array) *Array {
|
||||
out := New("MATMUL")
|
||||
out := New("MATMUL", t, other)
|
||||
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Multiply(other *Array) *Array {
|
||||
out := New("MULTIPLY")
|
||||
out := New("MULTIPLY", t, other)
|
||||
C.mlx_multiply(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Negative() *Array {
|
||||
out := New("NEGATIVE")
|
||||
out := New("NEGATIVE", t)
|
||||
C.mlx_negative(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Power(exponent *Array) *Array {
|
||||
out := New("POWER")
|
||||
out := New("POWER", t, exponent)
|
||||
C.mlx_power(&out.ctx, t.ctx, exponent.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array {
|
||||
out := New("PUT_ALONG_AXIS")
|
||||
out := New("PUT_ALONG_AXIS", t, indices, values)
|
||||
C.mlx_put_along_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -165,25 +165,25 @@ func (t *Array) Reshape(axes ...int) *Array {
|
||||
cAxes[i] = C.int(axes[i])
|
||||
}
|
||||
|
||||
out := New("RESHAPE")
|
||||
out := New("RESHAPE", t)
|
||||
C.mlx_reshape(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Sigmoid() *Array {
|
||||
out := New("SIGMOID")
|
||||
out := New("SIGMOID", t)
|
||||
C.mlx_sigmoid(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Sqrt() *Array {
|
||||
out := New("SQRT")
|
||||
out := New("SQRT", t)
|
||||
C.mlx_sqrt(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Squeeze(axis int) *Array {
|
||||
out := New("SQUEEZE")
|
||||
out := New("SQUEEZE", t)
|
||||
C.mlx_squeeze_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -198,37 +198,37 @@ func (t *Array) StackAxis(axis int, others ...*Array) *Array {
|
||||
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
out := New("STACK_AXIS")
|
||||
out := New("STACK_AXIS", append(others, t)...)
|
||||
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Subtract(other *Array) *Array {
|
||||
out := New("SUBTRACT")
|
||||
out := New("SUBTRACT", t, other)
|
||||
C.mlx_subtract(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) SumAxis(axis int, keepDims bool) *Array {
|
||||
out := New("SUM_AXIS")
|
||||
out := New("SUM_AXIS", t)
|
||||
C.mlx_sum_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) TakeAxis(indices *Array, axis int) *Array {
|
||||
out := New("TAKE_AXIS")
|
||||
out := New("TAKE_AXIS", t, indices)
|
||||
C.mlx_take_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) TakeAlongAxis(indices *Array, axis int) *Array {
|
||||
out := New("TAKE_ALONG_AXIS")
|
||||
out := New("TAKE_ALONG_AXIS", t, indices)
|
||||
C.mlx_take_along_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Tanh() *Array {
|
||||
out := New("TANH")
|
||||
out := New("TANH", t)
|
||||
C.mlx_tanh(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -239,7 +239,7 @@ func (t *Array) Transpose(axes ...int) *Array {
|
||||
cAxes[i] = C.int(axis)
|
||||
}
|
||||
|
||||
out := New("TRANSPOSE")
|
||||
out := New("TRANSPOSE", t)
|
||||
C.mlx_transpose_axes(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -41,12 +41,14 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Arr
|
||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||
optDtype := C.mlx_optional_dtype{has_value: false}
|
||||
|
||||
inputs := []*Array{w, scales}
|
||||
var b C.mlx_array
|
||||
if biases != nil {
|
||||
b = biases.ctx
|
||||
inputs = append(inputs, biases)
|
||||
}
|
||||
|
||||
out := New("DEQUANTIZE")
|
||||
out := New("DEQUANTIZE", inputs...)
|
||||
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, optDtype, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -57,12 +59,14 @@ func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bit
|
||||
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||
|
||||
inputs := []*Array{x, w, scales}
|
||||
var b C.mlx_array
|
||||
if biases != nil {
|
||||
b = biases.ctx
|
||||
inputs = append(inputs, biases)
|
||||
}
|
||||
|
||||
out := New("QUANTIZED_MATMUL")
|
||||
out := New("QUANTIZED_MATMUL", inputs...)
|
||||
C.mlx_quantized_matmul(&out.ctx, x.ctx, w.ctx, scales.ctx, b, C.bool(transpose), optGroupSize, optBits, cMode, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -73,18 +77,22 @@ func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, trans
|
||||
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||
|
||||
inputs := []*Array{x, w, scales}
|
||||
var b, lhs, rhs C.mlx_array
|
||||
if biases != nil {
|
||||
b = biases.ctx
|
||||
inputs = append(inputs, biases)
|
||||
}
|
||||
if lhsIndices != nil {
|
||||
lhs = lhsIndices.ctx
|
||||
inputs = append(inputs, lhsIndices)
|
||||
}
|
||||
if rhsIndices != nil {
|
||||
rhs = rhsIndices.ctx
|
||||
inputs = append(inputs, rhsIndices)
|
||||
}
|
||||
|
||||
out := New("GATHER_QMM")
|
||||
out := New("GATHER_QMM", inputs...)
|
||||
C.mlx_gather_qmm(&out.ctx, x.ctx, w.ctx, scales.ctx, b, lhs, rhs, C.bool(transpose), optGroupSize, optBits, cMode, C.bool(sortedIndices), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -96,7 +104,7 @@ func Tile(a *Array, reps []int32) *Array {
|
||||
for i, r := range reps {
|
||||
cReps[i] = C.int(r)
|
||||
}
|
||||
out := New("TILE")
|
||||
out := New("TILE", a)
|
||||
C.mlx_tile(&out.ctx, a.ctx, unsafe.SliceData(cReps), C.size_t(len(reps)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -108,7 +116,7 @@ func Tri(n, m int32, k int) *Array {
|
||||
}
|
||||
|
||||
func Where(condition, a, b *Array) *Array {
|
||||
out := New("WHERE")
|
||||
out := New("WHERE", condition, a, b)
|
||||
C.mlx_where(&out.ctx, condition.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -123,7 +131,7 @@ func Stack(arrays []*Array, axis int) *Array {
|
||||
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
out := New("STACK")
|
||||
out := New("STACK", arrays...)
|
||||
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -145,13 +153,13 @@ func Take(a *Array, indices *Array, axis int) *Array {
|
||||
}
|
||||
|
||||
func RSqrt(a *Array) *Array {
|
||||
out := New("RSQRT")
|
||||
out := New("RSQRT", a)
|
||||
C.mlx_rsqrt(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Mean(a *Array, axis int, keepDims bool) *Array {
|
||||
out := New("MEAN_AXIS")
|
||||
out := New("MEAN_AXIS", a)
|
||||
C.mlx_mean_axis(&out.ctx, a.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -227,7 +235,7 @@ func SliceStartStop(a *Array, start, stop []int32) *Array {
|
||||
cStop[i] = C.int(stop[i])
|
||||
cStrides[i] = 1
|
||||
}
|
||||
out := New("SLICE")
|
||||
out := New("SLICE", a)
|
||||
C.mlx_slice(&out.ctx, a.ctx, unsafe.SliceData(cStart), C.size_t(n), unsafe.SliceData(cStop), C.size_t(n), unsafe.SliceData(cStrides), C.size_t(n), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -249,7 +257,7 @@ func SiLU(a *Array) *Array {
|
||||
|
||||
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
|
||||
freqs := New("")
|
||||
out := New("FAST_ROPE")
|
||||
out := New("FAST_ROPE", x, freqs)
|
||||
C.mlx_fast_rope(
|
||||
&out.ctx,
|
||||
x.ctx,
|
||||
@@ -281,13 +289,13 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
out := New("FAST_SDPA")
|
||||
out := New("FAST_SDPA", q, k, v, mask, sinks)
|
||||
C.mlx_fast_scaled_dot_product_attention(&out.ctx, q.ctx, k.ctx, v.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func RMSNormFn(x, weight *Array, eps float32) *Array {
|
||||
out := New("FAST_RMSNORM")
|
||||
out := New("FAST_RMSNORM", x)
|
||||
C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
@@ -314,7 +322,7 @@ func scalarWithDtype(s float32, a *Array) C.mlx_array {
|
||||
|
||||
func AddScalar(a *Array, s float32) *Array {
|
||||
scalar := scalarWithDtype(s, a)
|
||||
out := New("ADD_SCALAR")
|
||||
out := New("ADD_SCALAR", a)
|
||||
C.mlx_add(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
|
||||
C.mlx_array_free(scalar)
|
||||
return out
|
||||
@@ -322,7 +330,7 @@ func AddScalar(a *Array, s float32) *Array {
|
||||
|
||||
func MulScalar(a *Array, s float32) *Array {
|
||||
scalar := scalarWithDtype(s, a)
|
||||
out := New("MUL_SCALAR")
|
||||
out := New("MUL_SCALAR", a)
|
||||
C.mlx_multiply(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
|
||||
C.mlx_array_free(scalar)
|
||||
return out
|
||||
@@ -330,7 +338,7 @@ func MulScalar(a *Array, s float32) *Array {
|
||||
|
||||
func DivScalar(a *Array, s float32) *Array {
|
||||
scalar := scalarWithDtype(s, a)
|
||||
out := New("DIV_SCALAR")
|
||||
out := New("DIV_SCALAR", a)
|
||||
C.mlx_divide(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
|
||||
C.mlx_array_free(scalar)
|
||||
return out
|
||||
|
||||
@@ -7,7 +7,7 @@ import "C"
|
||||
|
||||
func (t *Array) Categorical(axis int) *Array {
|
||||
key := New("")
|
||||
out := New("")
|
||||
out := New("", t, key)
|
||||
C.mlx_random_categorical(&out.ctx, t.ctx, C.int(axis), key.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
|
||||
|
||||
func (t *Array) Slice(slices ...slice) *Array {
|
||||
starts, stops, strides := makeSlices(t.Dims(), slices...)
|
||||
out := New("SLICE")
|
||||
out := New("SLICE", t)
|
||||
C.mlx_slice(
|
||||
&out.ctx, t.ctx,
|
||||
unsafe.SliceData(starts), C.size_t(len(starts)),
|
||||
@@ -74,7 +74,7 @@ func (t *Array) Slice(slices ...slice) *Array {
|
||||
|
||||
func (t *Array) SliceUpdate(other *Array, slices ...slice) *Array {
|
||||
starts, stops, strides := makeSlices(t.Dims(), slices...)
|
||||
out := New("SLICE_UPDATE")
|
||||
out := New("SLICE_UPDATE", t, other)
|
||||
C.mlx_slice_update(
|
||||
&out.ctx, t.ctx, other.ctx,
|
||||
unsafe.SliceData(starts), C.size_t(len(starts)),
|
||||
|
||||
@@ -8,10 +8,10 @@ import (
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/tokenizer"
|
||||
)
|
||||
|
||||
// Model is the interface that model implementations must satisfy.
|
||||
@@ -78,21 +78,8 @@ func New(root *model.Root) (Model, error) {
|
||||
return fn(root)
|
||||
}
|
||||
|
||||
// Weights returns a function that loads model weights, then pins all
|
||||
// arrays reachable from the model struct and sweeps everything else.
|
||||
// Weights returns the model's LoadWeights method, which encapsulates all
|
||||
// weight assignment and post-processing (MLA absorption, expert stacking).
|
||||
func Weights(m Model) func(map[string]*mlx.Array) error {
|
||||
return func(tensors map[string]*mlx.Array) error {
|
||||
if err := m.LoadWeights(tensors); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
collected := mlx.Collect(m)
|
||||
for _, arr := range collected {
|
||||
mlx.Pin(arr)
|
||||
}
|
||||
mlx.Sweep()
|
||||
mlx.Eval(collected...)
|
||||
|
||||
return nil
|
||||
}
|
||||
return m.LoadWeights
|
||||
}
|
||||
|
||||
@@ -4,12 +4,11 @@ package mlxrunner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
@@ -47,8 +46,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
for total-processed > 1 {
|
||||
n := min(2<<10, total-processed-1)
|
||||
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
||||
mlx.Sweep()
|
||||
temp := r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
||||
defer mlx.Free(temp)
|
||||
mlx.Eval(func() []*mlx.Array {
|
||||
s := make([]*mlx.Array, 2*len(caches))
|
||||
for i, c := range caches {
|
||||
@@ -67,16 +66,11 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
||||
|
||||
logprobs := logits.Subtract(logits.Logsumexp(true))
|
||||
sample := request.Sample(logprobs)
|
||||
|
||||
mlx.Pin(sample, logprobs)
|
||||
mlx.Sweep()
|
||||
mlx.AsyncEval(sample, logprobs)
|
||||
|
||||
return sample, logprobs
|
||||
return request.Sample(logprobs), logprobs
|
||||
}
|
||||
|
||||
sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed))
|
||||
mlx.AsyncEval(sample, logprobs)
|
||||
|
||||
var b bytes.Buffer
|
||||
|
||||
@@ -85,6 +79,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
outputs := make([]int32, 0, request.Options.MaxTokens)
|
||||
for i := range request.Options.MaxTokens {
|
||||
nextSample, nextLogprobs := step(sample)
|
||||
mlx.AsyncEval(nextSample, nextLogprobs)
|
||||
|
||||
if i == 0 {
|
||||
slog.Info("Prompt processing progress", "processed", total, "total", total)
|
||||
@@ -97,7 +92,6 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
outputs = append(outputs, output)
|
||||
|
||||
if r.Tokenizer.IsEOS(output) {
|
||||
mlx.Unpin(nextSample, nextLogprobs)
|
||||
final.Token = int(output)
|
||||
final.DoneReason = 0
|
||||
final.CompletionTokens = i
|
||||
@@ -109,7 +103,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
Token: int(output),
|
||||
}
|
||||
|
||||
mlx.Unpin(sample, logprobs)
|
||||
mlx.Free(sample, logprobs)
|
||||
if i%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
@@ -117,19 +111,10 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
sample, logprobs = nextSample, nextLogprobs
|
||||
}
|
||||
|
||||
mlx.Unpin(sample, logprobs)
|
||||
mlx.Free(sample, logprobs)
|
||||
final.CompletionTokensDuration = time.Since(now)
|
||||
request.Responses <- final
|
||||
r.InsertCache(append(inputs, outputs...), caches)
|
||||
mlx.Sweep()
|
||||
|
||||
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
|
||||
mlx.LogArrays()
|
||||
if r.cache != nil {
|
||||
r.cache.LogCache()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -141,5 +126,13 @@ func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
return flushValidUTF8Prefix(b)
|
||||
if text := b.String(); utf8.ValidString(text) {
|
||||
b.Reset()
|
||||
return text
|
||||
} else if b.Len() >= utf8.UTFMax {
|
||||
b.Reset()
|
||||
return text
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -12,12 +12,12 @@ import (
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/mlxrunner/sample"
|
||||
"github.com/ollama/ollama/x/tokenizer"
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
@@ -58,10 +58,10 @@ type Response struct {
|
||||
}
|
||||
|
||||
type Runner struct {
|
||||
Model base.Model
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
Requests chan Request
|
||||
cache *CacheEntry
|
||||
Model base.Model
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
Requests chan Request
|
||||
CacheEntries map[int32]*CacheEntry
|
||||
}
|
||||
|
||||
func (r *Runner) Load(modelName string) error {
|
||||
|
||||
@@ -40,7 +40,8 @@ func Execute(args []string) error {
|
||||
flagSet.Parse(args)
|
||||
|
||||
runner := Runner{
|
||||
Requests: make(chan Request),
|
||||
Requests: make(chan Request),
|
||||
CacheEntries: make(map[int32]*CacheEntry),
|
||||
}
|
||||
|
||||
if err := runner.Load(modelName); err != nil {
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// flushValidUTF8Prefix returns and consumes the longest valid UTF-8 prefix
|
||||
// currently buffered, leaving any incomplete trailing bytes in place.
|
||||
func flushValidUTF8Prefix(b *bytes.Buffer) string {
|
||||
data := b.Bytes()
|
||||
if len(data) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
prefix := validUTF8PrefixLen(data)
|
||||
if prefix == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
text := string(data[:prefix])
|
||||
b.Next(prefix)
|
||||
return text
|
||||
}
|
||||
|
||||
func validUTF8PrefixLen(data []byte) int {
|
||||
i := 0
|
||||
prefix := 0
|
||||
for i < len(data) {
|
||||
r, size := utf8.DecodeRune(data[i:])
|
||||
if r == utf8.RuneError && size == 1 {
|
||||
if !utf8.FullRune(data[i:]) {
|
||||
break
|
||||
}
|
||||
|
||||
// Invalid UTF-8 byte; consume one byte to guarantee forward progress.
|
||||
i++
|
||||
prefix = i
|
||||
continue
|
||||
}
|
||||
|
||||
i += size
|
||||
prefix = i
|
||||
}
|
||||
|
||||
return prefix
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFlushValidUTF8Prefix_PreservesIncompleteRune(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
|
||||
b.Write([]byte{0xE3, 0x81})
|
||||
if got := flushValidUTF8Prefix(&b); got != "" {
|
||||
t.Fatalf("first flush = %q, want empty", got)
|
||||
}
|
||||
|
||||
b.Write([]byte{0x93, 0xE3})
|
||||
if got := flushValidUTF8Prefix(&b); got != "こ" {
|
||||
t.Fatalf("second flush = %q, want %q", got, "こ")
|
||||
}
|
||||
|
||||
if got := b.Bytes(); !bytes.Equal(got, []byte{0xE3}) {
|
||||
t.Fatalf("buffer after second flush = %v, want %v", got, []byte{0xE3})
|
||||
}
|
||||
|
||||
b.Write([]byte{0x82, 0x93})
|
||||
if got := flushValidUTF8Prefix(&b); got != "ん" {
|
||||
t.Fatalf("third flush = %q, want %q", got, "ん")
|
||||
}
|
||||
|
||||
if b.Len() != 0 {
|
||||
t.Fatalf("buffer not empty after third flush: %d", b.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlushValidUTF8Prefix_ValidText(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
b.WriteString("hello 世界")
|
||||
|
||||
if got := flushValidUTF8Prefix(&b); got != "hello 世界" {
|
||||
t.Fatalf("flush = %q, want %q", got, "hello 世界")
|
||||
}
|
||||
|
||||
if b.Len() != 0 {
|
||||
t.Fatalf("buffer not empty after flush: %d", b.Len())
|
||||
}
|
||||
}
|
||||
@@ -8,12 +8,12 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
"github.com/ollama/ollama/x/tokenizer"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -401,6 +401,9 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
if m.NormScaled == nil {
|
||||
return fmt.Errorf("missing precomputed final norm weight")
|
||||
}
|
||||
collected := mlx.Collect(m)
|
||||
mlx.Eval(collected...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -9,12 +9,12 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
"github.com/ollama/ollama/x/tokenizer"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -702,6 +702,9 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
}
|
||||
}
|
||||
|
||||
collected := mlx.Collect(m)
|
||||
mlx.Eval(collected...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -8,12 +8,12 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
"github.com/ollama/ollama/x/tokenizer"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -235,6 +235,9 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
m.Layers[i] = layer
|
||||
}
|
||||
|
||||
collected := mlx.Collect(m)
|
||||
mlx.Eval(collected...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -8,12 +8,12 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
"github.com/ollama/ollama/x/tokenizer"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -252,6 +252,9 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
m.Layers[i] = layer
|
||||
}
|
||||
|
||||
collected := mlx.Collect(m)
|
||||
mlx.Eval(collected...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,108 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
// tokenizer.go - BPE and SentencePiece tokenizer for HuggingFace models
|
||||
//
|
||||
// Based on standard BPE algorithm (Sennrich et al. 2015) with:
|
||||
// - GPT-2 byte-level encoding (OpenAI tiktoken)
|
||||
// - HuggingFace tokenizer.json pretokenizer patterns
|
||||
// - SentencePiece ▁-style space handling
|
||||
|
||||
package tokenizer
|
||||
|
||||
import "regexp"
|
||||
|
||||
// TokenizerType identifies the tokenization algorithm
|
||||
type TokenizerType int
|
||||
|
||||
const (
|
||||
TokenizerBPE TokenizerType = iota // GPT-2 style byte-level BPE
|
||||
TokenizerSentencePiece // SentencePiece with ▁ for spaces
|
||||
)
|
||||
|
||||
// Vocabulary holds the tokenizer vocabulary and merges
|
||||
type Vocabulary struct {
|
||||
Values []string
|
||||
Reverse map[string]int32
|
||||
Merges map[string]int
|
||||
|
||||
BOS int32
|
||||
EOS []int32 // Multiple EOS tokens supported (e.g., Gemma has <eos> and <end_of_turn>)
|
||||
PAD int32 // Padding token (often <|endoftext|> or <pad>)
|
||||
AddBOS bool
|
||||
AddEOS bool
|
||||
|
||||
// Precomputed byte token IDs for <0xNN> fallback (256 entries, -1 if not found)
|
||||
byteTokens [256]int32
|
||||
}
|
||||
|
||||
// Tokenizer handles BPE and SentencePiece tokenization
|
||||
type Tokenizer struct {
|
||||
vocab *Vocabulary
|
||||
pretokenizer *regexp.Regexp
|
||||
specialTokens map[string]int32 // Special tokens for direct lookup
|
||||
sortedSpecialTokens []string // Special tokens sorted by length, longest first
|
||||
typ TokenizerType // Algorithm type
|
||||
}
|
||||
|
||||
// Precomputed GPT-2 byte-level encoding table
|
||||
// Maps byte values to their encoded rune equivalents
|
||||
var byteToRune [256]rune
|
||||
|
||||
func init() {
|
||||
for b := 0; b < 256; b++ {
|
||||
r := rune(b)
|
||||
switch {
|
||||
case r == 0x00ad:
|
||||
r = 0x0143
|
||||
case r <= 0x0020:
|
||||
r = r + 0x0100
|
||||
case r >= 0x007f && r <= 0x00a0:
|
||||
r = r + 0x00a2
|
||||
}
|
||||
byteToRune[b] = r
|
||||
}
|
||||
}
|
||||
|
||||
// VocabSize returns the vocabulary size
|
||||
func (t *Tokenizer) VocabSize() int {
|
||||
return len(t.vocab.Values)
|
||||
}
|
||||
|
||||
// BOS returns the beginning of sequence token ID
|
||||
func (t *Tokenizer) BOS() int32 {
|
||||
return t.vocab.BOS
|
||||
}
|
||||
|
||||
// EOS returns the first end of sequence token ID (for backwards compatibility)
|
||||
func (t *Tokenizer) EOS() int32 {
|
||||
if len(t.vocab.EOS) > 0 {
|
||||
return t.vocab.EOS[0]
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// EOSTokens returns all end of sequence token IDs
|
||||
func (t *Tokenizer) EOSTokens() []int32 {
|
||||
return t.vocab.EOS
|
||||
}
|
||||
|
||||
// PAD returns the padding token ID, or -1 if not set
|
||||
func (t *Tokenizer) PAD() int32 {
|
||||
return t.vocab.PAD
|
||||
}
|
||||
|
||||
// IsEOS returns true if the token ID is an end of sequence token
|
||||
func (t *Tokenizer) IsEOS(id int32) bool {
|
||||
for _, eos := range t.vocab.EOS {
|
||||
if id == eos {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetSpecialToken returns the token ID for a special token string
|
||||
func (t *Tokenizer) GetSpecialToken(name string) (int32, bool) {
|
||||
id, ok := t.specialTokens[name]
|
||||
return id, ok
|
||||
}
|
||||
@@ -1,251 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var (
|
||||
benchmarkSinkIDs []int32
|
||||
benchmarkSinkStr string
|
||||
benchmarkSinkTok *Tokenizer
|
||||
)
|
||||
|
||||
const benchmarkWordPieceJSON = `{
|
||||
"model": {
|
||||
"type": "WordPiece",
|
||||
"vocab": {
|
||||
"[UNK]": 0,
|
||||
"hello": 1,
|
||||
"##world": 2,
|
||||
"##ly": 3,
|
||||
"##hello": 4
|
||||
}
|
||||
},
|
||||
"added_tokens": []
|
||||
}`
|
||||
|
||||
const benchmarkSentencePieceJSON = `{
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"vocab": {
|
||||
"\u2581": 0,
|
||||
"h": 1,
|
||||
"e": 2,
|
||||
"l": 3,
|
||||
"o": 4,
|
||||
"w": 5,
|
||||
"r": 6,
|
||||
"d": 7,
|
||||
"<0x0A>": 8
|
||||
},
|
||||
"merges": []
|
||||
},
|
||||
"decoder": {
|
||||
"type": "Sequence",
|
||||
"decoders": [
|
||||
{
|
||||
"type": "Replace",
|
||||
"pattern": {
|
||||
"String": "\u2581"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"added_tokens": []
|
||||
}`
|
||||
|
||||
func benchmarkMiniLlamaPath(tb testing.TB) string {
|
||||
tb.Helper()
|
||||
|
||||
_, filename, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
tb.Fatal("failed to resolve benchmark file path")
|
||||
}
|
||||
|
||||
return filepath.Join(filepath.Dir(filename), "..", "imagegen", "tokenizer", "testdata", "mini_llama.json")
|
||||
}
|
||||
|
||||
func benchmarkLoadMiniLlama(tb testing.TB) *Tokenizer {
|
||||
tb.Helper()
|
||||
|
||||
data := benchmarkLoadMiniLlamaBytes(tb)
|
||||
tok, err := LoadFromBytes(data)
|
||||
if err != nil {
|
||||
tb.Fatalf("failed to load mini llama tokenizer: %v", err)
|
||||
}
|
||||
return tok
|
||||
}
|
||||
|
||||
func benchmarkLoadMiniLlamaBytes(tb testing.TB) []byte {
|
||||
tb.Helper()
|
||||
|
||||
data, err := os.ReadFile(benchmarkMiniLlamaPath(tb))
|
||||
if err != nil {
|
||||
tb.Fatalf("failed to read mini llama tokenizer: %v", err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func benchmarkLoadFromBytes(tb testing.TB, data []byte) *Tokenizer {
|
||||
tb.Helper()
|
||||
|
||||
tok, err := LoadFromBytes(data)
|
||||
if err != nil {
|
||||
tb.Fatalf("failed to load tokenizer from bytes: %v", err)
|
||||
}
|
||||
return tok
|
||||
}
|
||||
|
||||
func BenchmarkTokenizerEncodeBPE(b *testing.B) {
|
||||
tok := benchmarkLoadMiniLlama(b)
|
||||
|
||||
inputs := []struct {
|
||||
name string
|
||||
text string
|
||||
}{
|
||||
{name: "short", text: "Hello, world!"},
|
||||
{name: "medium", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 16)},
|
||||
{name: "long_sequential", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 80)},
|
||||
{name: "long_parallel", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 160)},
|
||||
{name: "huge_parallel", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 640)},
|
||||
{name: "special_tokens", text: "<|begin_of_text|>system\nYou are concise.<|end_of_text|>"},
|
||||
}
|
||||
|
||||
for _, input := range inputs {
|
||||
b.Run(input.name, func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(input.text)))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkSinkIDs = tok.Encode(input.text, false)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTokenizerDecodeBPE(b *testing.B) {
|
||||
tok := benchmarkLoadMiniLlama(b)
|
||||
|
||||
inputs := []struct {
|
||||
name string
|
||||
text string
|
||||
}{
|
||||
{name: "medium", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 16)},
|
||||
{name: "long", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 160)},
|
||||
}
|
||||
|
||||
for _, input := range inputs {
|
||||
ids := tok.Encode(input.text, false)
|
||||
b.Run(input.name, func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(input.text)))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkSinkStr = tok.Decode(ids)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTokenizerLoadFromBytes(b *testing.B) {
|
||||
data := benchmarkLoadMiniLlamaBytes(b)
|
||||
|
||||
config := &TokenizerConfig{
|
||||
TokenizerConfigJSON: []byte(`{
|
||||
"bos_token": {"content": "<|begin_of_text|>"},
|
||||
"eos_token": {"content": "<|end_of_text|>"},
|
||||
"add_bos_token": true
|
||||
}`),
|
||||
GenerationConfigJSON: []byte(`{"bos_token_id": 128000, "eos_token_id": 128001}`),
|
||||
}
|
||||
|
||||
b.Run("without_config", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(data)))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
tok, err := LoadFromBytes(data)
|
||||
if err != nil {
|
||||
b.Fatalf("LoadFromBytes failed: %v", err)
|
||||
}
|
||||
benchmarkSinkTok = tok
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("with_config", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(data)))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
tok, err := LoadFromBytesWithConfig(data, config)
|
||||
if err != nil {
|
||||
b.Fatalf("LoadFromBytesWithConfig failed: %v", err)
|
||||
}
|
||||
benchmarkSinkTok = tok
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkTokenizerEncodeWordPiece(b *testing.B) {
|
||||
tok := benchmarkLoadFromBytes(b, []byte(benchmarkWordPieceJSON))
|
||||
text := strings.Repeat("helloworldly", 16)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(text)))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkSinkIDs = tok.Encode(text, false)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTokenizerDecodeWordPiece(b *testing.B) {
|
||||
tok := benchmarkLoadFromBytes(b, []byte(benchmarkWordPieceJSON))
|
||||
text := strings.Repeat("helloworldly", 16)
|
||||
ids := tok.Encode(text, false)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(text)))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkSinkStr = tok.Decode(ids)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTokenizerEncodeSentencePiece(b *testing.B) {
|
||||
tok := benchmarkLoadFromBytes(b, []byte(benchmarkSentencePieceJSON))
|
||||
text := strings.Repeat("hello world\n", 64)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(text)))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkSinkIDs = tok.Encode(text, false)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTokenizerDecodeSentencePiece(b *testing.B) {
|
||||
tok := benchmarkLoadFromBytes(b, []byte(benchmarkSentencePieceJSON))
|
||||
text := strings.Repeat("hello world\n", 64)
|
||||
ids := tok.Encode(text, false)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(text)))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkSinkStr = tok.Decode(ids)
|
||||
}
|
||||
}
|
||||
@@ -1,175 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package tokenizer
|
||||
|
||||
import "container/heap"
|
||||
|
||||
type bpeMergeNode struct {
|
||||
prev int
|
||||
next int
|
||||
token string
|
||||
}
|
||||
|
||||
type bpePair struct {
|
||||
left int
|
||||
right int
|
||||
rank int
|
||||
value string
|
||||
}
|
||||
|
||||
type bpePairHeap []*bpePair
|
||||
|
||||
func (h bpePairHeap) Len() int { return len(h) }
|
||||
|
||||
func (h bpePairHeap) Less(i, j int) bool {
|
||||
return h[i].rank < h[j].rank || (h[i].rank == h[j].rank && h[i].left < h[j].left)
|
||||
}
|
||||
|
||||
func (h bpePairHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
|
||||
|
||||
func (h *bpePairHeap) Push(x any) {
|
||||
*h = append(*h, x.(*bpePair))
|
||||
}
|
||||
|
||||
func (h *bpePairHeap) Pop() any {
|
||||
old := *h
|
||||
n := len(old)
|
||||
item := old[n-1]
|
||||
*h = old[:n-1]
|
||||
return item
|
||||
}
|
||||
|
||||
// encodeBPEMerge encodes using BPE merge algorithm.
|
||||
// Uses the heap/linked-list pair merge strategy from tokenizer/bytepairencoding.go:
|
||||
// merge the lowest-rank valid pair, then only recheck adjacent pairs.
|
||||
func (t *Tokenizer) encodeBPEMerge(encoded string, ids []int32) []int32 {
|
||||
runes := []rune(encoded)
|
||||
if len(runes) == 0 {
|
||||
return ids
|
||||
}
|
||||
|
||||
nodes := make([]bpeMergeNode, len(runes))
|
||||
for i := range runes {
|
||||
nodes[i] = bpeMergeNode{
|
||||
prev: i - 1,
|
||||
next: i + 1,
|
||||
token: string(runes[i]),
|
||||
}
|
||||
}
|
||||
|
||||
pairwise := func(left, right int) *bpePair {
|
||||
if left < 0 || right >= len(nodes) {
|
||||
return nil
|
||||
}
|
||||
if nodes[left].token == "" || nodes[right].token == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
leftToken, rightToken := nodes[left].token, nodes[right].token
|
||||
rank, ok := t.vocab.Merges[leftToken+" "+rightToken]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
value := leftToken + rightToken
|
||||
if _, ok := t.vocab.Reverse[value]; !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &bpePair{
|
||||
left: left,
|
||||
right: right,
|
||||
rank: rank,
|
||||
value: value,
|
||||
}
|
||||
}
|
||||
|
||||
pairs := bpePairHeap{}
|
||||
heap.Init(&pairs)
|
||||
for i := 0; i < len(runes)-1; i++ {
|
||||
if pair := pairwise(i, i+1); pair != nil {
|
||||
heap.Push(&pairs, pair)
|
||||
}
|
||||
}
|
||||
|
||||
for pairs.Len() > 0 {
|
||||
pair := heap.Pop(&pairs).(*bpePair)
|
||||
left, right := nodes[pair.left], nodes[pair.right]
|
||||
if left.token == "" || right.token == "" {
|
||||
continue
|
||||
}
|
||||
if left.next != pair.right || right.prev != pair.left {
|
||||
continue
|
||||
}
|
||||
if left.token+right.token != pair.value {
|
||||
continue
|
||||
}
|
||||
|
||||
nodes[pair.left].token = pair.value
|
||||
nodes[pair.right].token = ""
|
||||
nodes[pair.left].next = right.next
|
||||
if right.next < len(nodes) {
|
||||
nodes[right.next].prev = pair.left
|
||||
}
|
||||
|
||||
if pair := pairwise(nodes[pair.left].prev, pair.left); pair != nil {
|
||||
heap.Push(&pairs, pair)
|
||||
}
|
||||
if pair := pairwise(pair.left, nodes[pair.left].next); pair != nil {
|
||||
heap.Push(&pairs, pair)
|
||||
}
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
if node.token == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if id, ok := t.vocab.Reverse[node.token]; ok {
|
||||
ids = append(ids, id)
|
||||
continue
|
||||
}
|
||||
|
||||
ids = t.appendByteFallback(ids, node.token)
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
func (t *Tokenizer) appendByteFallback(ids []int32, token string) []int32 {
|
||||
if t.typ == TokenizerBPE {
|
||||
for _, r := range token {
|
||||
if b, ok := decodeByteLevelRune(r); ok {
|
||||
if id := t.vocab.byteTokens[b]; id >= 0 {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// SentencePiece fallback uses the UTF-8 bytes for <0xNN> tokens.
|
||||
for _, b := range []byte(token) {
|
||||
if id := t.vocab.byteTokens[b]; id >= 0 {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func decodeByteLevelRune(r rune) (byte, bool) {
|
||||
switch {
|
||||
case r >= 0x00 && r <= 0xFF:
|
||||
return byte(r), true
|
||||
case r == 0x0100:
|
||||
return 0x00, true
|
||||
case r == 0x0143:
|
||||
return 0x00ad, true
|
||||
case r > 0x0100 && r <= 0x0120:
|
||||
return byte(r - 0x0100), true
|
||||
case r > 0x0120 && r <= 0x0142:
|
||||
return byte(r - 0x00a2), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
@@ -1,137 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func equalIDs(a, b []int32) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestEncodeRoundtripMiniLlama(t *testing.T) {
|
||||
tok := benchmarkLoadMiniLlama(t)
|
||||
|
||||
inputs := []string{
|
||||
"",
|
||||
"hello",
|
||||
"hello world",
|
||||
" hello world ",
|
||||
"don't we'll they're",
|
||||
"1234567890",
|
||||
"こんにちは世界",
|
||||
"Hello 世界",
|
||||
"func main() {}",
|
||||
"<|begin_of_text|>system\nYou are concise.<|end_of_text|>",
|
||||
strings.Repeat("The quick brown fox jumps over the lazy dog. ", 32),
|
||||
}
|
||||
|
||||
for _, input := range inputs {
|
||||
ids := tok.Encode(input, false)
|
||||
got := tok.Decode(ids)
|
||||
if got != input {
|
||||
t.Fatalf("roundtrip mismatch for %q: got %q", input, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitBySpecialTokensGreedyLongest(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"vocab": {"a": 0, "b": 1},
|
||||
"merges": []
|
||||
},
|
||||
"added_tokens": [
|
||||
{"id": 2, "content": "<tag>", "special": true},
|
||||
{"id": 3, "content": "<tag>x", "special": true}
|
||||
]
|
||||
}`)
|
||||
|
||||
tok, err := LoadFromBytes(data)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
input := "a<tag>xb"
|
||||
want := []string{"a", "<tag>x", "b"}
|
||||
|
||||
got := tok.splitBySpecialTokens(input)
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("split length mismatch: got %v want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("split mismatch at %d: got %v want %v", i, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitBySpecialTokensFallbackWithoutCache(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"vocab": {"a": 0, "b": 1},
|
||||
"merges": []
|
||||
},
|
||||
"added_tokens": [
|
||||
{"id": 2, "content": "<tag>", "special": true},
|
||||
{"id": 3, "content": "<tag>x", "special": true}
|
||||
]
|
||||
}`)
|
||||
|
||||
tok, err := LoadFromBytes(data)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load tokenizer: %v", err)
|
||||
}
|
||||
|
||||
input := "a<tag>xb"
|
||||
want := []string{"a", "<tag>x", "b"}
|
||||
|
||||
// Simulate construction outside loader path where cache is not set.
|
||||
tok.sortedSpecialTokens = nil
|
||||
|
||||
got := tok.splitBySpecialTokens(input)
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("split length mismatch: got %v want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("split mismatch at %d: got %v want %v", i, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeDeterministicAcrossGOMAXPROCS(t *testing.T) {
|
||||
tok := benchmarkLoadMiniLlama(t)
|
||||
|
||||
input := strings.Repeat("The quick brown fox jumps over the lazy dog. ", 640)
|
||||
|
||||
prev := runtime.GOMAXPROCS(0)
|
||||
defer runtime.GOMAXPROCS(prev)
|
||||
|
||||
runtime.GOMAXPROCS(1)
|
||||
seq := tok.Encode(input, false)
|
||||
|
||||
if prev < 2 {
|
||||
runtime.GOMAXPROCS(2)
|
||||
} else {
|
||||
runtime.GOMAXPROCS(prev)
|
||||
}
|
||||
par := tok.Encode(input, false)
|
||||
|
||||
if !equalIDs(seq, par) {
|
||||
t.Fatalf("encode mismatch between sequential and parallel paths: seq=%d par=%d", len(seq), len(par))
|
||||
}
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Decode converts token IDs back to text
|
||||
func (t *Tokenizer) Decode(ids []int32) string {
|
||||
var sb strings.Builder
|
||||
|
||||
for _, id := range ids {
|
||||
if int(id) >= len(t.vocab.Values) {
|
||||
continue
|
||||
}
|
||||
|
||||
token := t.vocab.Values[id]
|
||||
|
||||
switch t.typ {
|
||||
case TokenizerSentencePiece:
|
||||
// SentencePiece style: replace ▁ with space, decode byte tokens
|
||||
token = strings.ReplaceAll(token, "▁", " ")
|
||||
// Handle byte fallback tokens like <0x0D>
|
||||
if len(token) == 6 && token[0] == '<' && token[1] == '0' && token[2] == 'x' && token[5] == '>' {
|
||||
if v, err := strconv.ParseUint(token[3:5], 16, 8); err == nil {
|
||||
sb.WriteByte(byte(v))
|
||||
continue
|
||||
}
|
||||
}
|
||||
sb.WriteString(token)
|
||||
default:
|
||||
// GPT-2 BPE style: decode byte-level encoding
|
||||
for _, r := range token {
|
||||
switch {
|
||||
case r == 0x0100:
|
||||
// Mirror GGML tokenizer behavior for NULL byte.
|
||||
// 0x00 is omitted during decode.
|
||||
continue
|
||||
case r == 0x0143:
|
||||
r = 0x00ad
|
||||
case r > 0x0100 && r <= 0x0120:
|
||||
r = r - 0x0100
|
||||
case r > 0x0120 && r <= 0x0142:
|
||||
r = r - 0x00a2
|
||||
}
|
||||
|
||||
// Write as byte, not UTF-8 encoded rune
|
||||
sb.WriteByte(byte(r))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
@@ -1,289 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
const (
|
||||
encodeParallelMinInputBytes = 4 * 1024
|
||||
encodeParallelMinChunksPerWorker = 8
|
||||
)
|
||||
|
||||
type tokenMatch struct {
|
||||
start int
|
||||
end int
|
||||
}
|
||||
|
||||
type encodeChunk struct {
|
||||
text string
|
||||
isSpecial bool
|
||||
}
|
||||
|
||||
// isNonNewlineWhitespace returns true if s contains only whitespace characters (no newlines)
|
||||
func isNonNewlineWhitespace(s string) bool {
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
for _, r := range s {
|
||||
if r == '\n' || r == '\r' {
|
||||
return false
|
||||
}
|
||||
if !unicode.IsSpace(r) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// splitBySpecialTokens splits text into parts, keeping special tokens as separate elements
|
||||
func (t *Tokenizer) splitBySpecialTokens(s string) []string {
|
||||
if len(t.specialTokens) == 0 {
|
||||
return []string{s}
|
||||
}
|
||||
|
||||
tokens := t.sortedSpecialTokens
|
||||
if len(tokens) == 0 {
|
||||
// Fallback for tokenizers constructed outside the loaders.
|
||||
tokens = make([]string, 0, len(t.specialTokens))
|
||||
for tok := range t.specialTokens {
|
||||
tokens = append(tokens, tok)
|
||||
}
|
||||
sort.Slice(tokens, func(i, j int) bool {
|
||||
return len(tokens[i]) > len(tokens[j])
|
||||
})
|
||||
}
|
||||
|
||||
var result []string
|
||||
remaining := s
|
||||
|
||||
for len(remaining) > 0 {
|
||||
found := false
|
||||
for _, tok := range tokens {
|
||||
if strings.HasPrefix(remaining, tok) {
|
||||
result = append(result, tok)
|
||||
remaining = remaining[len(tok):]
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
// Find next special token position
|
||||
nextPos := len(remaining)
|
||||
for _, tok := range tokens {
|
||||
if idx := strings.Index(remaining, tok); idx != -1 && idx < nextPos {
|
||||
nextPos = idx
|
||||
}
|
||||
}
|
||||
if nextPos > 0 {
|
||||
result = append(result, remaining[:nextPos])
|
||||
}
|
||||
remaining = remaining[nextPos:]
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func adjustWhitespaceBoundary(part string, curr, next *tokenMatch) {
|
||||
m := part[curr.start:curr.end]
|
||||
nextText := part[next.start:next.end]
|
||||
|
||||
if !isNonNewlineWhitespace(m) || len(nextText) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
firstRune, _ := utf8.DecodeRuneInString(nextText)
|
||||
if !unicode.IsLetter(firstRune) {
|
||||
return
|
||||
}
|
||||
|
||||
lastSpaceStart := curr.end
|
||||
for j := curr.end; j > curr.start; {
|
||||
r, size := utf8.DecodeLastRuneInString(part[curr.start:j])
|
||||
if unicode.IsSpace(r) {
|
||||
lastSpaceStart = j - size
|
||||
break
|
||||
}
|
||||
j -= size
|
||||
}
|
||||
if lastSpaceStart > curr.start {
|
||||
curr.end = lastSpaceStart
|
||||
next.start = lastSpaceStart
|
||||
} else {
|
||||
next.start = curr.start
|
||||
curr.end = curr.start
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tokenizer) forEachPartChunk(part string, fn func(encodeChunk)) {
|
||||
if _, ok := t.specialTokens[part]; ok {
|
||||
fn(encodeChunk{text: part, isSpecial: true})
|
||||
return
|
||||
}
|
||||
|
||||
if t.pretokenizer == nil {
|
||||
fn(encodeChunk{text: part, isSpecial: false})
|
||||
return
|
||||
}
|
||||
|
||||
re := t.pretokenizer
|
||||
offset := 0
|
||||
loc := re.FindStringIndex(part[offset:])
|
||||
if loc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
curr := tokenMatch{start: offset + loc[0], end: offset + loc[1]}
|
||||
offset += loc[1]
|
||||
|
||||
for {
|
||||
loc = re.FindStringIndex(part[offset:])
|
||||
if loc == nil {
|
||||
if curr.end > curr.start {
|
||||
fn(encodeChunk{text: part[curr.start:curr.end], isSpecial: false})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
next := tokenMatch{start: offset + loc[0], end: offset + loc[1]}
|
||||
offset += loc[1]
|
||||
|
||||
adjustWhitespaceBoundary(part, &curr, &next)
|
||||
|
||||
if curr.end > curr.start {
|
||||
fn(encodeChunk{text: part[curr.start:curr.end], isSpecial: false})
|
||||
}
|
||||
curr = next
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tokenizer) appendEncodedChunk(ids []int32, c encodeChunk) []int32 {
|
||||
if c.isSpecial {
|
||||
if id, ok := t.specialTokens[c.text]; ok {
|
||||
return append(ids, id)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
return t.encodeChunkInto(c.text, ids)
|
||||
}
|
||||
|
||||
// Encode tokenizes text to token IDs.
|
||||
// Parallel encoding is used only for very large inputs with enough chunks per worker.
|
||||
func (t *Tokenizer) Encode(s string, addBOS bool) []int32 {
|
||||
// First: split by special tokens
|
||||
parts := t.splitBySpecialTokens(s)
|
||||
|
||||
// Fast path: encode sequentially without materializing chunk slices.
|
||||
if len(s) < encodeParallelMinInputBytes {
|
||||
var ids []int32
|
||||
for _, part := range parts {
|
||||
t.forEachPartChunk(part, func(c encodeChunk) {
|
||||
ids = t.appendEncodedChunk(ids, c)
|
||||
})
|
||||
}
|
||||
|
||||
if addBOS && t.vocab.BOS >= 0 {
|
||||
ids = append([]int32{t.vocab.BOS}, ids...)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// For large inputs collect chunks to enable parallel processing.
|
||||
var allChunks []encodeChunk
|
||||
for _, part := range parts {
|
||||
t.forEachPartChunk(part, func(c encodeChunk) {
|
||||
allChunks = append(allChunks, c)
|
||||
})
|
||||
}
|
||||
|
||||
// Encode chunks. Use the parallel path only when the chunk count is
|
||||
// large enough to amortize goroutine/synchronization overhead.
|
||||
useParallel := true
|
||||
numWorkers := runtime.GOMAXPROCS(0)
|
||||
if numWorkers > len(allChunks) {
|
||||
numWorkers = len(allChunks)
|
||||
}
|
||||
if numWorkers < 2 || len(allChunks) < numWorkers*encodeParallelMinChunksPerWorker {
|
||||
useParallel = false
|
||||
}
|
||||
|
||||
var ids []int32
|
||||
if !useParallel {
|
||||
for _, c := range allChunks {
|
||||
ids = t.appendEncodedChunk(ids, c)
|
||||
}
|
||||
} else {
|
||||
chunksPer := (len(allChunks) + numWorkers - 1) / numWorkers
|
||||
results := make([][]int32, numWorkers)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < numWorkers; i++ {
|
||||
start := i * chunksPer
|
||||
end := start + chunksPer
|
||||
if end > len(allChunks) {
|
||||
end = len(allChunks)
|
||||
}
|
||||
if start >= end {
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(i int, chunks []encodeChunk) {
|
||||
defer wg.Done()
|
||||
var r []int32
|
||||
for _, c := range chunks {
|
||||
r = t.appendEncodedChunk(r, c)
|
||||
}
|
||||
results[i] = r
|
||||
}(i, allChunks[start:end])
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for _, r := range results {
|
||||
ids = append(ids, r...)
|
||||
}
|
||||
}
|
||||
|
||||
if addBOS && t.vocab.BOS >= 0 {
|
||||
ids = append([]int32{t.vocab.BOS}, ids...)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// encodeChunkInto appends encoded tokens to ids and returns the extended slice.
|
||||
// Uses BPE merge algorithm for both BPE and SentencePiece tokenization.
|
||||
func (t *Tokenizer) encodeChunkInto(s string, ids []int32) []int32 {
|
||||
if s == "" {
|
||||
return ids
|
||||
}
|
||||
|
||||
// Apply encoding transformation
|
||||
// SentencePiece: replace space with ▁
|
||||
// BPE: convert bytes using precomputed table (GPT-2 byte-level encoding)
|
||||
var encoded string
|
||||
if t.typ == TokenizerSentencePiece {
|
||||
encoded = strings.ReplaceAll(s, " ", "▁")
|
||||
} else {
|
||||
var sb strings.Builder
|
||||
sb.Grow(len(s) * 2)
|
||||
for i := 0; i < len(s); i++ {
|
||||
sb.WriteRune(byteToRune[s[i]])
|
||||
}
|
||||
encoded = sb.String()
|
||||
}
|
||||
|
||||
// Fast path: check if entire chunk is a single token
|
||||
if id, ok := t.vocab.Reverse[encoded]; ok {
|
||||
return append(ids, id)
|
||||
}
|
||||
|
||||
return t.encodeBPEMerge(encoded, ids)
|
||||
}
|
||||
@@ -1,207 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func llama32GGMLFixturePath(tb testing.TB, file string) string {
|
||||
tb.Helper()
|
||||
|
||||
_, filename, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
tb.Fatal("failed to resolve test file path")
|
||||
}
|
||||
|
||||
return filepath.Join(filepath.Dir(filename), "..", "..", "tokenizer", "testdata", "llama3.2", file)
|
||||
}
|
||||
|
||||
func loadLlama32FromGGMLFixture(tb testing.TB) *Tokenizer {
|
||||
tb.Helper()
|
||||
|
||||
f, err := os.Open(llama32GGMLFixturePath(tb, "encoder.json"))
|
||||
if err != nil {
|
||||
tb.Fatalf("failed to open encoder.json: %v", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
vocab := make(map[string]int32)
|
||||
if err := json.NewDecoder(f).Decode(&vocab); err != nil {
|
||||
tb.Fatalf("failed to decode encoder.json: %v", err)
|
||||
}
|
||||
|
||||
type addedToken struct {
|
||||
ID int32 `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Special bool `json:"special"`
|
||||
}
|
||||
var addedTokens []addedToken
|
||||
for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} {
|
||||
if _, ok := vocab[token]; !ok {
|
||||
id := int32(len(vocab))
|
||||
vocab[token] = id
|
||||
addedTokens = append(addedTokens, addedToken{ID: id, Content: token, Special: true})
|
||||
}
|
||||
}
|
||||
|
||||
mf, err := os.Open(llama32GGMLFixturePath(tb, "vocab.bpe"))
|
||||
if err != nil {
|
||||
tb.Fatalf("failed to open vocab.bpe: %v", err)
|
||||
}
|
||||
defer mf.Close()
|
||||
|
||||
var merges []string
|
||||
scanner := bufio.NewScanner(mf)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" {
|
||||
merges = append(merges, line)
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
tb.Fatalf("failed to read vocab.bpe: %v", err)
|
||||
}
|
||||
|
||||
payload := struct {
|
||||
Model struct {
|
||||
Type string `json:"type"`
|
||||
Vocab map[string]int32 `json:"vocab"`
|
||||
Merges []string `json:"merges"`
|
||||
} `json:"model"`
|
||||
PreTokenizer struct {
|
||||
Type string `json:"type"`
|
||||
Pretokenizers []struct {
|
||||
Type string `json:"type"`
|
||||
Pattern struct {
|
||||
Regex string `json:"Regex"`
|
||||
} `json:"pattern"`
|
||||
} `json:"pretokenizers"`
|
||||
} `json:"pre_tokenizer"`
|
||||
AddedTokens []addedToken `json:"added_tokens"`
|
||||
}{}
|
||||
|
||||
payload.Model.Type = "BPE"
|
||||
payload.Model.Vocab = vocab
|
||||
payload.Model.Merges = merges
|
||||
payload.PreTokenizer.Type = "Sequence"
|
||||
payload.PreTokenizer.Pretokenizers = []struct {
|
||||
Type string `json:"type"`
|
||||
Pattern struct {
|
||||
Regex string `json:"Regex"`
|
||||
} `json:"pattern"`
|
||||
}{
|
||||
{
|
||||
Type: "Split",
|
||||
Pattern: struct {
|
||||
Regex string `json:"Regex"`
|
||||
}{
|
||||
Regex: `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
},
|
||||
},
|
||||
}
|
||||
payload.AddedTokens = addedTokens
|
||||
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
tb.Fatalf("failed to marshal synthetic tokenizer.json: %v", err)
|
||||
}
|
||||
|
||||
tok, err := LoadFromBytes(data)
|
||||
if err != nil {
|
||||
tb.Fatalf("failed to load tokenizer from fixture data: %v", err)
|
||||
}
|
||||
return tok
|
||||
}
|
||||
|
||||
func TestGGMLLlamaKnownEncodings(t *testing.T) {
|
||||
tok := loadLlama32FromGGMLFixture(t)
|
||||
|
||||
cases := map[string][]int32{
|
||||
"hello world": {15339, 1917},
|
||||
"hello <|end_of_text|>": {15339, 220, 128001},
|
||||
"<|begin_of_text|>A B!": {128000, 32, 426, 0},
|
||||
"<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0},
|
||||
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0},
|
||||
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001},
|
||||
}
|
||||
|
||||
for input, want := range cases {
|
||||
got := tok.Encode(input, false)
|
||||
if !equalIDs(got, want) {
|
||||
t.Fatalf("encode mismatch for %q:\n got: %v\n want: %v", input, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGGMLLlamaRepeatedZeros(t *testing.T) {
|
||||
tok := loadLlama32FromGGMLFixture(t)
|
||||
|
||||
cases := map[int][]int32{
|
||||
1: {15},
|
||||
2: {410},
|
||||
3: {931},
|
||||
4: {931, 15},
|
||||
5: {931, 410},
|
||||
6: {931, 931},
|
||||
7: {931, 931, 15},
|
||||
8: {931, 931, 410},
|
||||
9: {931, 931, 931},
|
||||
10: {931, 931, 931, 15},
|
||||
11: {931, 931, 931, 410},
|
||||
12: {931, 931, 931, 931},
|
||||
13: {931, 931, 931, 931, 15},
|
||||
14: {931, 931, 931, 931, 410},
|
||||
15: {931, 931, 931, 931, 931},
|
||||
16: {931, 931, 931, 931, 931, 15},
|
||||
17: {931, 931, 931, 931, 931, 410},
|
||||
}
|
||||
|
||||
for n, want := range cases {
|
||||
input := strings.Repeat("0", n)
|
||||
got := tok.Encode(input, false)
|
||||
if !equalIDs(got, want) {
|
||||
t.Fatalf("encode mismatch for %q:\n got: %v\n want: %v", input, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGGMLLlamaRoundtripAndByteBehavior(t *testing.T) {
|
||||
tok := loadLlama32FromGGMLFixture(t)
|
||||
|
||||
cases := []string{
|
||||
"hello",
|
||||
"hello ",
|
||||
"hello ",
|
||||
" hello",
|
||||
" hello ",
|
||||
" hello ",
|
||||
"hello world",
|
||||
"请考试我的软件!12345",
|
||||
}
|
||||
|
||||
for _, input := range cases {
|
||||
ids := tok.Encode(input, false)
|
||||
got := tok.Decode(ids)
|
||||
if got != input {
|
||||
t.Fatalf("roundtrip mismatch for %q: got %q", input, got)
|
||||
}
|
||||
}
|
||||
|
||||
// Match GGML tokenizer behavior: 0x00 is omitted when decoding.
|
||||
ids := tok.Encode(string(rune(0x00)), false)
|
||||
got := tok.Decode(ids)
|
||||
if got != "" {
|
||||
t.Fatalf("expected empty decode for 0x00, got %q (ids=%v)", got, ids)
|
||||
}
|
||||
}
|
||||
@@ -1,458 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// TokenizerConfig holds optional configuration data that can be passed to LoadFromBytesWithConfig.
|
||||
type TokenizerConfig struct {
|
||||
TokenizerConfigJSON []byte // tokenizer_config.json content
|
||||
GenerationConfigJSON []byte // generation_config.json content
|
||||
SpecialTokensMapJSON []byte // special_tokens_map.json content
|
||||
ConfigJSON []byte // config.json content
|
||||
}
|
||||
|
||||
// LoadFromBytes loads a tokenizer from tokenizer.json bytes.
|
||||
// This is useful when loading from blob storage where the file content is already in memory.
|
||||
// Note: This won't load special token config from companion files. Use LoadFromBytesWithConfig
|
||||
// to provide tokenizer_config.json data for proper PAD/EOS token loading.
|
||||
func LoadFromBytes(data []byte) (*Tokenizer, error) {
|
||||
return loadFromTokenizerJSON(data)
|
||||
}
|
||||
|
||||
// LoadFromBytesWithConfig loads a tokenizer from tokenizer.json bytes with additional config files.
|
||||
// This is useful when loading from blob storage where companion config files are also blobs.
|
||||
func LoadFromBytesWithConfig(data []byte, config *TokenizerConfig) (*Tokenizer, error) {
|
||||
t, err := loadFromTokenizerJSON(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// Apply special token configs from provided data
|
||||
loadSpecialTokenConfigFromBytes(t, config)
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// loadFromTokenizerJSON parses tokenizer.json content from bytes.
|
||||
func loadFromTokenizerJSON(data []byte) (*Tokenizer, error) {
|
||||
|
||||
var raw struct {
|
||||
Model struct {
|
||||
Type string `json:"type"` // "BPE"
|
||||
Vocab map[string]int32 `json:"vocab"`
|
||||
Merges json.RawMessage `json:"merges"` // Can be []string or [][]string (BPE only)
|
||||
} `json:"model"`
|
||||
PreTokenizer json.RawMessage `json:"pre_tokenizer"`
|
||||
Decoder json.RawMessage `json:"decoder"`
|
||||
AddedTokens []struct {
|
||||
ID int32 `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Special bool `json:"special"`
|
||||
} `json:"added_tokens"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse tokenizer: %w", err)
|
||||
}
|
||||
|
||||
// Covers SentencePiece and BPE models
|
||||
if raw.Model.Type != "BPE" {
|
||||
return nil, fmt.Errorf("unsupported tokenizer type: %s", raw.Model.Type)
|
||||
}
|
||||
|
||||
// Parse merges - can be []string (Llama) or [][]string (GPT-OSS).
|
||||
var mergesStrings []string
|
||||
if raw.Model.Merges != nil {
|
||||
var mergesArrays [][]string
|
||||
if err := json.Unmarshal(raw.Model.Merges, &mergesStrings); err != nil {
|
||||
// Try array of arrays format
|
||||
if err := json.Unmarshal(raw.Model.Merges, &mergesArrays); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse merges: %w", err)
|
||||
}
|
||||
// Convert [][]string to []string
|
||||
mergesStrings = make([]string, len(mergesArrays))
|
||||
for i, pair := range mergesArrays {
|
||||
if len(pair) != 2 {
|
||||
return nil, fmt.Errorf("failed to parse merges: expected merge pair of length 2, got %d", len(pair))
|
||||
}
|
||||
mergesStrings[i] = pair[0] + " " + pair[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build tokenizer
|
||||
t := &Tokenizer{
|
||||
vocab: &Vocabulary{
|
||||
Values: make([]string, len(raw.Model.Vocab)),
|
||||
Reverse: raw.Model.Vocab,
|
||||
Merges: make(map[string]int, len(mergesStrings)),
|
||||
BOS: -1,
|
||||
PAD: -1,
|
||||
},
|
||||
specialTokens: make(map[string]int32),
|
||||
}
|
||||
|
||||
// Build values array
|
||||
for token, id := range raw.Model.Vocab {
|
||||
if int(id) >= len(t.vocab.Values) {
|
||||
newValues := make([]string, id+1)
|
||||
copy(newValues, t.vocab.Values)
|
||||
t.vocab.Values = newValues
|
||||
}
|
||||
t.vocab.Values[id] = token
|
||||
}
|
||||
|
||||
// Build merges map
|
||||
for i, merge := range mergesStrings {
|
||||
t.vocab.Merges[merge] = i
|
||||
}
|
||||
|
||||
// Add all added_tokens to vocabulary and special tokens map.
|
||||
// HuggingFace treats ALL added_tokens as special for tokenization purposes -
|
||||
// they bypass BPE and get their own token ID. The "special" flag just indicates
|
||||
// if it's a "truly special" token like BOS/EOS/PAD, but for tokenization we need
|
||||
// to treat all added_tokens as special to match HuggingFace behavior.
|
||||
for _, tok := range raw.AddedTokens {
|
||||
if int(tok.ID) >= len(t.vocab.Values) {
|
||||
newValues := make([]string, tok.ID+1)
|
||||
copy(newValues, t.vocab.Values)
|
||||
t.vocab.Values = newValues
|
||||
}
|
||||
t.vocab.Values[tok.ID] = tok.Content
|
||||
t.specialTokens[tok.Content] = tok.ID // Add ALL added_tokens to special tokens
|
||||
}
|
||||
|
||||
// Precompute byte token IDs for <0xNN> fallback
|
||||
initByteTokens(t)
|
||||
|
||||
// Determine tokenizer type
|
||||
switch {
|
||||
case detectSentencePiece(raw.Decoder):
|
||||
t.typ = TokenizerSentencePiece
|
||||
default:
|
||||
t.typ = TokenizerBPE
|
||||
}
|
||||
|
||||
// Parse and compile pretokenizer pattern (BPE only - SentencePiece doesn't use pretokenizer)
|
||||
if t.typ == TokenizerBPE {
|
||||
pattern := extractPretokenizer(raw.PreTokenizer)
|
||||
if pattern == "" {
|
||||
pattern = `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`
|
||||
}
|
||||
re, err := regexp.Compile(rewritePatternForRE2(pattern))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to compile pretokenizer regex %q: %w", pattern, err)
|
||||
}
|
||||
t.pretokenizer = re
|
||||
}
|
||||
|
||||
cacheSortedSpecialTokens(t)
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func cacheSortedSpecialTokens(t *Tokenizer) {
|
||||
if len(t.specialTokens) == 0 {
|
||||
t.sortedSpecialTokens = nil
|
||||
return
|
||||
}
|
||||
|
||||
tokens := make([]string, 0, len(t.specialTokens))
|
||||
for tok := range t.specialTokens {
|
||||
tokens = append(tokens, tok)
|
||||
}
|
||||
sort.Slice(tokens, func(i, j int) bool {
|
||||
return len(tokens[i]) > len(tokens[j])
|
||||
})
|
||||
t.sortedSpecialTokens = tokens
|
||||
}
|
||||
|
||||
type specialTokenConfigData struct {
|
||||
tokenizerConfigJSON []byte
|
||||
generationConfigJSON []byte
|
||||
specialTokensMapJSON []byte
|
||||
configJSON []byte
|
||||
}
|
||||
|
||||
func applySpecialTokenConfig(t *Tokenizer, config specialTokenConfigData) {
|
||||
parseTokenIDs := func(v interface{}) []int32 {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return []int32{int32(val)}
|
||||
case []interface{}:
|
||||
ids := make([]int32, 0, len(val))
|
||||
for _, id := range val {
|
||||
if f, ok := id.(float64); ok {
|
||||
ids = append(ids, int32(f))
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Priority 1: generation_config.json
|
||||
if len(config.generationConfigJSON) > 0 {
|
||||
var genConfig struct {
|
||||
EOSTokenID interface{} `json:"eos_token_id"`
|
||||
BOSTokenID interface{} `json:"bos_token_id"`
|
||||
}
|
||||
if err := json.Unmarshal(config.generationConfigJSON, &genConfig); err == nil {
|
||||
if ids := parseTokenIDs(genConfig.EOSTokenID); len(ids) > 0 {
|
||||
t.vocab.EOS = ids
|
||||
}
|
||||
if ids := parseTokenIDs(genConfig.BOSTokenID); len(ids) > 0 {
|
||||
t.vocab.BOS = ids[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 2: config.json
|
||||
if len(config.configJSON) > 0 && (len(t.vocab.EOS) == 0 || t.vocab.BOS < 0) {
|
||||
var modelConfig struct {
|
||||
EOSTokenID interface{} `json:"eos_token_id"`
|
||||
BOSTokenID interface{} `json:"bos_token_id"`
|
||||
}
|
||||
if err := json.Unmarshal(config.configJSON, &modelConfig); err == nil {
|
||||
if len(t.vocab.EOS) == 0 {
|
||||
if ids := parseTokenIDs(modelConfig.EOSTokenID); len(ids) > 0 {
|
||||
t.vocab.EOS = ids
|
||||
}
|
||||
}
|
||||
if t.vocab.BOS < 0 {
|
||||
if ids := parseTokenIDs(modelConfig.BOSTokenID); len(ids) > 0 {
|
||||
t.vocab.BOS = ids[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 3: tokenizer_config.json
|
||||
if len(config.tokenizerConfigJSON) > 0 {
|
||||
var tokConfig struct {
|
||||
BOSToken interface{} `json:"bos_token"`
|
||||
EOSToken interface{} `json:"eos_token"`
|
||||
PADToken interface{} `json:"pad_token"`
|
||||
AddBOSToken *bool `json:"add_bos_token"`
|
||||
AddEOSToken *bool `json:"add_eos_token"`
|
||||
}
|
||||
if err := json.Unmarshal(config.tokenizerConfigJSON, &tokConfig); err == nil {
|
||||
if t.vocab.BOS < 0 {
|
||||
if bosStr := extractTokenString(tokConfig.BOSToken); bosStr != "" {
|
||||
if id, ok := t.specialTokens[bosStr]; ok {
|
||||
t.vocab.BOS = id
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(t.vocab.EOS) == 0 {
|
||||
if eosStr := extractTokenString(tokConfig.EOSToken); eosStr != "" {
|
||||
if id, ok := t.specialTokens[eosStr]; ok {
|
||||
t.vocab.EOS = []int32{id}
|
||||
}
|
||||
}
|
||||
}
|
||||
if t.vocab.PAD < 0 {
|
||||
if padStr := extractTokenString(tokConfig.PADToken); padStr != "" {
|
||||
if id, ok := t.specialTokens[padStr]; ok {
|
||||
t.vocab.PAD = id
|
||||
}
|
||||
}
|
||||
}
|
||||
if tokConfig.AddBOSToken != nil {
|
||||
t.vocab.AddBOS = *tokConfig.AddBOSToken
|
||||
}
|
||||
if tokConfig.AddEOSToken != nil {
|
||||
t.vocab.AddEOS = *tokConfig.AddEOSToken
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 4: special_tokens_map.json
|
||||
if len(config.specialTokensMapJSON) > 0 {
|
||||
var tokensMap map[string]interface{}
|
||||
if err := json.Unmarshal(config.specialTokensMapJSON, &tokensMap); err == nil {
|
||||
if t.vocab.BOS < 0 {
|
||||
if bosStr := extractTokenString(tokensMap["bos_token"]); bosStr != "" {
|
||||
if id, ok := t.specialTokens[bosStr]; ok {
|
||||
t.vocab.BOS = id
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(t.vocab.EOS) == 0 {
|
||||
if eosStr := extractTokenString(tokensMap["eos_token"]); eosStr != "" {
|
||||
if id, ok := t.specialTokens[eosStr]; ok {
|
||||
t.vocab.EOS = []int32{id}
|
||||
}
|
||||
}
|
||||
}
|
||||
if t.vocab.PAD < 0 {
|
||||
if padStr := extractTokenString(tokensMap["pad_token"]); padStr != "" {
|
||||
if id, ok := t.specialTokens[padStr]; ok {
|
||||
t.vocab.PAD = id
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractTokenString extracts the token string from various formats used in HuggingFace configs.
|
||||
// Tokens can be represented as:
|
||||
// - string: "token"
|
||||
// - object: {"content": "token", ...}
|
||||
func extractTokenString(v interface{}) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
// Direct string
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
// Object with content field
|
||||
if m, ok := v.(map[string]interface{}); ok {
|
||||
if content, ok := m["content"].(string); ok {
|
||||
return content
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// rewritePatternForRE2 rewrites HuggingFace pretokenizer regex patterns to be
|
||||
// compatible with Go's regexp package (RE2). HuggingFace patterns use PCRE features:
|
||||
// - (?!\S) negative lookahead - RE2 doesn't support this
|
||||
// - (?i:...) inline case-insensitive groups - RE2 doesn't support this
|
||||
//
|
||||
// We replace \s+(?!\S)|\s+ with \s+ and fix whitespace boundaries in encodeWithRegex().
|
||||
// The lookahead version splits "a b" into ["a", " ", " b"] (space prepended to word).
|
||||
// Simple \s+ would give ["a", " ", "b"]. We post-process to match Python's behavior.
|
||||
func rewritePatternForRE2(pattern string) string {
|
||||
// Replace lookahead pattern with simple \s+ - we fix boundaries in encodeWithRegex()
|
||||
pattern = strings.ReplaceAll(pattern, `\s+(?!\S)|\s+`, `\s+`)
|
||||
|
||||
// Handle the pattern when it appears with a ? suffix (optional contractions in GPT-4o style)
|
||||
// IMPORTANT: Must be done before the non-optional version to avoid partial replacement
|
||||
pattern = strings.ReplaceAll(pattern,
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
|
||||
`(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?`)
|
||||
|
||||
// Expand case-insensitive contraction pattern to explicit alternations
|
||||
// (?i:'s|'t|'re|'ve|'m|'ll|'d) -> '[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD]
|
||||
pattern = strings.ReplaceAll(pattern,
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)`,
|
||||
`(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])`)
|
||||
|
||||
return pattern
|
||||
}
|
||||
|
||||
// loadSpecialTokenConfigFromBytes loads special token configuration from byte slices.
|
||||
func loadSpecialTokenConfigFromBytes(t *Tokenizer, config *TokenizerConfig) {
|
||||
applySpecialTokenConfig(t, specialTokenConfigData{
|
||||
tokenizerConfigJSON: config.TokenizerConfigJSON,
|
||||
generationConfigJSON: config.GenerationConfigJSON,
|
||||
specialTokensMapJSON: config.SpecialTokensMapJSON,
|
||||
configJSON: config.ConfigJSON,
|
||||
})
|
||||
}
|
||||
|
||||
// detectSentencePiece checks if the decoder uses SentencePiece-style (▁ for spaces)
|
||||
// vs GPT-2 byte-level encoding
|
||||
func detectSentencePiece(data json.RawMessage) bool {
|
||||
if data == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for Sequence decoder with Replace step (SentencePiece style)
|
||||
var seq struct {
|
||||
Type string `json:"type"`
|
||||
Decoders []struct {
|
||||
Type string `json:"type"`
|
||||
Pattern struct {
|
||||
String string `json:"String"`
|
||||
} `json:"pattern"`
|
||||
} `json:"decoders"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &seq); err == nil {
|
||||
if seq.Type == "Sequence" {
|
||||
for _, dec := range seq.Decoders {
|
||||
// Look for Replace decoder that converts ▁ to space
|
||||
if dec.Type == "Replace" && dec.Pattern.String == "▁" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for direct ByteLevel decoder (GPT-2 style)
|
||||
var simple struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &simple); err == nil {
|
||||
if simple.Type == "ByteLevel" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// initByteTokens precomputes byte token IDs for <0xNN> fallback encoding
|
||||
func initByteTokens(t *Tokenizer) {
|
||||
for i := range t.vocab.byteTokens {
|
||||
t.vocab.byteTokens[i] = -1
|
||||
}
|
||||
for b := 0; b < 256; b++ {
|
||||
token := fmt.Sprintf("<0x%02X>", b)
|
||||
if id, ok := t.vocab.Reverse[token]; ok {
|
||||
t.vocab.byteTokens[b] = id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractPretokenizer extracts the regex pattern from the pre_tokenizer config
|
||||
func extractPretokenizer(data json.RawMessage) string {
|
||||
if data == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Try to parse as a single Split pretokenizer
|
||||
var single struct {
|
||||
Type string `json:"type"`
|
||||
Pattern struct {
|
||||
Regex string `json:"Regex"`
|
||||
} `json:"pattern"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &single); err == nil && single.Pattern.Regex != "" {
|
||||
return single.Pattern.Regex
|
||||
}
|
||||
|
||||
// Try to parse as Sequence of pretokenizers - use first Split pattern
|
||||
var seq struct {
|
||||
Type string `json:"type"`
|
||||
Pretokenizers []struct {
|
||||
Type string `json:"type"`
|
||||
Pattern struct {
|
||||
Regex string `json:"Regex"`
|
||||
} `json:"pattern"`
|
||||
} `json:"pretokenizers"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &seq); err == nil && seq.Type == "Sequence" {
|
||||
for _, pt := range seq.Pretokenizers {
|
||||
if pt.Type == "Split" && pt.Pattern.Regex != "" {
|
||||
return pt.Pattern.Regex
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
//go:build mlx
|
||||
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadFromBytesRejectsWordPiece(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"model": {
|
||||
"type": "WordPiece",
|
||||
"vocab": {"[UNK]": 0, "hello": 1}
|
||||
},
|
||||
"added_tokens": []
|
||||
}`)
|
||||
|
||||
_, err := LoadFromBytes(data)
|
||||
if err == nil {
|
||||
t.Fatal("expected WordPiece load to fail")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unsupported tokenizer type: WordPiece") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user