Compare commits

..

3 Commits

Author SHA1 Message Date
Bruce MacDonald
9d02d1d767 install: prevent partial download script execution (#14311)
Wrap script in main function so that a truncated partial download doesn't end up executing half a script.
2026-02-18 18:32:45 -08:00
Bruce MacDonald
1a636fb47a cmd: set codex env vars on launch and handle zstd request bodies (#14122)
The Codex runner was not setting OPENAI_BASE_URL or OPENAI_API_KEY, this prevents Codex from sending requests to api.openai.com instead of the local Ollama server. This mirrors the approach used by the Claude runner.

Codex v0.98.0 sends zstd-compressed request bodies to the /v1/responses endpoint. Add decompression support in ResponsesMiddleware with an 8MB max decompressed size limit to prevent resource exhaustion.
2026-02-18 17:19:36 -08:00
Patrick Devine
0759fface9 Revert "chore: update mlx-c bindings to 0.5.0 (#14303)" (#14316)
This reverts commit f01a9a7859.
2026-02-18 17:01:25 -08:00
18 changed files with 991 additions and 1269 deletions

View File

@@ -1 +1 @@
v0.5.0
v0.4.1

View File

@@ -9,7 +9,6 @@ import (
"fmt"
"io"
"log/slog"
"net/http"
"os"
"path/filepath"
"strings"
@@ -84,24 +83,3 @@ func Sign(ctx context.Context, bts []byte) (string, error) {
// signature is <pubkey>:<signature>
return fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)), nil
}
// SignRequest adds a nonce query parameter and an Authorization header with
// an Ed25519 signature to req.
func SignRequest(ctx context.Context, req *http.Request) error {
nonce, err := NewNonce(rand.Reader, 16)
if err != nil {
return err
}
q := req.URL.Query()
q.Set("nonce", nonce)
req.URL.RawQuery = q.Encode()
data := []byte(fmt.Sprintf("%s,%s", req.Method, req.URL.RequestURI()))
signature, err := Sign(ctx, data)
if err != nil {
return err
}
req.Header.Set("Authorization", signature)
return nil
}

View File

@@ -1900,21 +1900,6 @@ func runInteractiveTUI(cmd *cobra.Command) {
return
}
if version.Version != "0.0.0" && version.IsOfficialInstall() && version.IsLocalHost(envconfig.Host()) {
if version.HasCachedUpdate() {
fmt.Print("A new version of Ollama is available. Run \"ollama update\" to install.\n\n")
_ = version.ClearCachedUpdate()
}
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if available, err := version.CheckForUpdate(ctx); err == nil && available {
_ = version.CacheAvailableUpdate()
}
}()
}
// Selector adapters for tui
singleSelector := func(title string, items []config.ModelItem, current string) (string, error) {
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
@@ -2332,18 +2317,6 @@ func NewCLI() *cobra.Command {
}
}
updateCmd := &cobra.Command{
Use: "update",
Short: "Update Ollama to the latest version",
Args: cobra.ExactArgs(0),
RunE: func(cmd *cobra.Command, args []string) error {
force, _ := cmd.Flags().GetBool("force")
_ = version.ClearCachedUpdate()
return version.DoUpdate(force)
},
}
updateCmd.Flags().BoolP("force", "f", false, "Force update even if installed via a package manager")
rootCmd.AddCommand(
serveCmd,
createCmd,
@@ -2361,7 +2334,6 @@ func NewCLI() *cobra.Command {
copyCmd,
deleteCmd,
runnerCmd,
updateCmd,
config.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
)

View File

@@ -6,6 +6,7 @@ import (
"os/exec"
"strings"
"github.com/ollama/ollama/envconfig"
"golang.org/x/mod/semver"
)
@@ -32,6 +33,10 @@ func (c *Codex) Run(model string, args []string) error {
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(),
"OPENAI_BASE_URL="+envconfig.Host().String()+"/v1/",
"OPENAI_API_KEY=ollama",
)
return cmd.Run()
}

1
go.mod
View File

@@ -26,6 +26,7 @@ require (
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/dlclark/regexp2 v1.11.4
github.com/emirpasic/gods/v2 v2.0.0-alpha
github.com/klauspost/compress v1.18.3
github.com/mattn/go-runewidth v0.0.16
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c

4
go.sum
View File

@@ -122,7 +122,6 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/flatbuffers v2.0.0+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
github.com/google/flatbuffers v24.3.25+incompatible h1:CX395cjN9Kke9mmalRoL3d81AtFUxJM+yDthflgJGkI=
@@ -150,8 +149,9 @@ github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.13.1 h1:wXr2uRxZTJXHLly6qhJabee5JqIhTRoLBhDOA74hDEQ=
github.com/klauspost/compress v1.13.1/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg=
github.com/klauspost/compress v1.18.3 h1:9PJRvfbmTabkOX8moIpXPbMMbYN60bWImDDU7L+/6zw=
github.com/klauspost/compress v1.18.3/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/openai"
@@ -496,6 +497,17 @@ func (w *ResponsesWriter) Write(data []byte) (int, error) {
func ResponsesMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
if c.GetHeader("Content-Encoding") == "zstd" {
reader, err := zstd.NewReader(c.Request.Body, zstd.WithDecoderMaxMemory(8<<20))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "failed to decompress zstd body"))
return
}
defer reader.Close()
c.Request.Body = io.NopCloser(reader)
c.Request.Header.Del("Content-Encoding")
}
var req openai.ResponsesRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))

View File

@@ -14,6 +14,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/klauspost/compress/zstd"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/openai"
@@ -1238,3 +1239,102 @@ func TestImageEditsMiddleware(t *testing.T) {
})
}
}
func zstdCompress(t *testing.T, data []byte) []byte {
t.Helper()
var buf bytes.Buffer
w, err := zstd.NewWriter(&buf)
if err != nil {
t.Fatal(err)
}
if _, err := w.Write(data); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
return buf.Bytes()
}
func TestResponsesMiddlewareZstd(t *testing.T) {
tests := []struct {
name string
body string
useZstd bool
oversized bool
wantCode int
wantModel string
wantMessage string
}{
{
name: "plain JSON",
body: `{"model": "test-model", "input": "Hello"}`,
wantCode: http.StatusOK,
wantModel: "test-model",
wantMessage: "Hello",
},
{
name: "zstd compressed",
body: `{"model": "test-model", "input": "Hello"}`,
useZstd: true,
wantCode: http.StatusOK,
wantModel: "test-model",
wantMessage: "Hello",
},
{
name: "zstd over max decompressed size",
oversized: true,
useZstd: true,
wantCode: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var capturedRequest *api.ChatRequest
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(ResponsesMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/v1/responses", func(c *gin.Context) {
c.Status(http.StatusOK)
})
var bodyReader io.Reader
if tt.oversized {
bodyReader = bytes.NewReader(zstdCompress(t, bytes.Repeat([]byte("A"), 9<<20)))
} else if tt.useZstd {
bodyReader = bytes.NewReader(zstdCompress(t, []byte(tt.body)))
} else {
bodyReader = strings.NewReader(tt.body)
}
req, _ := http.NewRequest(http.MethodPost, "/v1/responses", bodyReader)
req.Header.Set("Content-Type", "application/json")
if tt.useZstd || tt.oversized {
req.Header.Set("Content-Encoding", "zstd")
}
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != tt.wantCode {
t.Fatalf("expected status %d, got %d: %s", tt.wantCode, resp.Code, resp.Body.String())
}
if tt.wantCode != http.StatusOK {
return
}
if capturedRequest == nil {
t.Fatal("expected captured request, got nil")
}
if capturedRequest.Model != tt.wantModel {
t.Fatalf("expected model %q, got %q", tt.wantModel, capturedRequest.Model)
}
if len(capturedRequest.Messages) != 1 || capturedRequest.Messages[0].Content != tt.wantMessage {
t.Fatalf("expected single user message %q, got %+v", tt.wantMessage, capturedRequest.Messages)
}
})
}
}

View File

@@ -2,6 +2,10 @@
# This script installs Ollama on Linux and macOS.
# It detects the current operating system architecture and installs the appropriate version of Ollama.
# Wrap script in main function so that a truncated partial download doesn't end
# up executing half a script.
main() {
set -eu
red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)"
@@ -446,3 +450,6 @@ fi
status "NVIDIA GPU ready."
install_success
}
main

View File

@@ -1,190 +0,0 @@
package version
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/ollama/ollama/auth"
)
var updateCheckURLBase = "https://ollama.com"
// CheckForUpdate calls the ollama.com update API and reports whether a
// newer version is available.
func CheckForUpdate(ctx context.Context) (bool, error) {
requestURL, err := url.Parse(updateCheckURLBase + "/api/update")
if err != nil {
return false, fmt.Errorf("parse update URL: %w", err)
}
query := requestURL.Query()
query.Add("os", runtime.GOOS)
query.Add("arch", runtime.GOARCH)
query.Add("version", Version)
requestURL.RawQuery = query.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
if err != nil {
return false, fmt.Errorf("create request: %w", err)
}
_ = auth.SignRequest(ctx, req)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return false, fmt.Errorf("update check request: %w", err)
}
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK, nil
}
func cacheFilePath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".ollama", "update"), nil
}
// CacheAvailableUpdate creates the update marker file.
func CacheAvailableUpdate() error {
path, err := cacheFilePath()
if err != nil {
return err
}
f, err := os.Create(path)
if err != nil {
return err
}
return f.Close()
}
// HasCachedUpdate reports whether a non-stale update marker exists.
func HasCachedUpdate() bool {
path, err := cacheFilePath()
if err != nil {
return false
}
fi, err := os.Stat(path)
if err != nil {
return false
}
return time.Since(fi.ModTime()) <= 24*time.Hour
}
// ClearCachedUpdate removes the update marker file.
func ClearCachedUpdate() error {
path, err := cacheFilePath()
if err != nil {
return err
}
err = os.Remove(path)
if os.IsNotExist(err) {
return nil
}
return err
}
func IsOfficialInstall() bool {
exe, err := os.Executable()
if err != nil {
return false
}
exe, err = filepath.EvalSymlinks(exe)
if err != nil {
return false
}
switch runtime.GOOS {
case "windows":
localAppData := os.Getenv("LOCALAPPDATA")
if localAppData == "" {
return false
}
return strings.HasPrefix(strings.ToLower(exe), strings.ToLower(filepath.Join(localAppData, "Programs", "Ollama")+string(filepath.Separator)))
case "darwin":
return strings.HasPrefix(exe, "/Applications/Ollama.app/")
default:
dir := filepath.Dir(exe)
return dir == "/usr/local/bin" || dir == "/usr/bin" || dir == "/bin"
}
}
// DoUpdate downloads and runs the platform-appropriate install script.
func DoUpdate(force bool) error {
if !force && !IsOfficialInstall() {
return fmt.Errorf("ollama appears to be installed through a package manager. Please update it using your package manager")
}
var scriptURL, tmpPattern, shell string
switch runtime.GOOS {
case "windows":
scriptURL = "https://ollama.com/install.ps1"
tmpPattern = "ollama-install-*.ps1"
shell = "powershell"
default:
scriptURL = "https://ollama.com/install.sh"
tmpPattern = "ollama-install-*.sh"
shell = "sh"
}
resp, err := http.Get(scriptURL)
if err != nil {
return fmt.Errorf("download install script: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download install script: status %d", resp.StatusCode)
}
tmpFile, err := os.CreateTemp("", tmpPattern)
if err != nil {
return fmt.Errorf("create temp file: %w", err)
}
defer os.Remove(tmpFile.Name())
if _, err := io.Copy(tmpFile, resp.Body); err != nil {
tmpFile.Close()
return fmt.Errorf("write install script: %w", err)
}
tmpFile.Close()
cmd := exec.Command(shell, tmpFile.Name())
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
// IsLocalHost reports whether the configured Ollama host points to the
// local machine.
func IsLocalHost(host *url.URL) bool {
hostname := host.Hostname()
switch hostname {
case "", "127.0.0.1", "localhost", "::1", "0.0.0.0":
return true
}
if ip := net.ParseIP(hostname); ip != nil {
return ip.IsLoopback()
}
return false
}

View File

@@ -1,146 +0,0 @@
package version
import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"runtime"
"testing"
"time"
)
func setHome(t *testing.T, dir string) {
t.Helper()
if runtime.GOOS == "windows" {
t.Setenv("USERPROFILE", dir)
} else {
t.Setenv("HOME", dir)
}
}
func TestCheckForUpdate(t *testing.T) {
t.Run("update available", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("os") == "" || r.URL.Query().Get("arch") == "" || r.URL.Query().Get("version") == "" {
t.Error("missing expected query parameters")
}
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
old := updateCheckURLBase
updateCheckURLBase = ts.URL
defer func() { updateCheckURLBase = old }()
available, err := CheckForUpdate(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !available {
t.Fatal("expected update to be available")
}
})
t.Run("up to date", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
defer ts.Close()
old := updateCheckURLBase
updateCheckURLBase = ts.URL
defer func() { updateCheckURLBase = old }()
available, err := CheckForUpdate(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if available {
t.Fatal("expected no update available")
}
})
t.Run("network error", func(t *testing.T) {
old := updateCheckURLBase
updateCheckURLBase = "http://localhost:1"
defer func() { updateCheckURLBase = old }()
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
_, err := CheckForUpdate(ctx)
if err == nil {
t.Fatal("expected error for unreachable server")
}
})
}
func TestCacheRoundTrip(t *testing.T) {
tmp := t.TempDir()
setHome(t, tmp)
os.MkdirAll(filepath.Join(tmp, ".ollama"), 0o755)
if err := CacheAvailableUpdate(); err != nil {
t.Fatalf("cache write: %v", err)
}
if !HasCachedUpdate() {
t.Fatal("expected cached update to be present")
}
if err := ClearCachedUpdate(); err != nil {
t.Fatalf("cache clear: %v", err)
}
if HasCachedUpdate() {
t.Fatal("expected no cached update after clear")
}
}
func TestHasCachedUpdateStale(t *testing.T) {
tmp := t.TempDir()
setHome(t, tmp)
os.MkdirAll(filepath.Join(tmp, ".ollama"), 0o755)
if err := CacheAvailableUpdate(); err != nil {
t.Fatalf("cache write: %v", err)
}
// Backdate the file to make it stale
path := filepath.Join(tmp, ".ollama", "update")
staleTime := time.Now().Add(-25 * time.Hour)
os.Chtimes(path, staleTime, staleTime)
if HasCachedUpdate() {
t.Fatal("expected no cached update for stale file")
}
}
func TestIsLocalHost(t *testing.T) {
tests := []struct {
host string
local bool
}{
{"http://127.0.0.1:11434", true},
{"http://localhost:11434", true},
{"http://[::1]:11434", true},
{"http://0.0.0.0:11434", true},
{"http://remote.example.com:11434", false},
{"http://192.168.1.100:11434", false},
}
for _, tt := range tests {
t.Run(tt.host, func(t *testing.T) {
u, err := url.Parse(tt.host)
if err != nil {
t.Fatalf("parse URL: %v", err)
}
if got := IsLocalHost(u); got != tt.local {
t.Errorf("IsLocalHost(%s) = %v, want %v", tt.host, got, tt.local)
}
})
}
}

View File

@@ -16,10 +16,10 @@ import (
)
type Function struct {
Name string
ReturnType string
Params string
ParamNames []string
Name string
ReturnType string
Params string
ParamNames []string
NeedsARM64Guard bool
}
@@ -29,11 +29,6 @@ func findHeaders(directory string) ([]string, error) {
if err != nil {
return err
}
// Private headers contain C++ implementation helpers and are not part of
// the C API surface; parsing them can produce invalid wrapper signatures.
if d.IsDir() && d.Name() == "private" {
return fs.SkipDir
}
if !d.IsDir() && strings.HasSuffix(path, ".h") {
headers = append(headers, path)
}
@@ -199,10 +194,10 @@ func parseFunctions(content string) []Function {
needsGuard := needsARM64Guard(funcName, returnType, params)
functions = append(functions, Function{
Name: funcName,
ReturnType: returnType,
Params: params,
ParamNames: paramNames,
Name: funcName,
ReturnType: returnType,
Params: params,
ParamNames: paramNames,
NeedsARM64Guard: needsGuard,
})
}

View File

@@ -20,8 +20,6 @@ mlx_array (*mlx_array_new_float64_ptr)(double val) = NULL;
mlx_array (*mlx_array_new_double_ptr)(double val) = NULL;
mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val) = NULL;
mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype) = NULL;
mlx_array (*mlx_array_new_data_managed_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) = NULL;
mlx_array (*mlx_array_new_data_managed_payload_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) = NULL;
int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src) = NULL;
int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val) = NULL;
int (*mlx_array_set_int_ptr)(mlx_array* arr, int val) = NULL;
@@ -51,7 +49,7 @@ int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr) = NULL;
int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr) = NULL;
int (*mlx_array_item_complex64_ptr)(mlx_complex64_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_complex64_ptr)(float _Complex* res, const mlx_array arr) = NULL;
#if defined(__aarch64__) || defined(_M_ARM64)
int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr) = NULL;
#endif
@@ -69,7 +67,7 @@ const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr) = NULL;
const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr) = NULL;
const float* (*mlx_array_data_float32_ptr)(const mlx_array arr) = NULL;
const double* (*mlx_array_data_float64_ptr)(const mlx_array arr) = NULL;
const mlx_complex64_t* (*mlx_array_data_complex64_ptr)(const mlx_array arr) = NULL;
const float _Complex* (*mlx_array_data_complex64_ptr)(const mlx_array arr) = NULL;
#if defined(__aarch64__) || defined(_M_ARM64)
const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr) = NULL;
#endif
@@ -125,7 +123,6 @@ int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id) = NULL;
int (*mlx_disable_compile_ptr)(void) = NULL;
int (*mlx_enable_compile_ptr)(void) = NULL;
int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode) = NULL;
int (*mlx_cuda_is_available_ptr)(bool* res) = NULL;
mlx_device (*mlx_device_new_ptr)(void) = NULL;
mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index) = NULL;
int (*mlx_device_free_ptr)(mlx_device dev) = NULL;
@@ -136,16 +133,6 @@ int (*mlx_device_get_index_ptr)(int* index, mlx_device dev) = NULL;
int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev) = NULL;
int (*mlx_get_default_device_ptr)(mlx_device* dev) = NULL;
int (*mlx_set_default_device_ptr)(mlx_device dev) = NULL;
int (*mlx_device_is_available_ptr)(bool* avail, mlx_device dev) = NULL;
int (*mlx_device_count_ptr)(int* count, mlx_device_type type) = NULL;
mlx_device_info (*mlx_device_info_new_ptr)(void) = NULL;
int (*mlx_device_info_get_ptr)(mlx_device_info* info, mlx_device dev) = NULL;
int (*mlx_device_info_free_ptr)(mlx_device_info info) = NULL;
int (*mlx_device_info_has_key_ptr)(bool* exists, mlx_device_info info, const char* key) = NULL;
int (*mlx_device_info_is_string_ptr)(bool* is_string, mlx_device_info info, const char* key) = NULL;
int (*mlx_device_info_get_string_ptr)(const char** value, mlx_device_info info, const char* key) = NULL;
int (*mlx_device_info_get_size_ptr)(size_t* value, mlx_device_info info, const char* key) = NULL;
int (*mlx_device_info_get_keys_ptr)(mlx_vector_string* keys, mlx_device_info info) = NULL;
int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) = NULL;
int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL;
int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL;
@@ -276,6 +263,7 @@ int (*mlx_reset_peak_memory_ptr)(void) = NULL;
int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit) = NULL;
int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit) = NULL;
int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit) = NULL;
mlx_metal_device_info_t (*mlx_metal_device_info_ptr)(void) = NULL;
int (*mlx_metal_is_available_ptr)(bool* res) = NULL;
int (*mlx_metal_start_capture_ptr)(const char* path) = NULL;
int (*mlx_metal_stop_capture_ptr)(void) = NULL;
@@ -670,16 +658,6 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data\n");
return -1;
}
mlx_array_new_data_managed_ptr = dlsym(handle, "mlx_array_new_data_managed");
if (mlx_array_new_data_managed_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed\n");
return -1;
}
mlx_array_new_data_managed_payload_ptr = dlsym(handle, "mlx_array_new_data_managed_payload");
if (mlx_array_new_data_managed_payload_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed_payload\n");
return -1;
}
mlx_array_set_ptr = dlsym(handle, "mlx_array_set");
if (mlx_array_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set\n");
@@ -1163,11 +1141,6 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_compile_mode\n");
return -1;
}
mlx_cuda_is_available_ptr = dlsym(handle, "mlx_cuda_is_available");
if (mlx_cuda_is_available_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_cuda_is_available\n");
return -1;
}
mlx_device_new_ptr = dlsym(handle, "mlx_device_new");
if (mlx_device_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_new\n");
@@ -1218,56 +1191,6 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_default_device\n");
return -1;
}
mlx_device_is_available_ptr = dlsym(handle, "mlx_device_is_available");
if (mlx_device_is_available_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_is_available\n");
return -1;
}
mlx_device_count_ptr = dlsym(handle, "mlx_device_count");
if (mlx_device_count_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_count\n");
return -1;
}
mlx_device_info_new_ptr = dlsym(handle, "mlx_device_info_new");
if (mlx_device_info_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_new\n");
return -1;
}
mlx_device_info_get_ptr = dlsym(handle, "mlx_device_info_get");
if (mlx_device_info_get_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get\n");
return -1;
}
mlx_device_info_free_ptr = dlsym(handle, "mlx_device_info_free");
if (mlx_device_info_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_free\n");
return -1;
}
mlx_device_info_has_key_ptr = dlsym(handle, "mlx_device_info_has_key");
if (mlx_device_info_has_key_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_has_key\n");
return -1;
}
mlx_device_info_is_string_ptr = dlsym(handle, "mlx_device_info_is_string");
if (mlx_device_info_is_string_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_is_string\n");
return -1;
}
mlx_device_info_get_string_ptr = dlsym(handle, "mlx_device_info_get_string");
if (mlx_device_info_get_string_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_string\n");
return -1;
}
mlx_device_info_get_size_ptr = dlsym(handle, "mlx_device_info_get_size");
if (mlx_device_info_get_size_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_size\n");
return -1;
}
mlx_device_info_get_keys_ptr = dlsym(handle, "mlx_device_info_get_keys");
if (mlx_device_info_get_keys_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_keys\n");
return -1;
}
mlx_distributed_all_gather_ptr = dlsym(handle, "mlx_distributed_all_gather");
if (mlx_distributed_all_gather_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_gather\n");
@@ -1918,6 +1841,11 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_wired_limit\n");
return -1;
}
mlx_metal_device_info_ptr = dlsym(handle, "mlx_metal_device_info");
if (mlx_metal_device_info_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_device_info\n");
return -1;
}
mlx_metal_is_available_ptr = dlsym(handle, "mlx_metal_is_available");
if (mlx_metal_is_available_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_is_available\n");
@@ -3600,14 +3528,6 @@ mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dt
return mlx_array_new_data_ptr(data, shape, dim, dtype);
}
mlx_array mlx_array_new_data_managed(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) {
return mlx_array_new_data_managed_ptr(data, shape, dim, dtype, dtor);
}
mlx_array mlx_array_new_data_managed_payload(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) {
return mlx_array_new_data_managed_payload_ptr(data, shape, dim, dtype, payload, dtor);
}
int mlx_array_set(mlx_array* arr, const mlx_array src) {
return mlx_array_set_ptr(arr, src);
}
@@ -3724,7 +3644,7 @@ int mlx_array_item_float64(double* res, const mlx_array arr) {
return mlx_array_item_float64_ptr(res, arr);
}
int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr) {
int mlx_array_item_complex64(float _Complex* res, const mlx_array arr) {
return mlx_array_item_complex64_ptr(res, arr);
}
@@ -3784,7 +3704,7 @@ const double* mlx_array_data_float64(const mlx_array arr) {
return mlx_array_data_float64_ptr(arr);
}
const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr) {
const float _Complex* mlx_array_data_complex64(const mlx_array arr) {
return mlx_array_data_complex64_ptr(arr);
}
@@ -3996,10 +3916,6 @@ int mlx_set_compile_mode(mlx_compile_mode mode) {
return mlx_set_compile_mode_ptr(mode);
}
int mlx_cuda_is_available(bool* res) {
return mlx_cuda_is_available_ptr(res);
}
mlx_device mlx_device_new(void) {
return mlx_device_new_ptr();
}
@@ -4040,46 +3956,6 @@ int mlx_set_default_device(mlx_device dev) {
return mlx_set_default_device_ptr(dev);
}
int mlx_device_is_available(bool* avail, mlx_device dev) {
return mlx_device_is_available_ptr(avail, dev);
}
int mlx_device_count(int* count, mlx_device_type type) {
return mlx_device_count_ptr(count, type);
}
mlx_device_info mlx_device_info_new(void) {
return mlx_device_info_new_ptr();
}
int mlx_device_info_get(mlx_device_info* info, mlx_device dev) {
return mlx_device_info_get_ptr(info, dev);
}
int mlx_device_info_free(mlx_device_info info) {
return mlx_device_info_free_ptr(info);
}
int mlx_device_info_has_key(bool* exists, mlx_device_info info, const char* key) {
return mlx_device_info_has_key_ptr(exists, info, key);
}
int mlx_device_info_is_string(bool* is_string, mlx_device_info info, const char* key) {
return mlx_device_info_is_string_ptr(is_string, info, key);
}
int mlx_device_info_get_string(const char** value, mlx_device_info info, const char* key) {
return mlx_device_info_get_string_ptr(value, info, key);
}
int mlx_device_info_get_size(size_t* value, mlx_device_info info, const char* key) {
return mlx_device_info_get_size_ptr(value, info, key);
}
int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info) {
return mlx_device_info_get_keys_ptr(keys, info);
}
int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) {
return mlx_distributed_all_gather_ptr(res, x, group, S);
}
@@ -4600,6 +4476,10 @@ int mlx_set_wired_limit(size_t* res, size_t limit) {
return mlx_set_wired_limit_ptr(res, limit);
}
mlx_metal_device_info_t mlx_metal_device_info(void) {
return mlx_metal_device_info_ptr();
}
int mlx_metal_is_available(bool* res) {
return mlx_metal_is_available_ptr(res);
}

View File

@@ -26,8 +26,6 @@
#undef mlx_array_new_double
#undef mlx_array_new_complex
#undef mlx_array_new_data
#undef mlx_array_new_data_managed
#undef mlx_array_new_data_managed_payload
#undef mlx_array_set
#undef mlx_array_set_bool
#undef mlx_array_set_int
@@ -123,7 +121,6 @@
#undef mlx_disable_compile
#undef mlx_enable_compile
#undef mlx_set_compile_mode
#undef mlx_cuda_is_available
#undef mlx_device_new
#undef mlx_device_new_type
#undef mlx_device_free
@@ -134,16 +131,6 @@
#undef mlx_device_get_type
#undef mlx_get_default_device
#undef mlx_set_default_device
#undef mlx_device_is_available
#undef mlx_device_count
#undef mlx_device_info_new
#undef mlx_device_info_get
#undef mlx_device_info_free
#undef mlx_device_info_has_key
#undef mlx_device_info_is_string
#undef mlx_device_info_get_string
#undef mlx_device_info_get_size
#undef mlx_device_info_get_keys
#undef mlx_distributed_all_gather
#undef mlx_distributed_all_max
#undef mlx_distributed_all_min
@@ -274,6 +261,7 @@
#undef mlx_set_cache_limit
#undef mlx_set_memory_limit
#undef mlx_set_wired_limit
#undef mlx_metal_device_info
#undef mlx_metal_is_available
#undef mlx_metal_start_capture
#undef mlx_metal_stop_capture
@@ -614,8 +602,6 @@ extern mlx_array (*mlx_array_new_float64_ptr)(double val);
extern mlx_array (*mlx_array_new_double_ptr)(double val);
extern mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val);
extern mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype);
extern mlx_array (*mlx_array_new_data_managed_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*));
extern mlx_array (*mlx_array_new_data_managed_payload_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*));
extern int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src);
extern int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val);
extern int (*mlx_array_set_int_ptr)(mlx_array* arr, int val);
@@ -645,7 +631,7 @@ extern int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr);
extern int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr);
extern int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr);
extern int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr);
extern int (*mlx_array_item_complex64_ptr)(mlx_complex64_t* res, const mlx_array arr);
extern int (*mlx_array_item_complex64_ptr)(float _Complex* res, const mlx_array arr);
#if defined(__aarch64__) || defined(_M_ARM64)
extern int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr);
#endif
@@ -663,7 +649,7 @@ extern const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr);
extern const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr);
extern const float* (*mlx_array_data_float32_ptr)(const mlx_array arr);
extern const double* (*mlx_array_data_float64_ptr)(const mlx_array arr);
extern const mlx_complex64_t* (*mlx_array_data_complex64_ptr)(const mlx_array arr);
extern const float _Complex* (*mlx_array_data_complex64_ptr)(const mlx_array arr);
#if defined(__aarch64__) || defined(_M_ARM64)
extern const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr);
#endif
@@ -719,7 +705,6 @@ extern int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id);
extern int (*mlx_disable_compile_ptr)(void);
extern int (*mlx_enable_compile_ptr)(void);
extern int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode);
extern int (*mlx_cuda_is_available_ptr)(bool* res);
extern mlx_device (*mlx_device_new_ptr)(void);
extern mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index);
extern int (*mlx_device_free_ptr)(mlx_device dev);
@@ -730,16 +715,6 @@ extern int (*mlx_device_get_index_ptr)(int* index, mlx_device dev);
extern int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev);
extern int (*mlx_get_default_device_ptr)(mlx_device* dev);
extern int (*mlx_set_default_device_ptr)(mlx_device dev);
extern int (*mlx_device_is_available_ptr)(bool* avail, mlx_device dev);
extern int (*mlx_device_count_ptr)(int* count, mlx_device_type type);
extern mlx_device_info (*mlx_device_info_new_ptr)(void);
extern int (*mlx_device_info_get_ptr)(mlx_device_info* info, mlx_device dev);
extern int (*mlx_device_info_free_ptr)(mlx_device_info info);
extern int (*mlx_device_info_has_key_ptr)(bool* exists, mlx_device_info info, const char* key);
extern int (*mlx_device_info_is_string_ptr)(bool* is_string, mlx_device_info info, const char* key);
extern int (*mlx_device_info_get_string_ptr)(const char** value, mlx_device_info info, const char* key);
extern int (*mlx_device_info_get_size_ptr)(size_t* value, mlx_device_info info, const char* key);
extern int (*mlx_device_info_get_keys_ptr)(mlx_vector_string* keys, mlx_device_info info);
extern int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S);
extern int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
extern int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
@@ -870,6 +845,7 @@ extern int (*mlx_reset_peak_memory_ptr)(void);
extern int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit);
extern int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit);
extern int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit);
extern mlx_metal_device_info_t (*mlx_metal_device_info_ptr)(void);
extern int (*mlx_metal_is_available_ptr)(bool* res);
extern int (*mlx_metal_start_capture_ptr)(const char* path);
extern int (*mlx_metal_stop_capture_ptr)(void);
@@ -1226,10 +1202,6 @@ mlx_array mlx_array_new_complex(float real_val, float imag_val);
mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dtype dtype);
mlx_array mlx_array_new_data_managed(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*));
mlx_array mlx_array_new_data_managed_payload(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*));
int mlx_array_set(mlx_array* arr, const mlx_array src);
int mlx_array_set_bool(mlx_array* arr, bool val);
@@ -1288,7 +1260,7 @@ int mlx_array_item_float32(float* res, const mlx_array arr);
int mlx_array_item_float64(double* res, const mlx_array arr);
int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr);
int mlx_array_item_complex64(float _Complex* res, const mlx_array arr);
#if defined(__aarch64__) || defined(_M_ARM64)
int mlx_array_item_float16(float16_t* res, const mlx_array arr);
@@ -1320,7 +1292,7 @@ const float* mlx_array_data_float32(const mlx_array arr);
const double* mlx_array_data_float64(const mlx_array arr);
const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr);
const float _Complex* mlx_array_data_complex64(const mlx_array arr);
#if defined(__aarch64__) || defined(_M_ARM64)
const float16_t* mlx_array_data_float16(const mlx_array arr);
@@ -1428,8 +1400,6 @@ int mlx_enable_compile(void);
int mlx_set_compile_mode(mlx_compile_mode mode);
int mlx_cuda_is_available(bool* res);
mlx_device mlx_device_new(void);
mlx_device mlx_device_new_type(mlx_device_type type, int index);
@@ -1450,26 +1420,6 @@ int mlx_get_default_device(mlx_device* dev);
int mlx_set_default_device(mlx_device dev);
int mlx_device_is_available(bool* avail, mlx_device dev);
int mlx_device_count(int* count, mlx_device_type type);
mlx_device_info mlx_device_info_new(void);
int mlx_device_info_get(mlx_device_info* info, mlx_device dev);
int mlx_device_info_free(mlx_device_info info);
int mlx_device_info_has_key(bool* exists, mlx_device_info info, const char* key);
int mlx_device_info_is_string(bool* is_string, mlx_device_info info, const char* key);
int mlx_device_info_get_string(const char** value, mlx_device_info info, const char* key);
int mlx_device_info_get_size(size_t* value, mlx_device_info info, const char* key);
int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info);
int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S);
int mlx_distributed_all_max(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
@@ -1730,6 +1680,8 @@ int mlx_set_memory_limit(size_t* res, size_t limit);
int mlx_set_wired_limit(size_t* res, size_t limit);
mlx_metal_device_info_t mlx_metal_device_info(void);
int mlx_metal_is_available(bool* res);
int mlx_metal_start_capture(const char* path);

View File

@@ -15,7 +15,7 @@ set(CMAKE_INSTALL_RPATH "@loader_path")
include(FetchContent)
set(MLX_C_GIT_TAG "v0.5.0" CACHE STRING "")
set(MLX_C_GIT_TAG "v0.4.1" CACHE STRING "")
FetchContent_Declare(
mlx-c

View File

@@ -22,19 +22,6 @@ mlx_array (*mlx_array_new_data_)(
const int* shape,
int dim,
mlx_dtype dtype) = NULL;
mlx_array (*mlx_array_new_data_managed_)(
void* data,
const int* shape,
int dim,
mlx_dtype dtype,
void (*dtor)(void*)) = NULL;
mlx_array (*mlx_array_new_data_managed_payload_)(
void* data,
const int* shape,
int dim,
mlx_dtype dtype,
void* payload,
void (*dtor)(void*)) = NULL;
int (*mlx_array_set_)(mlx_array* arr, const mlx_array src) = NULL;
int (*mlx_array_set_bool_)(mlx_array* arr, bool val) = NULL;
int (*mlx_array_set_int_)(mlx_array* arr, int val) = NULL;
@@ -69,7 +56,7 @@ int (*mlx_array_item_int32_)(int32_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_int64_)(int64_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_float32_)(float* res, const mlx_array arr) = NULL;
int (*mlx_array_item_float64_)(double* res, const mlx_array arr) = NULL;
int (*mlx_array_item_complex64_)(mlx_complex64_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_complex64_)(float _Complex* res, const mlx_array arr) = NULL;
int (*mlx_array_item_float16_)(float16_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_bfloat16_)(bfloat16_t* res, const mlx_array arr) = NULL;
const bool * (*mlx_array_data_bool_)(const mlx_array arr) = NULL;
@@ -83,7 +70,7 @@ const int32_t * (*mlx_array_data_int32_)(const mlx_array arr) = NULL;
const int64_t * (*mlx_array_data_int64_)(const mlx_array arr) = NULL;
const float * (*mlx_array_data_float32_)(const mlx_array arr) = NULL;
const double * (*mlx_array_data_float64_)(const mlx_array arr) = NULL;
const mlx_complex64_t * (*mlx_array_data_complex64_)(const mlx_array arr) = NULL;
const float _Complex * (*mlx_array_data_complex64_)(const mlx_array arr) = NULL;
const float16_t * (*mlx_array_data_float16_)(const mlx_array arr) = NULL;
const bfloat16_t * (*mlx_array_data_bfloat16_)(const mlx_array arr) = NULL;
int (*_mlx_array_is_available_)(bool* res, const mlx_array arr) = NULL;
@@ -107,11 +94,10 @@ int (*mlx_closure_apply_)(
mlx_closure (*mlx_closure_new_unary_)(int (*fun)(mlx_array*, const mlx_array)) = NULL;
mlx_closure_kwargs (*mlx_closure_kwargs_new_)(void) = NULL;
int (*mlx_closure_kwargs_free_)(mlx_closure_kwargs cls) = NULL;
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_map_string_to_array)) = NULL;
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_map_string_to_array)) = NULL;
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_payload_)(
int (*fun)(
mlx_vector_array*,
@@ -150,12 +136,11 @@ int (*mlx_closure_value_and_grad_apply_)(
const mlx_vector_array input) = NULL;
mlx_closure_custom (*mlx_closure_custom_new_)(void) = NULL;
int (*mlx_closure_custom_free_)(mlx_closure_custom cls) = NULL;
mlx_closure_custom (*mlx_closure_custom_new_func_)(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const mlx_vector_array)) = NULL;
mlx_closure_custom (*mlx_closure_custom_new_func_)(int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const mlx_vector_array)) = NULL;
mlx_closure_custom (*mlx_closure_custom_new_func_payload_)(
int (*fun)(
mlx_vector_array*,
@@ -176,13 +161,12 @@ int (*mlx_closure_custom_apply_)(
const mlx_vector_array input_2) = NULL;
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_)(void) = NULL;
int (*mlx_closure_custom_jvp_free_)(mlx_closure_custom_jvp cls) = NULL;
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const int*,
size_t _num)) = NULL;
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const int*,
size_t _num)) = NULL;
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_payload_)(
int (*fun)(
mlx_vector_array*,
@@ -205,13 +189,12 @@ int (*mlx_closure_custom_jvp_apply_)(
size_t input_2_num) = NULL;
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_)(void) = NULL;
int (*mlx_closure_custom_vmap_free_)(mlx_closure_custom_vmap cls) = NULL;
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(
int (*fun)(
mlx_vector_array*,
mlx_vector_int*,
const mlx_vector_array,
const int*,
size_t _num)) = NULL;
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(int (*fun)(
mlx_vector_array*,
mlx_vector_int*,
const mlx_vector_array,
const int*,
size_t _num)) = NULL;
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_payload_)(
int (*fun)(
mlx_vector_array*,
@@ -245,7 +228,6 @@ int (*mlx_detail_compile_erase_)(uintptr_t fun_id) = NULL;
int (*mlx_disable_compile_)(void) = NULL;
int (*mlx_enable_compile_)(void) = NULL;
int (*mlx_set_compile_mode_)(mlx_compile_mode mode) = NULL;
int (*mlx_cuda_is_available_)(bool* res) = NULL;
mlx_device (*mlx_device_new_)(void) = NULL;
mlx_device (*mlx_device_new_type_)(mlx_device_type type, int index) = NULL;
int (*mlx_device_free_)(mlx_device dev) = NULL;
@@ -256,28 +238,11 @@ int (*mlx_device_get_index_)(int* index, mlx_device dev) = NULL;
int (*mlx_device_get_type_)(mlx_device_type* type, mlx_device dev) = NULL;
int (*mlx_get_default_device_)(mlx_device* dev) = NULL;
int (*mlx_set_default_device_)(mlx_device dev) = NULL;
int (*mlx_device_is_available_)(bool* avail, mlx_device dev) = NULL;
int (*mlx_device_count_)(int* count, mlx_device_type type) = NULL;
mlx_device_info (*mlx_device_info_new_)(void) = NULL;
int (*mlx_device_info_get_)(mlx_device_info* info, mlx_device dev) = NULL;
int (*mlx_device_info_free_)(mlx_device_info info) = NULL;
int (*mlx_device_info_has_key_)(
bool* exists,
mlx_device_info info,
const char* key) = NULL;
int (*mlx_device_info_is_string_)(
bool* is_string,
mlx_device_info info,
const char* key) = NULL;
int (*mlx_device_info_get_string_)(
const char** value,
mlx_device_info info,
const char* key) = NULL;
int (*mlx_device_info_get_size_)(
size_t* value,
mlx_device_info info,
const char* key) = NULL;
int (*mlx_device_info_get_keys_)(mlx_vector_string* keys, mlx_device_info info) = NULL;
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
bool (*mlx_distributed_is_available_)(void) = NULL;
mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL;
int (*mlx_distributed_all_gather_)(
mlx_array* res,
const mlx_array x,
@@ -323,11 +288,6 @@ int (*mlx_distributed_sum_scatter_)(
const mlx_array x,
const mlx_distributed_group group /* may be null */,
const mlx_stream s) = NULL;
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
bool (*mlx_distributed_is_available_)(void) = NULL;
mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL;
void (*mlx_set_error_handler_)(
mlx_error_handler_func handler,
void* data,
@@ -490,16 +450,6 @@ int (*mlx_fast_rope_)(
int offset,
const mlx_array freqs /* may be null */,
const mlx_stream s) = NULL;
int (*mlx_fast_rope_dynamic_)(
mlx_array* res,
const mlx_array x,
int dims,
bool traditional,
mlx_optional_float base,
float scale,
const mlx_array offset,
const mlx_array freqs /* may be null */,
const mlx_stream s) = NULL;
int (*mlx_fast_scaled_dot_product_attention_)(
mlx_array* res,
const mlx_array queries,
@@ -610,6 +560,14 @@ int (*mlx_fft_rfftn_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable) = NULL;
int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io) = NULL;
int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io) = NULL;
int (*mlx_io_reader_free_)(mlx_io_reader io) = NULL;
mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable) = NULL;
int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io) = NULL;
int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io) = NULL;
int (*mlx_io_writer_free_)(mlx_io_writer io) = NULL;
int (*mlx_load_reader_)(
mlx_array* res,
mlx_io_reader in_stream,
@@ -635,14 +593,6 @@ int (*mlx_save_safetensors_)(
const char* file,
const mlx_map_string_to_array param,
const mlx_map_string_to_string metadata) = NULL;
mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable) = NULL;
int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io) = NULL;
int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io) = NULL;
int (*mlx_io_reader_free_)(mlx_io_reader io) = NULL;
mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable) = NULL;
int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io) = NULL;
int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io) = NULL;
int (*mlx_io_writer_free_)(mlx_io_writer io) = NULL;
int (*mlx_linalg_cholesky_)(
mlx_array* res,
const mlx_array a,
@@ -783,6 +733,7 @@ int (*mlx_reset_peak_memory_)(void) = NULL;
int (*mlx_set_cache_limit_)(size_t* res, size_t limit) = NULL;
int (*mlx_set_memory_limit_)(size_t* res, size_t limit) = NULL;
int (*mlx_set_wired_limit_)(size_t* res, size_t limit) = NULL;
mlx_metal_device_info_t (*mlx_metal_device_info_)(void) = NULL;
int (*mlx_metal_is_available_)(bool* res) = NULL;
int (*mlx_metal_start_capture_)(const char* path) = NULL;
int (*mlx_metal_stop_capture_)(void) = NULL;
@@ -1211,14 +1162,6 @@ int (*mlx_gather_)(
const int* slice_sizes,
size_t slice_sizes_num,
const mlx_stream s) = NULL;
int (*mlx_gather_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
int axis,
const int* slice_sizes,
size_t slice_sizes_num,
const mlx_stream s) = NULL;
int (*mlx_gather_mm_)(
mlx_array* res,
const mlx_array a,
@@ -1540,15 +1483,6 @@ int (*mlx_put_along_axis_)(
const mlx_array values,
int axis,
const mlx_stream s) = NULL;
int (*mlx_qqmm_)(
mlx_array* res,
const mlx_array x,
const mlx_array w,
const mlx_array w_scales /* may be null */,
mlx_optional_int group_size,
mlx_optional_int bits,
const char* mode,
const mlx_stream s) = NULL;
int (*mlx_quantize_)(
mlx_vector_array* res,
const mlx_array w,
@@ -1632,13 +1566,6 @@ int (*mlx_scatter_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
int (*mlx_scatter_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array updates,
int axis,
const mlx_stream s) = NULL;
int (*mlx_scatter_add_)(
mlx_array* res,
const mlx_array a,
@@ -1647,13 +1574,6 @@ int (*mlx_scatter_add_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
int (*mlx_scatter_add_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array updates,
int axis,
const mlx_stream s) = NULL;
int (*mlx_scatter_add_axis_)(
mlx_array* res,
const mlx_array a,
@@ -1669,13 +1589,6 @@ int (*mlx_scatter_max_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
int (*mlx_scatter_max_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array updates,
int axis,
const mlx_stream s) = NULL;
int (*mlx_scatter_min_)(
mlx_array* res,
const mlx_array a,
@@ -1684,13 +1597,6 @@ int (*mlx_scatter_min_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
int (*mlx_scatter_min_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array updates,
int axis,
const mlx_stream s) = NULL;
int (*mlx_scatter_prod_)(
mlx_array* res,
const mlx_array a,
@@ -1699,13 +1605,6 @@ int (*mlx_scatter_prod_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
int (*mlx_scatter_prod_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array updates,
int axis,
const mlx_stream s) = NULL;
int (*mlx_segmented_mm_)(
mlx_array* res,
const mlx_array a,
@@ -2129,6 +2028,22 @@ mlx_string (*mlx_string_new_data_)(const char* str) = NULL;
int (*mlx_string_set_)(mlx_string* str, const mlx_string src) = NULL;
const char * (*mlx_string_data_)(mlx_string str) = NULL;
int (*mlx_string_free_)(mlx_string str) = NULL;
int (*mlx_detail_vmap_replace_)(
mlx_vector_array* res,
const mlx_vector_array inputs,
const mlx_vector_array s_inputs,
const mlx_vector_array s_outputs,
const int* in_axes,
size_t in_axes_num,
const int* out_axes,
size_t out_axes_num) = NULL;
int (*mlx_detail_vmap_trace_)(
mlx_vector_array* res_0,
mlx_vector_array* res_1,
const mlx_closure fun,
const mlx_vector_array inputs,
const int* in_axes,
size_t in_axes_num) = NULL;
int (*mlx_async_eval_)(const mlx_vector_array outputs) = NULL;
int (*mlx_checkpoint_)(mlx_closure* res, const mlx_closure fun) = NULL;
int (*mlx_custom_function_)(
@@ -2159,22 +2074,6 @@ int (*mlx_vjp_)(
const mlx_closure fun,
const mlx_vector_array primals,
const mlx_vector_array cotangents) = NULL;
int (*mlx_detail_vmap_replace_)(
mlx_vector_array* res,
const mlx_vector_array inputs,
const mlx_vector_array s_inputs,
const mlx_vector_array s_outputs,
const int* in_axes,
size_t in_axes_num,
const int* out_axes,
size_t out_axes_num) = NULL;
int (*mlx_detail_vmap_trace_)(
mlx_vector_array* res_0,
mlx_vector_array* res_1,
const mlx_closure fun,
const mlx_vector_array inputs,
const int* in_axes,
size_t in_axes_num) = NULL;
mlx_vector_array (*mlx_vector_array_new_)(void) = NULL;
int (*mlx_vector_array_set_)(mlx_vector_array* vec, const mlx_vector_array src) = NULL;
int (*mlx_vector_array_free_)(mlx_vector_array vec) = NULL;
@@ -2267,8 +2166,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_array_new_double);
CHECK_LOAD(handle, mlx_array_new_complex);
CHECK_LOAD(handle, mlx_array_new_data);
CHECK_LOAD(handle, mlx_array_new_data_managed);
CHECK_LOAD(handle, mlx_array_new_data_managed_payload);
CHECK_LOAD(handle, mlx_array_set);
CHECK_LOAD(handle, mlx_array_set_bool);
CHECK_LOAD(handle, mlx_array_set_int);
@@ -2364,7 +2261,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_disable_compile);
CHECK_LOAD(handle, mlx_enable_compile);
CHECK_LOAD(handle, mlx_set_compile_mode);
CHECK_LOAD(handle, mlx_cuda_is_available);
CHECK_LOAD(handle, mlx_device_new);
CHECK_LOAD(handle, mlx_device_new_type);
CHECK_LOAD(handle, mlx_device_free);
@@ -2375,16 +2271,11 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_device_get_type);
CHECK_LOAD(handle, mlx_get_default_device);
CHECK_LOAD(handle, mlx_set_default_device);
CHECK_LOAD(handle, mlx_device_is_available);
CHECK_LOAD(handle, mlx_device_count);
CHECK_LOAD(handle, mlx_device_info_new);
CHECK_LOAD(handle, mlx_device_info_get);
CHECK_LOAD(handle, mlx_device_info_free);
CHECK_LOAD(handle, mlx_device_info_has_key);
CHECK_LOAD(handle, mlx_device_info_is_string);
CHECK_LOAD(handle, mlx_device_info_get_string);
CHECK_LOAD(handle, mlx_device_info_get_size);
CHECK_LOAD(handle, mlx_device_info_get_keys);
CHECK_LOAD(handle, mlx_distributed_group_rank);
CHECK_LOAD(handle, mlx_distributed_group_size);
CHECK_LOAD(handle, mlx_distributed_group_split);
CHECK_LOAD(handle, mlx_distributed_is_available);
CHECK_LOAD(handle, mlx_distributed_init);
CHECK_LOAD(handle, mlx_distributed_all_gather);
CHECK_LOAD(handle, mlx_distributed_all_max);
CHECK_LOAD(handle, mlx_distributed_all_min);
@@ -2393,11 +2284,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_distributed_recv_like);
CHECK_LOAD(handle, mlx_distributed_send);
CHECK_LOAD(handle, mlx_distributed_sum_scatter);
CHECK_LOAD(handle, mlx_distributed_group_rank);
CHECK_LOAD(handle, mlx_distributed_group_size);
CHECK_LOAD(handle, mlx_distributed_group_split);
CHECK_LOAD(handle, mlx_distributed_is_available);
CHECK_LOAD(handle, mlx_distributed_init);
CHECK_LOAD(handle, mlx_set_error_handler);
CHECK_LOAD(handle, _mlx_error);
CHECK_LOAD(handle, mlx_export_function);
@@ -2439,7 +2325,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_fast_metal_kernel_apply);
CHECK_LOAD(handle, mlx_fast_rms_norm);
CHECK_LOAD(handle, mlx_fast_rope);
CHECK_LOAD(handle, mlx_fast_rope_dynamic);
CHECK_LOAD(handle, mlx_fast_scaled_dot_product_attention);
CHECK_LOAD(handle, mlx_fft_fft);
CHECK_LOAD(handle, mlx_fft_fft2);
@@ -2455,14 +2340,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_fft_rfft);
CHECK_LOAD(handle, mlx_fft_rfft2);
CHECK_LOAD(handle, mlx_fft_rfftn);
CHECK_LOAD(handle, mlx_load_reader);
CHECK_LOAD(handle, mlx_load);
CHECK_LOAD(handle, mlx_load_safetensors_reader);
CHECK_LOAD(handle, mlx_load_safetensors);
CHECK_LOAD(handle, mlx_save_writer);
CHECK_LOAD(handle, mlx_save);
CHECK_LOAD(handle, mlx_save_safetensors_writer);
CHECK_LOAD(handle, mlx_save_safetensors);
CHECK_LOAD(handle, mlx_io_reader_new);
CHECK_LOAD(handle, mlx_io_reader_descriptor);
CHECK_LOAD(handle, mlx_io_reader_tostring);
@@ -2471,6 +2348,14 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_io_writer_descriptor);
CHECK_LOAD(handle, mlx_io_writer_tostring);
CHECK_LOAD(handle, mlx_io_writer_free);
CHECK_LOAD(handle, mlx_load_reader);
CHECK_LOAD(handle, mlx_load);
CHECK_LOAD(handle, mlx_load_safetensors_reader);
CHECK_LOAD(handle, mlx_load_safetensors);
CHECK_LOAD(handle, mlx_save_writer);
CHECK_LOAD(handle, mlx_save);
CHECK_LOAD(handle, mlx_save_safetensors_writer);
CHECK_LOAD(handle, mlx_save_safetensors);
CHECK_LOAD(handle, mlx_linalg_cholesky);
CHECK_LOAD(handle, mlx_linalg_cholesky_inv);
CHECK_LOAD(handle, mlx_linalg_cross);
@@ -2515,6 +2400,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_set_cache_limit);
CHECK_LOAD(handle, mlx_set_memory_limit);
CHECK_LOAD(handle, mlx_set_wired_limit);
CHECK_LOAD(handle, mlx_metal_device_info);
CHECK_LOAD(handle, mlx_metal_is_available);
CHECK_LOAD(handle, mlx_metal_start_capture);
CHECK_LOAD(handle, mlx_metal_stop_capture);
@@ -2600,7 +2486,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_full);
CHECK_LOAD(handle, mlx_full_like);
CHECK_LOAD(handle, mlx_gather);
CHECK_LOAD(handle, mlx_gather_single);
CHECK_LOAD(handle, mlx_gather_mm);
CHECK_LOAD(handle, mlx_gather_qmm);
CHECK_LOAD(handle, mlx_greater);
@@ -2665,7 +2550,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_prod_axis);
CHECK_LOAD(handle, mlx_prod);
CHECK_LOAD(handle, mlx_put_along_axis);
CHECK_LOAD(handle, mlx_qqmm);
CHECK_LOAD(handle, mlx_quantize);
CHECK_LOAD(handle, mlx_quantized_matmul);
CHECK_LOAD(handle, mlx_radians);
@@ -2682,16 +2566,11 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_round);
CHECK_LOAD(handle, mlx_rsqrt);
CHECK_LOAD(handle, mlx_scatter);
CHECK_LOAD(handle, mlx_scatter_single);
CHECK_LOAD(handle, mlx_scatter_add);
CHECK_LOAD(handle, mlx_scatter_add_single);
CHECK_LOAD(handle, mlx_scatter_add_axis);
CHECK_LOAD(handle, mlx_scatter_max);
CHECK_LOAD(handle, mlx_scatter_max_single);
CHECK_LOAD(handle, mlx_scatter_min);
CHECK_LOAD(handle, mlx_scatter_min_single);
CHECK_LOAD(handle, mlx_scatter_prod);
CHECK_LOAD(handle, mlx_scatter_prod_single);
CHECK_LOAD(handle, mlx_segmented_mm);
CHECK_LOAD(handle, mlx_sigmoid);
CHECK_LOAD(handle, mlx_sign);
@@ -2786,6 +2665,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_string_set);
CHECK_LOAD(handle, mlx_string_data);
CHECK_LOAD(handle, mlx_string_free);
CHECK_LOAD(handle, mlx_detail_vmap_replace);
CHECK_LOAD(handle, mlx_detail_vmap_trace);
CHECK_LOAD(handle, mlx_async_eval);
CHECK_LOAD(handle, mlx_checkpoint);
CHECK_LOAD(handle, mlx_custom_function);
@@ -2794,8 +2675,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_jvp);
CHECK_LOAD(handle, mlx_value_and_grad);
CHECK_LOAD(handle, mlx_vjp);
CHECK_LOAD(handle, mlx_detail_vmap_replace);
CHECK_LOAD(handle, mlx_detail_vmap_trace);
CHECK_LOAD(handle, mlx_vector_array_new);
CHECK_LOAD(handle, mlx_vector_array_set);
CHECK_LOAD(handle, mlx_vector_array_free);

View File

File diff suppressed because it is too large Load Diff

View File

@@ -4,10 +4,6 @@
#define MLX_GENERATED_H
#include "dynamic.h"
{{ range .Functions }}
#define {{ .Name }} {{ .Name }}_mlx_gen_orig_
{{- end }}
#include "mlx/c/mlx.h"
{{ range .Functions }}
#undef {{ .Name }}