mirror of
https://github.com/ollama/ollama.git
synced 2026-02-19 07:45:22 -05:00
Compare commits
2 Commits
main
...
brucemacd/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
365a3657ad | ||
|
|
71c1d8d0a9 |
@@ -1 +1 @@
|
||||
v0.4.1
|
||||
v0.5.0
|
||||
|
||||
22
auth/auth.go
22
auth/auth.go
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -83,3 +84,24 @@ func Sign(ctx context.Context, bts []byte) (string, error) {
|
||||
// signature is <pubkey>:<signature>
|
||||
return fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)), nil
|
||||
}
|
||||
|
||||
// SignRequest adds a nonce query parameter and an Authorization header with
|
||||
// an Ed25519 signature to req.
|
||||
func SignRequest(ctx context.Context, req *http.Request) error {
|
||||
nonce, err := NewNonce(rand.Reader, 16)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
q := req.URL.Query()
|
||||
q.Set("nonce", nonce)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
|
||||
data := []byte(fmt.Sprintf("%s,%s", req.Method, req.URL.RequestURI()))
|
||||
signature, err := Sign(ctx, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Authorization", signature)
|
||||
return nil
|
||||
}
|
||||
|
||||
28
cmd/cmd.go
28
cmd/cmd.go
@@ -1900,6 +1900,21 @@ func runInteractiveTUI(cmd *cobra.Command) {
|
||||
return
|
||||
}
|
||||
|
||||
if version.Version != "0.0.0" && version.IsOfficialInstall() && version.IsLocalHost(envconfig.Host()) {
|
||||
if version.HasCachedUpdate() {
|
||||
fmt.Print("A new version of Ollama is available. Run \"ollama update\" to install.\n\n")
|
||||
_ = version.ClearCachedUpdate()
|
||||
}
|
||||
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if available, err := version.CheckForUpdate(ctx); err == nil && available {
|
||||
_ = version.CacheAvailableUpdate()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Selector adapters for tui
|
||||
singleSelector := func(title string, items []config.ModelItem, current string) (string, error) {
|
||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||
@@ -2317,6 +2332,18 @@ func NewCLI() *cobra.Command {
|
||||
}
|
||||
}
|
||||
|
||||
updateCmd := &cobra.Command{
|
||||
Use: "update",
|
||||
Short: "Update Ollama to the latest version",
|
||||
Args: cobra.ExactArgs(0),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
force, _ := cmd.Flags().GetBool("force")
|
||||
_ = version.ClearCachedUpdate()
|
||||
return version.DoUpdate(force)
|
||||
},
|
||||
}
|
||||
updateCmd.Flags().BoolP("force", "f", false, "Force update even if installed via a package manager")
|
||||
|
||||
rootCmd.AddCommand(
|
||||
serveCmd,
|
||||
createCmd,
|
||||
@@ -2334,6 +2361,7 @@ func NewCLI() *cobra.Command {
|
||||
copyCmd,
|
||||
deleteCmd,
|
||||
runnerCmd,
|
||||
updateCmd,
|
||||
config.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
|
||||
)
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"golang.org/x/mod/semver"
|
||||
)
|
||||
|
||||
@@ -33,10 +32,6 @@ func (c *Codex) Run(model string, args []string) error {
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Env = append(os.Environ(),
|
||||
"OPENAI_BASE_URL="+envconfig.Host().String()+"/v1/",
|
||||
"OPENAI_API_KEY=ollama",
|
||||
)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
|
||||
1
go.mod
1
go.mod
@@ -26,7 +26,6 @@ require (
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||
github.com/dlclark/regexp2 v1.11.4
|
||||
github.com/emirpasic/gods/v2 v2.0.0-alpha
|
||||
github.com/klauspost/compress v1.18.3
|
||||
github.com/mattn/go-runewidth v0.0.16
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
|
||||
4
go.sum
4
go.sum
@@ -122,6 +122,7 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
|
||||
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=
|
||||
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/flatbuffers v2.0.0+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
|
||||
github.com/google/flatbuffers v24.3.25+incompatible h1:CX395cjN9Kke9mmalRoL3d81AtFUxJM+yDthflgJGkI=
|
||||
@@ -149,9 +150,8 @@ github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+
|
||||
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.13.1 h1:wXr2uRxZTJXHLly6qhJabee5JqIhTRoLBhDOA74hDEQ=
|
||||
github.com/klauspost/compress v1.13.1/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg=
|
||||
github.com/klauspost/compress v1.18.3 h1:9PJRvfbmTabkOX8moIpXPbMMbYN60bWImDDU7L+/6zw=
|
||||
github.com/klauspost/compress v1.18.3/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
|
||||
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/openai"
|
||||
@@ -497,17 +496,6 @@ func (w *ResponsesWriter) Write(data []byte) (int, error) {
|
||||
|
||||
func ResponsesMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.GetHeader("Content-Encoding") == "zstd" {
|
||||
reader, err := zstd.NewReader(c.Request.Body, zstd.WithDecoderMaxMemory(8<<20))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "failed to decompress zstd body"))
|
||||
return
|
||||
}
|
||||
defer reader.Close()
|
||||
c.Request.Body = io.NopCloser(reader)
|
||||
c.Request.Header.Del("Content-Encoding")
|
||||
}
|
||||
|
||||
var req openai.ResponsesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/openai"
|
||||
@@ -1239,102 +1238,3 @@ func TestImageEditsMiddleware(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func zstdCompress(t *testing.T, data []byte) []byte {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
w, err := zstd.NewWriter(&buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := w.Write(data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func TestResponsesMiddlewareZstd(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
useZstd bool
|
||||
oversized bool
|
||||
wantCode int
|
||||
wantModel string
|
||||
wantMessage string
|
||||
}{
|
||||
{
|
||||
name: "plain JSON",
|
||||
body: `{"model": "test-model", "input": "Hello"}`,
|
||||
wantCode: http.StatusOK,
|
||||
wantModel: "test-model",
|
||||
wantMessage: "Hello",
|
||||
},
|
||||
{
|
||||
name: "zstd compressed",
|
||||
body: `{"model": "test-model", "input": "Hello"}`,
|
||||
useZstd: true,
|
||||
wantCode: http.StatusOK,
|
||||
wantModel: "test-model",
|
||||
wantMessage: "Hello",
|
||||
},
|
||||
{
|
||||
name: "zstd over max decompressed size",
|
||||
oversized: true,
|
||||
useZstd: true,
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var capturedRequest *api.ChatRequest
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(ResponsesMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/v1/responses", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
var bodyReader io.Reader
|
||||
if tt.oversized {
|
||||
bodyReader = bytes.NewReader(zstdCompress(t, bytes.Repeat([]byte("A"), 9<<20)))
|
||||
} else if tt.useZstd {
|
||||
bodyReader = bytes.NewReader(zstdCompress(t, []byte(tt.body)))
|
||||
} else {
|
||||
bodyReader = strings.NewReader(tt.body)
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/v1/responses", bodyReader)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if tt.useZstd || tt.oversized {
|
||||
req.Header.Set("Content-Encoding", "zstd")
|
||||
}
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != tt.wantCode {
|
||||
t.Fatalf("expected status %d, got %d: %s", tt.wantCode, resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
if tt.wantCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
|
||||
if capturedRequest == nil {
|
||||
t.Fatal("expected captured request, got nil")
|
||||
}
|
||||
if capturedRequest.Model != tt.wantModel {
|
||||
t.Fatalf("expected model %q, got %q", tt.wantModel, capturedRequest.Model)
|
||||
}
|
||||
if len(capturedRequest.Messages) != 1 || capturedRequest.Messages[0].Content != tt.wantMessage {
|
||||
t.Fatalf("expected single user message %q, got %+v", tt.wantMessage, capturedRequest.Messages)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,10 +2,6 @@
|
||||
# This script installs Ollama on Linux and macOS.
|
||||
# It detects the current operating system architecture and installs the appropriate version of Ollama.
|
||||
|
||||
# Wrap script in main function so that a truncated partial download doesn't end
|
||||
# up executing half a script.
|
||||
main() {
|
||||
|
||||
set -eu
|
||||
|
||||
red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)"
|
||||
@@ -450,6 +446,3 @@ fi
|
||||
|
||||
status "NVIDIA GPU ready."
|
||||
install_success
|
||||
}
|
||||
|
||||
main
|
||||
|
||||
190
version/update.go
Normal file
190
version/update.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package version
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/auth"
|
||||
)
|
||||
|
||||
var updateCheckURLBase = "https://ollama.com"
|
||||
|
||||
// CheckForUpdate calls the ollama.com update API and reports whether a
|
||||
// newer version is available.
|
||||
func CheckForUpdate(ctx context.Context) (bool, error) {
|
||||
requestURL, err := url.Parse(updateCheckURLBase + "/api/update")
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parse update URL: %w", err)
|
||||
}
|
||||
|
||||
query := requestURL.Query()
|
||||
query.Add("os", runtime.GOOS)
|
||||
query.Add("arch", runtime.GOARCH)
|
||||
query.Add("version", Version)
|
||||
requestURL.RawQuery = query.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
_ = auth.SignRequest(ctx, req)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("update check request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return resp.StatusCode == http.StatusOK, nil
|
||||
}
|
||||
|
||||
func cacheFilePath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".ollama", "update"), nil
|
||||
}
|
||||
|
||||
// CacheAvailableUpdate creates the update marker file.
|
||||
func CacheAvailableUpdate() error {
|
||||
path, err := cacheFilePath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return f.Close()
|
||||
}
|
||||
|
||||
// HasCachedUpdate reports whether a non-stale update marker exists.
|
||||
func HasCachedUpdate() bool {
|
||||
path, err := cacheFilePath()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
fi, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return time.Since(fi.ModTime()) <= 24*time.Hour
|
||||
}
|
||||
|
||||
// ClearCachedUpdate removes the update marker file.
|
||||
func ClearCachedUpdate() error {
|
||||
path, err := cacheFilePath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = os.Remove(path)
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func IsOfficialInstall() bool {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
exe, err = filepath.EvalSymlinks(exe)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
localAppData := os.Getenv("LOCALAPPDATA")
|
||||
if localAppData == "" {
|
||||
return false
|
||||
}
|
||||
return strings.HasPrefix(strings.ToLower(exe), strings.ToLower(filepath.Join(localAppData, "Programs", "Ollama")+string(filepath.Separator)))
|
||||
case "darwin":
|
||||
return strings.HasPrefix(exe, "/Applications/Ollama.app/")
|
||||
default:
|
||||
dir := filepath.Dir(exe)
|
||||
return dir == "/usr/local/bin" || dir == "/usr/bin" || dir == "/bin"
|
||||
}
|
||||
}
|
||||
|
||||
// DoUpdate downloads and runs the platform-appropriate install script.
|
||||
func DoUpdate(force bool) error {
|
||||
if !force && !IsOfficialInstall() {
|
||||
return fmt.Errorf("ollama appears to be installed through a package manager. Please update it using your package manager")
|
||||
}
|
||||
|
||||
var scriptURL, tmpPattern, shell string
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
scriptURL = "https://ollama.com/install.ps1"
|
||||
tmpPattern = "ollama-install-*.ps1"
|
||||
shell = "powershell"
|
||||
default:
|
||||
scriptURL = "https://ollama.com/install.sh"
|
||||
tmpPattern = "ollama-install-*.sh"
|
||||
shell = "sh"
|
||||
}
|
||||
|
||||
resp, err := http.Get(scriptURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("download install script: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("download install script: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
tmpFile, err := os.CreateTemp("", tmpPattern)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create temp file: %w", err)
|
||||
}
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
if _, err := io.Copy(tmpFile, resp.Body); err != nil {
|
||||
tmpFile.Close()
|
||||
return fmt.Errorf("write install script: %w", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
|
||||
cmd := exec.Command(shell, tmpFile.Name())
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// IsLocalHost reports whether the configured Ollama host points to the
|
||||
// local machine.
|
||||
func IsLocalHost(host *url.URL) bool {
|
||||
hostname := host.Hostname()
|
||||
switch hostname {
|
||||
case "", "127.0.0.1", "localhost", "::1", "0.0.0.0":
|
||||
return true
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(hostname); ip != nil {
|
||||
return ip.IsLoopback()
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
146
version/update_test.go
Normal file
146
version/update_test.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package version
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func setHome(t *testing.T, dir string) {
|
||||
t.Helper()
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Setenv("USERPROFILE", dir)
|
||||
} else {
|
||||
t.Setenv("HOME", dir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckForUpdate(t *testing.T) {
|
||||
t.Run("update available", func(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Query().Get("os") == "" || r.URL.Query().Get("arch") == "" || r.URL.Query().Get("version") == "" {
|
||||
t.Error("missing expected query parameters")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
old := updateCheckURLBase
|
||||
updateCheckURLBase = ts.URL
|
||||
defer func() { updateCheckURLBase = old }()
|
||||
|
||||
available, err := CheckForUpdate(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !available {
|
||||
t.Fatal("expected update to be available")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("up to date", func(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
old := updateCheckURLBase
|
||||
updateCheckURLBase = ts.URL
|
||||
defer func() { updateCheckURLBase = old }()
|
||||
|
||||
available, err := CheckForUpdate(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if available {
|
||||
t.Fatal("expected no update available")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("network error", func(t *testing.T) {
|
||||
old := updateCheckURLBase
|
||||
updateCheckURLBase = "http://localhost:1"
|
||||
defer func() { updateCheckURLBase = old }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err := CheckForUpdate(ctx)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unreachable server")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCacheRoundTrip(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
setHome(t, tmp)
|
||||
os.MkdirAll(filepath.Join(tmp, ".ollama"), 0o755)
|
||||
|
||||
if err := CacheAvailableUpdate(); err != nil {
|
||||
t.Fatalf("cache write: %v", err)
|
||||
}
|
||||
|
||||
if !HasCachedUpdate() {
|
||||
t.Fatal("expected cached update to be present")
|
||||
}
|
||||
|
||||
if err := ClearCachedUpdate(); err != nil {
|
||||
t.Fatalf("cache clear: %v", err)
|
||||
}
|
||||
|
||||
if HasCachedUpdate() {
|
||||
t.Fatal("expected no cached update after clear")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasCachedUpdateStale(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
setHome(t, tmp)
|
||||
os.MkdirAll(filepath.Join(tmp, ".ollama"), 0o755)
|
||||
|
||||
if err := CacheAvailableUpdate(); err != nil {
|
||||
t.Fatalf("cache write: %v", err)
|
||||
}
|
||||
|
||||
// Backdate the file to make it stale
|
||||
path := filepath.Join(tmp, ".ollama", "update")
|
||||
staleTime := time.Now().Add(-25 * time.Hour)
|
||||
os.Chtimes(path, staleTime, staleTime)
|
||||
|
||||
if HasCachedUpdate() {
|
||||
t.Fatal("expected no cached update for stale file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsLocalHost(t *testing.T) {
|
||||
tests := []struct {
|
||||
host string
|
||||
local bool
|
||||
}{
|
||||
{"http://127.0.0.1:11434", true},
|
||||
{"http://localhost:11434", true},
|
||||
{"http://[::1]:11434", true},
|
||||
{"http://0.0.0.0:11434", true},
|
||||
{"http://remote.example.com:11434", false},
|
||||
{"http://192.168.1.100:11434", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.host, func(t *testing.T) {
|
||||
u, err := url.Parse(tt.host)
|
||||
if err != nil {
|
||||
t.Fatalf("parse URL: %v", err)
|
||||
}
|
||||
if got := IsLocalHost(u); got != tt.local {
|
||||
t.Errorf("IsLocalHost(%s) = %v, want %v", tt.host, got, tt.local)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -16,10 +16,10 @@ import (
|
||||
)
|
||||
|
||||
type Function struct {
|
||||
Name string
|
||||
ReturnType string
|
||||
Params string
|
||||
ParamNames []string
|
||||
Name string
|
||||
ReturnType string
|
||||
Params string
|
||||
ParamNames []string
|
||||
NeedsARM64Guard bool
|
||||
}
|
||||
|
||||
@@ -29,6 +29,11 @@ func findHeaders(directory string) ([]string, error) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Private headers contain C++ implementation helpers and are not part of
|
||||
// the C API surface; parsing them can produce invalid wrapper signatures.
|
||||
if d.IsDir() && d.Name() == "private" {
|
||||
return fs.SkipDir
|
||||
}
|
||||
if !d.IsDir() && strings.HasSuffix(path, ".h") {
|
||||
headers = append(headers, path)
|
||||
}
|
||||
@@ -194,10 +199,10 @@ func parseFunctions(content string) []Function {
|
||||
needsGuard := needsARM64Guard(funcName, returnType, params)
|
||||
|
||||
functions = append(functions, Function{
|
||||
Name: funcName,
|
||||
ReturnType: returnType,
|
||||
Params: params,
|
||||
ParamNames: paramNames,
|
||||
Name: funcName,
|
||||
ReturnType: returnType,
|
||||
Params: params,
|
||||
ParamNames: paramNames,
|
||||
NeedsARM64Guard: needsGuard,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -20,6 +20,8 @@ mlx_array (*mlx_array_new_float64_ptr)(double val) = NULL;
|
||||
mlx_array (*mlx_array_new_double_ptr)(double val) = NULL;
|
||||
mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val) = NULL;
|
||||
mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype) = NULL;
|
||||
mlx_array (*mlx_array_new_data_managed_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) = NULL;
|
||||
mlx_array (*mlx_array_new_data_managed_payload_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) = NULL;
|
||||
int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src) = NULL;
|
||||
int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val) = NULL;
|
||||
int (*mlx_array_set_int_ptr)(mlx_array* arr, int val) = NULL;
|
||||
@@ -49,7 +51,7 @@ int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_complex64_ptr)(float _Complex* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_complex64_ptr)(mlx_complex64_t* res, const mlx_array arr) = NULL;
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr) = NULL;
|
||||
#endif
|
||||
@@ -67,7 +69,7 @@ const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr) = NULL;
|
||||
const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr) = NULL;
|
||||
const float* (*mlx_array_data_float32_ptr)(const mlx_array arr) = NULL;
|
||||
const double* (*mlx_array_data_float64_ptr)(const mlx_array arr) = NULL;
|
||||
const float _Complex* (*mlx_array_data_complex64_ptr)(const mlx_array arr) = NULL;
|
||||
const mlx_complex64_t* (*mlx_array_data_complex64_ptr)(const mlx_array arr) = NULL;
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr) = NULL;
|
||||
#endif
|
||||
@@ -123,6 +125,7 @@ int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id) = NULL;
|
||||
int (*mlx_disable_compile_ptr)(void) = NULL;
|
||||
int (*mlx_enable_compile_ptr)(void) = NULL;
|
||||
int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode) = NULL;
|
||||
int (*mlx_cuda_is_available_ptr)(bool* res) = NULL;
|
||||
mlx_device (*mlx_device_new_ptr)(void) = NULL;
|
||||
mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index) = NULL;
|
||||
int (*mlx_device_free_ptr)(mlx_device dev) = NULL;
|
||||
@@ -133,6 +136,16 @@ int (*mlx_device_get_index_ptr)(int* index, mlx_device dev) = NULL;
|
||||
int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev) = NULL;
|
||||
int (*mlx_get_default_device_ptr)(mlx_device* dev) = NULL;
|
||||
int (*mlx_set_default_device_ptr)(mlx_device dev) = NULL;
|
||||
int (*mlx_device_is_available_ptr)(bool* avail, mlx_device dev) = NULL;
|
||||
int (*mlx_device_count_ptr)(int* count, mlx_device_type type) = NULL;
|
||||
mlx_device_info (*mlx_device_info_new_ptr)(void) = NULL;
|
||||
int (*mlx_device_info_get_ptr)(mlx_device_info* info, mlx_device dev) = NULL;
|
||||
int (*mlx_device_info_free_ptr)(mlx_device_info info) = NULL;
|
||||
int (*mlx_device_info_has_key_ptr)(bool* exists, mlx_device_info info, const char* key) = NULL;
|
||||
int (*mlx_device_info_is_string_ptr)(bool* is_string, mlx_device_info info, const char* key) = NULL;
|
||||
int (*mlx_device_info_get_string_ptr)(const char** value, mlx_device_info info, const char* key) = NULL;
|
||||
int (*mlx_device_info_get_size_ptr)(size_t* value, mlx_device_info info, const char* key) = NULL;
|
||||
int (*mlx_device_info_get_keys_ptr)(mlx_vector_string* keys, mlx_device_info info) = NULL;
|
||||
int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) = NULL;
|
||||
int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL;
|
||||
int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL;
|
||||
@@ -263,7 +276,6 @@ int (*mlx_reset_peak_memory_ptr)(void) = NULL;
|
||||
int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit) = NULL;
|
||||
int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit) = NULL;
|
||||
int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit) = NULL;
|
||||
mlx_metal_device_info_t (*mlx_metal_device_info_ptr)(void) = NULL;
|
||||
int (*mlx_metal_is_available_ptr)(bool* res) = NULL;
|
||||
int (*mlx_metal_start_capture_ptr)(const char* path) = NULL;
|
||||
int (*mlx_metal_stop_capture_ptr)(void) = NULL;
|
||||
@@ -658,6 +670,16 @@ int mlx_load_functions(void* handle) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_array_new_data_managed_ptr = dlsym(handle, "mlx_array_new_data_managed");
|
||||
if (mlx_array_new_data_managed_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_array_new_data_managed_payload_ptr = dlsym(handle, "mlx_array_new_data_managed_payload");
|
||||
if (mlx_array_new_data_managed_payload_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed_payload\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_array_set_ptr = dlsym(handle, "mlx_array_set");
|
||||
if (mlx_array_set_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set\n");
|
||||
@@ -1141,6 +1163,11 @@ int mlx_load_functions(void* handle) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_compile_mode\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_cuda_is_available_ptr = dlsym(handle, "mlx_cuda_is_available");
|
||||
if (mlx_cuda_is_available_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_cuda_is_available\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_new_ptr = dlsym(handle, "mlx_device_new");
|
||||
if (mlx_device_new_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_new\n");
|
||||
@@ -1191,6 +1218,56 @@ int mlx_load_functions(void* handle) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_default_device\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_is_available_ptr = dlsym(handle, "mlx_device_is_available");
|
||||
if (mlx_device_is_available_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_is_available\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_count_ptr = dlsym(handle, "mlx_device_count");
|
||||
if (mlx_device_count_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_count\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_new_ptr = dlsym(handle, "mlx_device_info_new");
|
||||
if (mlx_device_info_new_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_new\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_get_ptr = dlsym(handle, "mlx_device_info_get");
|
||||
if (mlx_device_info_get_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_free_ptr = dlsym(handle, "mlx_device_info_free");
|
||||
if (mlx_device_info_free_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_free\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_has_key_ptr = dlsym(handle, "mlx_device_info_has_key");
|
||||
if (mlx_device_info_has_key_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_has_key\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_is_string_ptr = dlsym(handle, "mlx_device_info_is_string");
|
||||
if (mlx_device_info_is_string_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_is_string\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_get_string_ptr = dlsym(handle, "mlx_device_info_get_string");
|
||||
if (mlx_device_info_get_string_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_string\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_get_size_ptr = dlsym(handle, "mlx_device_info_get_size");
|
||||
if (mlx_device_info_get_size_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_size\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_get_keys_ptr = dlsym(handle, "mlx_device_info_get_keys");
|
||||
if (mlx_device_info_get_keys_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_keys\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_distributed_all_gather_ptr = dlsym(handle, "mlx_distributed_all_gather");
|
||||
if (mlx_distributed_all_gather_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_gather\n");
|
||||
@@ -1841,11 +1918,6 @@ int mlx_load_functions(void* handle) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_wired_limit\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_metal_device_info_ptr = dlsym(handle, "mlx_metal_device_info");
|
||||
if (mlx_metal_device_info_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_device_info\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_metal_is_available_ptr = dlsym(handle, "mlx_metal_is_available");
|
||||
if (mlx_metal_is_available_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_is_available\n");
|
||||
@@ -3528,6 +3600,14 @@ mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dt
|
||||
return mlx_array_new_data_ptr(data, shape, dim, dtype);
|
||||
}
|
||||
|
||||
mlx_array mlx_array_new_data_managed(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) {
|
||||
return mlx_array_new_data_managed_ptr(data, shape, dim, dtype, dtor);
|
||||
}
|
||||
|
||||
mlx_array mlx_array_new_data_managed_payload(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) {
|
||||
return mlx_array_new_data_managed_payload_ptr(data, shape, dim, dtype, payload, dtor);
|
||||
}
|
||||
|
||||
int mlx_array_set(mlx_array* arr, const mlx_array src) {
|
||||
return mlx_array_set_ptr(arr, src);
|
||||
}
|
||||
@@ -3644,7 +3724,7 @@ int mlx_array_item_float64(double* res, const mlx_array arr) {
|
||||
return mlx_array_item_float64_ptr(res, arr);
|
||||
}
|
||||
|
||||
int mlx_array_item_complex64(float _Complex* res, const mlx_array arr) {
|
||||
int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr) {
|
||||
return mlx_array_item_complex64_ptr(res, arr);
|
||||
}
|
||||
|
||||
@@ -3704,7 +3784,7 @@ const double* mlx_array_data_float64(const mlx_array arr) {
|
||||
return mlx_array_data_float64_ptr(arr);
|
||||
}
|
||||
|
||||
const float _Complex* mlx_array_data_complex64(const mlx_array arr) {
|
||||
const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr) {
|
||||
return mlx_array_data_complex64_ptr(arr);
|
||||
}
|
||||
|
||||
@@ -3916,6 +3996,10 @@ int mlx_set_compile_mode(mlx_compile_mode mode) {
|
||||
return mlx_set_compile_mode_ptr(mode);
|
||||
}
|
||||
|
||||
int mlx_cuda_is_available(bool* res) {
|
||||
return mlx_cuda_is_available_ptr(res);
|
||||
}
|
||||
|
||||
mlx_device mlx_device_new(void) {
|
||||
return mlx_device_new_ptr();
|
||||
}
|
||||
@@ -3956,6 +4040,46 @@ int mlx_set_default_device(mlx_device dev) {
|
||||
return mlx_set_default_device_ptr(dev);
|
||||
}
|
||||
|
||||
int mlx_device_is_available(bool* avail, mlx_device dev) {
|
||||
return mlx_device_is_available_ptr(avail, dev);
|
||||
}
|
||||
|
||||
int mlx_device_count(int* count, mlx_device_type type) {
|
||||
return mlx_device_count_ptr(count, type);
|
||||
}
|
||||
|
||||
mlx_device_info mlx_device_info_new(void) {
|
||||
return mlx_device_info_new_ptr();
|
||||
}
|
||||
|
||||
int mlx_device_info_get(mlx_device_info* info, mlx_device dev) {
|
||||
return mlx_device_info_get_ptr(info, dev);
|
||||
}
|
||||
|
||||
int mlx_device_info_free(mlx_device_info info) {
|
||||
return mlx_device_info_free_ptr(info);
|
||||
}
|
||||
|
||||
int mlx_device_info_has_key(bool* exists, mlx_device_info info, const char* key) {
|
||||
return mlx_device_info_has_key_ptr(exists, info, key);
|
||||
}
|
||||
|
||||
int mlx_device_info_is_string(bool* is_string, mlx_device_info info, const char* key) {
|
||||
return mlx_device_info_is_string_ptr(is_string, info, key);
|
||||
}
|
||||
|
||||
int mlx_device_info_get_string(const char** value, mlx_device_info info, const char* key) {
|
||||
return mlx_device_info_get_string_ptr(value, info, key);
|
||||
}
|
||||
|
||||
int mlx_device_info_get_size(size_t* value, mlx_device_info info, const char* key) {
|
||||
return mlx_device_info_get_size_ptr(value, info, key);
|
||||
}
|
||||
|
||||
int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info) {
|
||||
return mlx_device_info_get_keys_ptr(keys, info);
|
||||
}
|
||||
|
||||
int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) {
|
||||
return mlx_distributed_all_gather_ptr(res, x, group, S);
|
||||
}
|
||||
@@ -4476,10 +4600,6 @@ int mlx_set_wired_limit(size_t* res, size_t limit) {
|
||||
return mlx_set_wired_limit_ptr(res, limit);
|
||||
}
|
||||
|
||||
mlx_metal_device_info_t mlx_metal_device_info(void) {
|
||||
return mlx_metal_device_info_ptr();
|
||||
}
|
||||
|
||||
int mlx_metal_is_available(bool* res) {
|
||||
return mlx_metal_is_available_ptr(res);
|
||||
}
|
||||
|
||||
@@ -26,6 +26,8 @@
|
||||
#undef mlx_array_new_double
|
||||
#undef mlx_array_new_complex
|
||||
#undef mlx_array_new_data
|
||||
#undef mlx_array_new_data_managed
|
||||
#undef mlx_array_new_data_managed_payload
|
||||
#undef mlx_array_set
|
||||
#undef mlx_array_set_bool
|
||||
#undef mlx_array_set_int
|
||||
@@ -121,6 +123,7 @@
|
||||
#undef mlx_disable_compile
|
||||
#undef mlx_enable_compile
|
||||
#undef mlx_set_compile_mode
|
||||
#undef mlx_cuda_is_available
|
||||
#undef mlx_device_new
|
||||
#undef mlx_device_new_type
|
||||
#undef mlx_device_free
|
||||
@@ -131,6 +134,16 @@
|
||||
#undef mlx_device_get_type
|
||||
#undef mlx_get_default_device
|
||||
#undef mlx_set_default_device
|
||||
#undef mlx_device_is_available
|
||||
#undef mlx_device_count
|
||||
#undef mlx_device_info_new
|
||||
#undef mlx_device_info_get
|
||||
#undef mlx_device_info_free
|
||||
#undef mlx_device_info_has_key
|
||||
#undef mlx_device_info_is_string
|
||||
#undef mlx_device_info_get_string
|
||||
#undef mlx_device_info_get_size
|
||||
#undef mlx_device_info_get_keys
|
||||
#undef mlx_distributed_all_gather
|
||||
#undef mlx_distributed_all_max
|
||||
#undef mlx_distributed_all_min
|
||||
@@ -261,7 +274,6 @@
|
||||
#undef mlx_set_cache_limit
|
||||
#undef mlx_set_memory_limit
|
||||
#undef mlx_set_wired_limit
|
||||
#undef mlx_metal_device_info
|
||||
#undef mlx_metal_is_available
|
||||
#undef mlx_metal_start_capture
|
||||
#undef mlx_metal_stop_capture
|
||||
@@ -602,6 +614,8 @@ extern mlx_array (*mlx_array_new_float64_ptr)(double val);
|
||||
extern mlx_array (*mlx_array_new_double_ptr)(double val);
|
||||
extern mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val);
|
||||
extern mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype);
|
||||
extern mlx_array (*mlx_array_new_data_managed_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*));
|
||||
extern mlx_array (*mlx_array_new_data_managed_payload_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*));
|
||||
extern int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src);
|
||||
extern int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val);
|
||||
extern int (*mlx_array_set_int_ptr)(mlx_array* arr, int val);
|
||||
@@ -631,7 +645,7 @@ extern int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr);
|
||||
extern int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr);
|
||||
extern int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr);
|
||||
extern int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr);
|
||||
extern int (*mlx_array_item_complex64_ptr)(float _Complex* res, const mlx_array arr);
|
||||
extern int (*mlx_array_item_complex64_ptr)(mlx_complex64_t* res, const mlx_array arr);
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
extern int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr);
|
||||
#endif
|
||||
@@ -649,7 +663,7 @@ extern const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr);
|
||||
extern const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr);
|
||||
extern const float* (*mlx_array_data_float32_ptr)(const mlx_array arr);
|
||||
extern const double* (*mlx_array_data_float64_ptr)(const mlx_array arr);
|
||||
extern const float _Complex* (*mlx_array_data_complex64_ptr)(const mlx_array arr);
|
||||
extern const mlx_complex64_t* (*mlx_array_data_complex64_ptr)(const mlx_array arr);
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
extern const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr);
|
||||
#endif
|
||||
@@ -705,6 +719,7 @@ extern int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id);
|
||||
extern int (*mlx_disable_compile_ptr)(void);
|
||||
extern int (*mlx_enable_compile_ptr)(void);
|
||||
extern int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode);
|
||||
extern int (*mlx_cuda_is_available_ptr)(bool* res);
|
||||
extern mlx_device (*mlx_device_new_ptr)(void);
|
||||
extern mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index);
|
||||
extern int (*mlx_device_free_ptr)(mlx_device dev);
|
||||
@@ -715,6 +730,16 @@ extern int (*mlx_device_get_index_ptr)(int* index, mlx_device dev);
|
||||
extern int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev);
|
||||
extern int (*mlx_get_default_device_ptr)(mlx_device* dev);
|
||||
extern int (*mlx_set_default_device_ptr)(mlx_device dev);
|
||||
extern int (*mlx_device_is_available_ptr)(bool* avail, mlx_device dev);
|
||||
extern int (*mlx_device_count_ptr)(int* count, mlx_device_type type);
|
||||
extern mlx_device_info (*mlx_device_info_new_ptr)(void);
|
||||
extern int (*mlx_device_info_get_ptr)(mlx_device_info* info, mlx_device dev);
|
||||
extern int (*mlx_device_info_free_ptr)(mlx_device_info info);
|
||||
extern int (*mlx_device_info_has_key_ptr)(bool* exists, mlx_device_info info, const char* key);
|
||||
extern int (*mlx_device_info_is_string_ptr)(bool* is_string, mlx_device_info info, const char* key);
|
||||
extern int (*mlx_device_info_get_string_ptr)(const char** value, mlx_device_info info, const char* key);
|
||||
extern int (*mlx_device_info_get_size_ptr)(size_t* value, mlx_device_info info, const char* key);
|
||||
extern int (*mlx_device_info_get_keys_ptr)(mlx_vector_string* keys, mlx_device_info info);
|
||||
extern int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S);
|
||||
extern int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
|
||||
extern int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
|
||||
@@ -845,7 +870,6 @@ extern int (*mlx_reset_peak_memory_ptr)(void);
|
||||
extern int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit);
|
||||
extern int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit);
|
||||
extern int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit);
|
||||
extern mlx_metal_device_info_t (*mlx_metal_device_info_ptr)(void);
|
||||
extern int (*mlx_metal_is_available_ptr)(bool* res);
|
||||
extern int (*mlx_metal_start_capture_ptr)(const char* path);
|
||||
extern int (*mlx_metal_stop_capture_ptr)(void);
|
||||
@@ -1202,6 +1226,10 @@ mlx_array mlx_array_new_complex(float real_val, float imag_val);
|
||||
|
||||
mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dtype dtype);
|
||||
|
||||
mlx_array mlx_array_new_data_managed(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*));
|
||||
|
||||
mlx_array mlx_array_new_data_managed_payload(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*));
|
||||
|
||||
int mlx_array_set(mlx_array* arr, const mlx_array src);
|
||||
|
||||
int mlx_array_set_bool(mlx_array* arr, bool val);
|
||||
@@ -1260,7 +1288,7 @@ int mlx_array_item_float32(float* res, const mlx_array arr);
|
||||
|
||||
int mlx_array_item_float64(double* res, const mlx_array arr);
|
||||
|
||||
int mlx_array_item_complex64(float _Complex* res, const mlx_array arr);
|
||||
int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr);
|
||||
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
int mlx_array_item_float16(float16_t* res, const mlx_array arr);
|
||||
@@ -1292,7 +1320,7 @@ const float* mlx_array_data_float32(const mlx_array arr);
|
||||
|
||||
const double* mlx_array_data_float64(const mlx_array arr);
|
||||
|
||||
const float _Complex* mlx_array_data_complex64(const mlx_array arr);
|
||||
const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr);
|
||||
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
const float16_t* mlx_array_data_float16(const mlx_array arr);
|
||||
@@ -1400,6 +1428,8 @@ int mlx_enable_compile(void);
|
||||
|
||||
int mlx_set_compile_mode(mlx_compile_mode mode);
|
||||
|
||||
int mlx_cuda_is_available(bool* res);
|
||||
|
||||
mlx_device mlx_device_new(void);
|
||||
|
||||
mlx_device mlx_device_new_type(mlx_device_type type, int index);
|
||||
@@ -1420,6 +1450,26 @@ int mlx_get_default_device(mlx_device* dev);
|
||||
|
||||
int mlx_set_default_device(mlx_device dev);
|
||||
|
||||
int mlx_device_is_available(bool* avail, mlx_device dev);
|
||||
|
||||
int mlx_device_count(int* count, mlx_device_type type);
|
||||
|
||||
mlx_device_info mlx_device_info_new(void);
|
||||
|
||||
int mlx_device_info_get(mlx_device_info* info, mlx_device dev);
|
||||
|
||||
int mlx_device_info_free(mlx_device_info info);
|
||||
|
||||
int mlx_device_info_has_key(bool* exists, mlx_device_info info, const char* key);
|
||||
|
||||
int mlx_device_info_is_string(bool* is_string, mlx_device_info info, const char* key);
|
||||
|
||||
int mlx_device_info_get_string(const char** value, mlx_device_info info, const char* key);
|
||||
|
||||
int mlx_device_info_get_size(size_t* value, mlx_device_info info, const char* key);
|
||||
|
||||
int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info);
|
||||
|
||||
int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S);
|
||||
|
||||
int mlx_distributed_all_max(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
|
||||
@@ -1680,8 +1730,6 @@ int mlx_set_memory_limit(size_t* res, size_t limit);
|
||||
|
||||
int mlx_set_wired_limit(size_t* res, size_t limit);
|
||||
|
||||
mlx_metal_device_info_t mlx_metal_device_info(void);
|
||||
|
||||
int mlx_metal_is_available(bool* res);
|
||||
|
||||
int mlx_metal_start_capture(const char* path);
|
||||
|
||||
@@ -15,7 +15,7 @@ set(CMAKE_INSTALL_RPATH "@loader_path")
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
set(MLX_C_GIT_TAG "v0.4.1" CACHE STRING "")
|
||||
set(MLX_C_GIT_TAG "v0.5.0" CACHE STRING "")
|
||||
|
||||
FetchContent_Declare(
|
||||
mlx-c
|
||||
|
||||
@@ -22,6 +22,19 @@ mlx_array (*mlx_array_new_data_)(
|
||||
const int* shape,
|
||||
int dim,
|
||||
mlx_dtype dtype) = NULL;
|
||||
mlx_array (*mlx_array_new_data_managed_)(
|
||||
void* data,
|
||||
const int* shape,
|
||||
int dim,
|
||||
mlx_dtype dtype,
|
||||
void (*dtor)(void*)) = NULL;
|
||||
mlx_array (*mlx_array_new_data_managed_payload_)(
|
||||
void* data,
|
||||
const int* shape,
|
||||
int dim,
|
||||
mlx_dtype dtype,
|
||||
void* payload,
|
||||
void (*dtor)(void*)) = NULL;
|
||||
int (*mlx_array_set_)(mlx_array* arr, const mlx_array src) = NULL;
|
||||
int (*mlx_array_set_bool_)(mlx_array* arr, bool val) = NULL;
|
||||
int (*mlx_array_set_int_)(mlx_array* arr, int val) = NULL;
|
||||
@@ -56,7 +69,7 @@ int (*mlx_array_item_int32_)(int32_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_int64_)(int64_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_float32_)(float* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_float64_)(double* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_complex64_)(float _Complex* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_complex64_)(mlx_complex64_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_float16_)(float16_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_bfloat16_)(bfloat16_t* res, const mlx_array arr) = NULL;
|
||||
const bool * (*mlx_array_data_bool_)(const mlx_array arr) = NULL;
|
||||
@@ -70,7 +83,7 @@ const int32_t * (*mlx_array_data_int32_)(const mlx_array arr) = NULL;
|
||||
const int64_t * (*mlx_array_data_int64_)(const mlx_array arr) = NULL;
|
||||
const float * (*mlx_array_data_float32_)(const mlx_array arr) = NULL;
|
||||
const double * (*mlx_array_data_float64_)(const mlx_array arr) = NULL;
|
||||
const float _Complex * (*mlx_array_data_complex64_)(const mlx_array arr) = NULL;
|
||||
const mlx_complex64_t * (*mlx_array_data_complex64_)(const mlx_array arr) = NULL;
|
||||
const float16_t * (*mlx_array_data_float16_)(const mlx_array arr) = NULL;
|
||||
const bfloat16_t * (*mlx_array_data_bfloat16_)(const mlx_array arr) = NULL;
|
||||
int (*_mlx_array_is_available_)(bool* res, const mlx_array arr) = NULL;
|
||||
@@ -94,10 +107,11 @@ int (*mlx_closure_apply_)(
|
||||
mlx_closure (*mlx_closure_new_unary_)(int (*fun)(mlx_array*, const mlx_array)) = NULL;
|
||||
mlx_closure_kwargs (*mlx_closure_kwargs_new_)(void) = NULL;
|
||||
int (*mlx_closure_kwargs_free_)(mlx_closure_kwargs cls) = NULL;
|
||||
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_map_string_to_array)) = NULL;
|
||||
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_map_string_to_array)) = NULL;
|
||||
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_payload_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
@@ -136,11 +150,12 @@ int (*mlx_closure_value_and_grad_apply_)(
|
||||
const mlx_vector_array input) = NULL;
|
||||
mlx_closure_custom (*mlx_closure_custom_new_)(void) = NULL;
|
||||
int (*mlx_closure_custom_free_)(mlx_closure_custom cls) = NULL;
|
||||
mlx_closure_custom (*mlx_closure_custom_new_func_)(int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array)) = NULL;
|
||||
mlx_closure_custom (*mlx_closure_custom_new_func_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array)) = NULL;
|
||||
mlx_closure_custom (*mlx_closure_custom_new_func_payload_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
@@ -161,12 +176,13 @@ int (*mlx_closure_custom_apply_)(
|
||||
const mlx_vector_array input_2) = NULL;
|
||||
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_)(void) = NULL;
|
||||
int (*mlx_closure_custom_jvp_free_)(mlx_closure_custom_jvp cls) = NULL;
|
||||
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num)) = NULL;
|
||||
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num)) = NULL;
|
||||
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_payload_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
@@ -189,12 +205,13 @@ int (*mlx_closure_custom_jvp_apply_)(
|
||||
size_t input_2_num) = NULL;
|
||||
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_)(void) = NULL;
|
||||
int (*mlx_closure_custom_vmap_free_)(mlx_closure_custom_vmap cls) = NULL;
|
||||
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(int (*fun)(
|
||||
mlx_vector_array*,
|
||||
mlx_vector_int*,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num)) = NULL;
|
||||
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
mlx_vector_int*,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num)) = NULL;
|
||||
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_payload_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
@@ -228,6 +245,7 @@ int (*mlx_detail_compile_erase_)(uintptr_t fun_id) = NULL;
|
||||
int (*mlx_disable_compile_)(void) = NULL;
|
||||
int (*mlx_enable_compile_)(void) = NULL;
|
||||
int (*mlx_set_compile_mode_)(mlx_compile_mode mode) = NULL;
|
||||
int (*mlx_cuda_is_available_)(bool* res) = NULL;
|
||||
mlx_device (*mlx_device_new_)(void) = NULL;
|
||||
mlx_device (*mlx_device_new_type_)(mlx_device_type type, int index) = NULL;
|
||||
int (*mlx_device_free_)(mlx_device dev) = NULL;
|
||||
@@ -238,11 +256,28 @@ int (*mlx_device_get_index_)(int* index, mlx_device dev) = NULL;
|
||||
int (*mlx_device_get_type_)(mlx_device_type* type, mlx_device dev) = NULL;
|
||||
int (*mlx_get_default_device_)(mlx_device* dev) = NULL;
|
||||
int (*mlx_set_default_device_)(mlx_device dev) = NULL;
|
||||
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
|
||||
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
|
||||
bool (*mlx_distributed_is_available_)(void) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL;
|
||||
int (*mlx_device_is_available_)(bool* avail, mlx_device dev) = NULL;
|
||||
int (*mlx_device_count_)(int* count, mlx_device_type type) = NULL;
|
||||
mlx_device_info (*mlx_device_info_new_)(void) = NULL;
|
||||
int (*mlx_device_info_get_)(mlx_device_info* info, mlx_device dev) = NULL;
|
||||
int (*mlx_device_info_free_)(mlx_device_info info) = NULL;
|
||||
int (*mlx_device_info_has_key_)(
|
||||
bool* exists,
|
||||
mlx_device_info info,
|
||||
const char* key) = NULL;
|
||||
int (*mlx_device_info_is_string_)(
|
||||
bool* is_string,
|
||||
mlx_device_info info,
|
||||
const char* key) = NULL;
|
||||
int (*mlx_device_info_get_string_)(
|
||||
const char** value,
|
||||
mlx_device_info info,
|
||||
const char* key) = NULL;
|
||||
int (*mlx_device_info_get_size_)(
|
||||
size_t* value,
|
||||
mlx_device_info info,
|
||||
const char* key) = NULL;
|
||||
int (*mlx_device_info_get_keys_)(mlx_vector_string* keys, mlx_device_info info) = NULL;
|
||||
int (*mlx_distributed_all_gather_)(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
@@ -288,6 +323,11 @@ int (*mlx_distributed_sum_scatter_)(
|
||||
const mlx_array x,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
|
||||
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
|
||||
bool (*mlx_distributed_is_available_)(void) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL;
|
||||
void (*mlx_set_error_handler_)(
|
||||
mlx_error_handler_func handler,
|
||||
void* data,
|
||||
@@ -450,6 +490,16 @@ int (*mlx_fast_rope_)(
|
||||
int offset,
|
||||
const mlx_array freqs /* may be null */,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_fast_rope_dynamic_)(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
int dims,
|
||||
bool traditional,
|
||||
mlx_optional_float base,
|
||||
float scale,
|
||||
const mlx_array offset,
|
||||
const mlx_array freqs /* may be null */,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_fast_scaled_dot_product_attention_)(
|
||||
mlx_array* res,
|
||||
const mlx_array queries,
|
||||
@@ -560,14 +610,6 @@ int (*mlx_fft_rfftn_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable) = NULL;
|
||||
int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io) = NULL;
|
||||
int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io) = NULL;
|
||||
int (*mlx_io_reader_free_)(mlx_io_reader io) = NULL;
|
||||
mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable) = NULL;
|
||||
int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io) = NULL;
|
||||
int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io) = NULL;
|
||||
int (*mlx_io_writer_free_)(mlx_io_writer io) = NULL;
|
||||
int (*mlx_load_reader_)(
|
||||
mlx_array* res,
|
||||
mlx_io_reader in_stream,
|
||||
@@ -593,6 +635,14 @@ int (*mlx_save_safetensors_)(
|
||||
const char* file,
|
||||
const mlx_map_string_to_array param,
|
||||
const mlx_map_string_to_string metadata) = NULL;
|
||||
mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable) = NULL;
|
||||
int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io) = NULL;
|
||||
int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io) = NULL;
|
||||
int (*mlx_io_reader_free_)(mlx_io_reader io) = NULL;
|
||||
mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable) = NULL;
|
||||
int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io) = NULL;
|
||||
int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io) = NULL;
|
||||
int (*mlx_io_writer_free_)(mlx_io_writer io) = NULL;
|
||||
int (*mlx_linalg_cholesky_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -733,7 +783,6 @@ int (*mlx_reset_peak_memory_)(void) = NULL;
|
||||
int (*mlx_set_cache_limit_)(size_t* res, size_t limit) = NULL;
|
||||
int (*mlx_set_memory_limit_)(size_t* res, size_t limit) = NULL;
|
||||
int (*mlx_set_wired_limit_)(size_t* res, size_t limit) = NULL;
|
||||
mlx_metal_device_info_t (*mlx_metal_device_info_)(void) = NULL;
|
||||
int (*mlx_metal_is_available_)(bool* res) = NULL;
|
||||
int (*mlx_metal_start_capture_)(const char* path) = NULL;
|
||||
int (*mlx_metal_stop_capture_)(void) = NULL;
|
||||
@@ -1162,6 +1211,14 @@ int (*mlx_gather_)(
|
||||
const int* slice_sizes,
|
||||
size_t slice_sizes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_gather_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
int axis,
|
||||
const int* slice_sizes,
|
||||
size_t slice_sizes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_gather_mm_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -1483,6 +1540,15 @@ int (*mlx_put_along_axis_)(
|
||||
const mlx_array values,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_qqmm_)(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
const mlx_array w,
|
||||
const mlx_array w_scales /* may be null */,
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_quantize_)(
|
||||
mlx_vector_array* res,
|
||||
const mlx_array w,
|
||||
@@ -1566,6 +1632,13 @@ int (*mlx_scatter_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
const mlx_array updates,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_add_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -1574,6 +1647,13 @@ int (*mlx_scatter_add_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_add_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
const mlx_array updates,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_add_axis_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -1589,6 +1669,13 @@ int (*mlx_scatter_max_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_max_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
const mlx_array updates,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_min_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -1597,6 +1684,13 @@ int (*mlx_scatter_min_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_min_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
const mlx_array updates,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_prod_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -1605,6 +1699,13 @@ int (*mlx_scatter_prod_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_prod_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
const mlx_array updates,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_segmented_mm_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -2028,22 +2129,6 @@ mlx_string (*mlx_string_new_data_)(const char* str) = NULL;
|
||||
int (*mlx_string_set_)(mlx_string* str, const mlx_string src) = NULL;
|
||||
const char * (*mlx_string_data_)(mlx_string str) = NULL;
|
||||
int (*mlx_string_free_)(mlx_string str) = NULL;
|
||||
int (*mlx_detail_vmap_replace_)(
|
||||
mlx_vector_array* res,
|
||||
const mlx_vector_array inputs,
|
||||
const mlx_vector_array s_inputs,
|
||||
const mlx_vector_array s_outputs,
|
||||
const int* in_axes,
|
||||
size_t in_axes_num,
|
||||
const int* out_axes,
|
||||
size_t out_axes_num) = NULL;
|
||||
int (*mlx_detail_vmap_trace_)(
|
||||
mlx_vector_array* res_0,
|
||||
mlx_vector_array* res_1,
|
||||
const mlx_closure fun,
|
||||
const mlx_vector_array inputs,
|
||||
const int* in_axes,
|
||||
size_t in_axes_num) = NULL;
|
||||
int (*mlx_async_eval_)(const mlx_vector_array outputs) = NULL;
|
||||
int (*mlx_checkpoint_)(mlx_closure* res, const mlx_closure fun) = NULL;
|
||||
int (*mlx_custom_function_)(
|
||||
@@ -2074,6 +2159,22 @@ int (*mlx_vjp_)(
|
||||
const mlx_closure fun,
|
||||
const mlx_vector_array primals,
|
||||
const mlx_vector_array cotangents) = NULL;
|
||||
int (*mlx_detail_vmap_replace_)(
|
||||
mlx_vector_array* res,
|
||||
const mlx_vector_array inputs,
|
||||
const mlx_vector_array s_inputs,
|
||||
const mlx_vector_array s_outputs,
|
||||
const int* in_axes,
|
||||
size_t in_axes_num,
|
||||
const int* out_axes,
|
||||
size_t out_axes_num) = NULL;
|
||||
int (*mlx_detail_vmap_trace_)(
|
||||
mlx_vector_array* res_0,
|
||||
mlx_vector_array* res_1,
|
||||
const mlx_closure fun,
|
||||
const mlx_vector_array inputs,
|
||||
const int* in_axes,
|
||||
size_t in_axes_num) = NULL;
|
||||
mlx_vector_array (*mlx_vector_array_new_)(void) = NULL;
|
||||
int (*mlx_vector_array_set_)(mlx_vector_array* vec, const mlx_vector_array src) = NULL;
|
||||
int (*mlx_vector_array_free_)(mlx_vector_array vec) = NULL;
|
||||
@@ -2166,6 +2267,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_array_new_double);
|
||||
CHECK_LOAD(handle, mlx_array_new_complex);
|
||||
CHECK_LOAD(handle, mlx_array_new_data);
|
||||
CHECK_LOAD(handle, mlx_array_new_data_managed);
|
||||
CHECK_LOAD(handle, mlx_array_new_data_managed_payload);
|
||||
CHECK_LOAD(handle, mlx_array_set);
|
||||
CHECK_LOAD(handle, mlx_array_set_bool);
|
||||
CHECK_LOAD(handle, mlx_array_set_int);
|
||||
@@ -2261,6 +2364,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_disable_compile);
|
||||
CHECK_LOAD(handle, mlx_enable_compile);
|
||||
CHECK_LOAD(handle, mlx_set_compile_mode);
|
||||
CHECK_LOAD(handle, mlx_cuda_is_available);
|
||||
CHECK_LOAD(handle, mlx_device_new);
|
||||
CHECK_LOAD(handle, mlx_device_new_type);
|
||||
CHECK_LOAD(handle, mlx_device_free);
|
||||
@@ -2271,11 +2375,16 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_device_get_type);
|
||||
CHECK_LOAD(handle, mlx_get_default_device);
|
||||
CHECK_LOAD(handle, mlx_set_default_device);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_rank);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_size);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_split);
|
||||
CHECK_LOAD(handle, mlx_distributed_is_available);
|
||||
CHECK_LOAD(handle, mlx_distributed_init);
|
||||
CHECK_LOAD(handle, mlx_device_is_available);
|
||||
CHECK_LOAD(handle, mlx_device_count);
|
||||
CHECK_LOAD(handle, mlx_device_info_new);
|
||||
CHECK_LOAD(handle, mlx_device_info_get);
|
||||
CHECK_LOAD(handle, mlx_device_info_free);
|
||||
CHECK_LOAD(handle, mlx_device_info_has_key);
|
||||
CHECK_LOAD(handle, mlx_device_info_is_string);
|
||||
CHECK_LOAD(handle, mlx_device_info_get_string);
|
||||
CHECK_LOAD(handle, mlx_device_info_get_size);
|
||||
CHECK_LOAD(handle, mlx_device_info_get_keys);
|
||||
CHECK_LOAD(handle, mlx_distributed_all_gather);
|
||||
CHECK_LOAD(handle, mlx_distributed_all_max);
|
||||
CHECK_LOAD(handle, mlx_distributed_all_min);
|
||||
@@ -2284,6 +2393,11 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_distributed_recv_like);
|
||||
CHECK_LOAD(handle, mlx_distributed_send);
|
||||
CHECK_LOAD(handle, mlx_distributed_sum_scatter);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_rank);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_size);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_split);
|
||||
CHECK_LOAD(handle, mlx_distributed_is_available);
|
||||
CHECK_LOAD(handle, mlx_distributed_init);
|
||||
CHECK_LOAD(handle, mlx_set_error_handler);
|
||||
CHECK_LOAD(handle, _mlx_error);
|
||||
CHECK_LOAD(handle, mlx_export_function);
|
||||
@@ -2325,6 +2439,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_fast_metal_kernel_apply);
|
||||
CHECK_LOAD(handle, mlx_fast_rms_norm);
|
||||
CHECK_LOAD(handle, mlx_fast_rope);
|
||||
CHECK_LOAD(handle, mlx_fast_rope_dynamic);
|
||||
CHECK_LOAD(handle, mlx_fast_scaled_dot_product_attention);
|
||||
CHECK_LOAD(handle, mlx_fft_fft);
|
||||
CHECK_LOAD(handle, mlx_fft_fft2);
|
||||
@@ -2340,14 +2455,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_fft_rfft);
|
||||
CHECK_LOAD(handle, mlx_fft_rfft2);
|
||||
CHECK_LOAD(handle, mlx_fft_rfftn);
|
||||
CHECK_LOAD(handle, mlx_io_reader_new);
|
||||
CHECK_LOAD(handle, mlx_io_reader_descriptor);
|
||||
CHECK_LOAD(handle, mlx_io_reader_tostring);
|
||||
CHECK_LOAD(handle, mlx_io_reader_free);
|
||||
CHECK_LOAD(handle, mlx_io_writer_new);
|
||||
CHECK_LOAD(handle, mlx_io_writer_descriptor);
|
||||
CHECK_LOAD(handle, mlx_io_writer_tostring);
|
||||
CHECK_LOAD(handle, mlx_io_writer_free);
|
||||
CHECK_LOAD(handle, mlx_load_reader);
|
||||
CHECK_LOAD(handle, mlx_load);
|
||||
CHECK_LOAD(handle, mlx_load_safetensors_reader);
|
||||
@@ -2356,6 +2463,14 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_save);
|
||||
CHECK_LOAD(handle, mlx_save_safetensors_writer);
|
||||
CHECK_LOAD(handle, mlx_save_safetensors);
|
||||
CHECK_LOAD(handle, mlx_io_reader_new);
|
||||
CHECK_LOAD(handle, mlx_io_reader_descriptor);
|
||||
CHECK_LOAD(handle, mlx_io_reader_tostring);
|
||||
CHECK_LOAD(handle, mlx_io_reader_free);
|
||||
CHECK_LOAD(handle, mlx_io_writer_new);
|
||||
CHECK_LOAD(handle, mlx_io_writer_descriptor);
|
||||
CHECK_LOAD(handle, mlx_io_writer_tostring);
|
||||
CHECK_LOAD(handle, mlx_io_writer_free);
|
||||
CHECK_LOAD(handle, mlx_linalg_cholesky);
|
||||
CHECK_LOAD(handle, mlx_linalg_cholesky_inv);
|
||||
CHECK_LOAD(handle, mlx_linalg_cross);
|
||||
@@ -2400,7 +2515,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_set_cache_limit);
|
||||
CHECK_LOAD(handle, mlx_set_memory_limit);
|
||||
CHECK_LOAD(handle, mlx_set_wired_limit);
|
||||
CHECK_LOAD(handle, mlx_metal_device_info);
|
||||
CHECK_LOAD(handle, mlx_metal_is_available);
|
||||
CHECK_LOAD(handle, mlx_metal_start_capture);
|
||||
CHECK_LOAD(handle, mlx_metal_stop_capture);
|
||||
@@ -2486,6 +2600,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_full);
|
||||
CHECK_LOAD(handle, mlx_full_like);
|
||||
CHECK_LOAD(handle, mlx_gather);
|
||||
CHECK_LOAD(handle, mlx_gather_single);
|
||||
CHECK_LOAD(handle, mlx_gather_mm);
|
||||
CHECK_LOAD(handle, mlx_gather_qmm);
|
||||
CHECK_LOAD(handle, mlx_greater);
|
||||
@@ -2550,6 +2665,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_prod_axis);
|
||||
CHECK_LOAD(handle, mlx_prod);
|
||||
CHECK_LOAD(handle, mlx_put_along_axis);
|
||||
CHECK_LOAD(handle, mlx_qqmm);
|
||||
CHECK_LOAD(handle, mlx_quantize);
|
||||
CHECK_LOAD(handle, mlx_quantized_matmul);
|
||||
CHECK_LOAD(handle, mlx_radians);
|
||||
@@ -2566,11 +2682,16 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_round);
|
||||
CHECK_LOAD(handle, mlx_rsqrt);
|
||||
CHECK_LOAD(handle, mlx_scatter);
|
||||
CHECK_LOAD(handle, mlx_scatter_single);
|
||||
CHECK_LOAD(handle, mlx_scatter_add);
|
||||
CHECK_LOAD(handle, mlx_scatter_add_single);
|
||||
CHECK_LOAD(handle, mlx_scatter_add_axis);
|
||||
CHECK_LOAD(handle, mlx_scatter_max);
|
||||
CHECK_LOAD(handle, mlx_scatter_max_single);
|
||||
CHECK_LOAD(handle, mlx_scatter_min);
|
||||
CHECK_LOAD(handle, mlx_scatter_min_single);
|
||||
CHECK_LOAD(handle, mlx_scatter_prod);
|
||||
CHECK_LOAD(handle, mlx_scatter_prod_single);
|
||||
CHECK_LOAD(handle, mlx_segmented_mm);
|
||||
CHECK_LOAD(handle, mlx_sigmoid);
|
||||
CHECK_LOAD(handle, mlx_sign);
|
||||
@@ -2665,8 +2786,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_string_set);
|
||||
CHECK_LOAD(handle, mlx_string_data);
|
||||
CHECK_LOAD(handle, mlx_string_free);
|
||||
CHECK_LOAD(handle, mlx_detail_vmap_replace);
|
||||
CHECK_LOAD(handle, mlx_detail_vmap_trace);
|
||||
CHECK_LOAD(handle, mlx_async_eval);
|
||||
CHECK_LOAD(handle, mlx_checkpoint);
|
||||
CHECK_LOAD(handle, mlx_custom_function);
|
||||
@@ -2675,6 +2794,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_jvp);
|
||||
CHECK_LOAD(handle, mlx_value_and_grad);
|
||||
CHECK_LOAD(handle, mlx_vjp);
|
||||
CHECK_LOAD(handle, mlx_detail_vmap_replace);
|
||||
CHECK_LOAD(handle, mlx_detail_vmap_trace);
|
||||
CHECK_LOAD(handle, mlx_vector_array_new);
|
||||
CHECK_LOAD(handle, mlx_vector_array_set);
|
||||
CHECK_LOAD(handle, mlx_vector_array_free);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,6 +4,10 @@
|
||||
#define MLX_GENERATED_H
|
||||
|
||||
#include "dynamic.h"
|
||||
{{ range .Functions }}
|
||||
#define {{ .Name }} {{ .Name }}_mlx_gen_orig_
|
||||
{{- end }}
|
||||
|
||||
#include "mlx/c/mlx.h"
|
||||
{{ range .Functions }}
|
||||
#undef {{ .Name }}
|
||||
|
||||
Reference in New Issue
Block a user