Compare commits

..

3 Commits

Author SHA1 Message Date
Bruce MacDonald
9ac1300805 fix lint 2026-02-17 16:04:10 -07:00
Bruce MacDonald
43d9907dd6 fix tests 2026-02-17 16:04:10 -07:00
Bruce MacDonald
91dc088e8b server: usage api
Add a new /api/usage endpoint that shows aggregate usage statistics per model since the server started.
2026-02-17 15:59:52 -07:00
24 changed files with 1271 additions and 1546 deletions

View File

@@ -1 +1 @@
v0.5.0
v0.4.1

View File

@@ -922,6 +922,19 @@ type UserResponse struct {
Plan string `json:"plan,omitempty"`
}
type UsageResponse struct {
// Start is the time the server started tracking usage (UTC, RFC 3339).
Start time.Time `json:"start"`
Usage []ModelUsageData `json:"usage"`
}
type ModelUsageData struct {
Model string `json:"model"`
Requests int64 `json:"requests"`
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
}
// Tensor describes the metadata for a given tensor.
type Tensor struct {
Name string `json:"name"`

View File

@@ -9,7 +9,6 @@ import (
"fmt"
"io"
"log/slog"
"net/http"
"os"
"path/filepath"
"strings"
@@ -84,24 +83,3 @@ func Sign(ctx context.Context, bts []byte) (string, error) {
// signature is <pubkey>:<signature>
return fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)), nil
}
// SignRequest adds a nonce query parameter and an Authorization header with
// an Ed25519 signature to req.
func SignRequest(ctx context.Context, req *http.Request) error {
nonce, err := NewNonce(rand.Reader, 16)
if err != nil {
return err
}
q := req.URL.Query()
q.Set("nonce", nonce)
req.URL.RawQuery = q.Encode()
data := []byte(fmt.Sprintf("%s,%s", req.Method, req.URL.RequestURI()))
signature, err := Sign(ctx, data)
if err != nil {
return err
}
req.Header.Set("Authorization", signature)
return nil
}

View File

@@ -1900,21 +1900,6 @@ func runInteractiveTUI(cmd *cobra.Command) {
return
}
if version.Version != "0.0.0" && version.IsOfficialInstall() && version.IsLocalHost(envconfig.Host()) {
if version.HasCachedUpdate() {
fmt.Print("A new version of Ollama is available. Run \"ollama update\" to install.\n\n")
_ = version.ClearCachedUpdate()
}
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if available, err := version.CheckForUpdate(ctx); err == nil && available {
_ = version.CacheAvailableUpdate()
}
}()
}
// Selector adapters for tui
singleSelector := func(title string, items []config.ModelItem, current string) (string, error) {
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
@@ -2332,18 +2317,6 @@ func NewCLI() *cobra.Command {
}
}
updateCmd := &cobra.Command{
Use: "update",
Short: "Update Ollama to the latest version",
Args: cobra.ExactArgs(0),
RunE: func(cmd *cobra.Command, args []string) error {
force, _ := cmd.Flags().GetBool("force")
_ = version.ClearCachedUpdate()
return version.DoUpdate(force)
},
}
updateCmd.Flags().BoolP("force", "f", false, "Force update even if installed via a package manager")
rootCmd.AddCommand(
serveCmd,
createCmd,
@@ -2361,7 +2334,6 @@ func NewCLI() *cobra.Command {
copyCmd,
deleteCmd,
runnerCmd,
updateCmd,
config.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
)

View File

@@ -415,12 +415,6 @@ type multiSelectorModel struct {
cancelled bool
confirmed bool
width int
// multi enables full multi-select editing mode. The zero value (false)
// shows a single-select picker where Enter adds the chosen model to
// the existing list. Tab toggles between modes.
multi bool
singleAdd string // model picked in single mode
}
func newMultiSelectorModel(title string, items []SelectItem, preChecked []string) multiSelectorModel {
@@ -435,23 +429,13 @@ func newMultiSelectorModel(title string, items []SelectItem, preChecked []string
m.itemIndex[item.Name] = i
}
// Reverse order so preChecked[0] (the current default) ends up last
// in checkOrder, matching the "last checked = default" convention.
for i := len(preChecked) - 1; i >= 0; i-- {
if idx, ok := m.itemIndex[preChecked[i]]; ok {
for _, name := range preChecked {
if idx, ok := m.itemIndex[name]; ok {
m.checked[idx] = true
m.checkOrder = append(m.checkOrder, idx)
}
}
// Position cursor on the current default model
if len(preChecked) > 0 {
if idx, ok := m.itemIndex[preChecked[0]]; ok {
m.cursor = idx
m.updateScroll(m.otherStart())
}
}
return m
}
@@ -562,25 +546,14 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.cancelled = true
return m, tea.Quit
case tea.KeyTab:
m.multi = !m.multi
case tea.KeyEnter:
if !m.multi {
if len(filtered) > 0 && m.cursor < len(filtered) {
m.singleAdd = filtered[m.cursor].Name
m.confirmed = true
return m, tea.Quit
}
} else if len(m.checkOrder) > 0 {
if len(m.checkOrder) > 0 {
m.confirmed = true
return m, tea.Quit
}
case tea.KeySpace:
if m.multi {
m.toggleItem()
}
m.toggleItem()
case tea.KeyUp:
if m.cursor > 0 {
@@ -619,9 +592,7 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// On some terminals (e.g. Windows PowerShell), space arrives as
// KeyRunes instead of KeySpace. Intercept it so toggle still works.
if len(msg.Runes) == 1 && msg.Runes[0] == ' ' {
if m.multi {
m.toggleItem()
}
m.toggleItem()
} else {
m.filter += string(msg.Runes)
m.cursor = 0
@@ -633,19 +604,6 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, nil
}
func (m multiSelectorModel) renderSingleItem(s *strings.Builder, item SelectItem, idx int) {
if idx == m.cursor {
s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
} else {
s.WriteString(selectorItemStyle.Render(item.Name))
}
s.WriteString("\n")
if item.Description != "" {
s.WriteString(selectorDescLineStyle.Render(item.Description))
s.WriteString("\n")
}
}
func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem, idx int) {
origIdx := m.itemIndex[item.Name]
@@ -657,7 +615,7 @@ func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem,
}
suffix := ""
if len(m.checkOrder) > 0 && m.checkOrder[len(m.checkOrder)-1] == origIdx {
if len(m.checkOrder) > 0 && m.checkOrder[0] == origIdx {
suffix = " " + selectorDefaultTagStyle.Render("(default)")
}
@@ -679,11 +637,6 @@ func (m multiSelectorModel) View() string {
return ""
}
renderItem := m.renderSingleItem
if m.multi {
renderItem = m.renderMultiItem
}
var s strings.Builder
s.WriteString(selectorTitleStyle.Render(m.title))
@@ -708,7 +661,7 @@ func (m multiSelectorModel) View() string {
if idx >= len(filtered) {
break
}
renderItem(&s, filtered[idx], idx)
m.renderMultiItem(&s, filtered[idx], idx)
}
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
@@ -731,7 +684,7 @@ func (m multiSelectorModel) View() string {
s.WriteString(sectionHeaderStyle.Render("Recommended"))
s.WriteString("\n")
for _, idx := range recItems {
renderItem(&s, filtered[idx], idx)
m.renderMultiItem(&s, filtered[idx], idx)
}
}
@@ -751,7 +704,7 @@ func (m multiSelectorModel) View() string {
if idx >= len(otherItems) {
break
}
renderItem(&s, filtered[otherItems[idx]], otherItems[idx])
m.renderMultiItem(&s, filtered[otherItems[idx]], otherItems[idx])
}
if remaining := len(otherItems) - m.scrollOffset - displayCount; remaining > 0 {
@@ -763,18 +716,15 @@ func (m multiSelectorModel) View() string {
s.WriteString("\n")
if !m.multi {
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • tab add multiple • esc cancel"))
count := m.selectedCount()
if count == 0 {
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
} else {
count := m.selectedCount()
if count == 0 {
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
} else {
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
}
s.WriteString("\n\n")
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • tab select single • enter confirm • esc cancel"))
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
}
s.WriteString("\n\n")
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • enter confirm • esc cancel"))
result := s.String()
if m.width > 0 {
@@ -797,28 +747,18 @@ func SelectMultiple(title string, items []SelectItem, preChecked []string) ([]st
}
fm := finalModel.(multiSelectorModel)
if fm.cancelled || !fm.confirmed {
if fm.cancelled {
return nil, ErrCancelled
}
// Single-add mode: prepend the picked model, keep existing models deduped
if fm.singleAdd != "" {
result := []string{fm.singleAdd}
for _, name := range preChecked {
if name != fm.singleAdd {
result = append(result, name)
}
}
return result, nil
if !fm.confirmed {
return nil, ErrCancelled
}
// Multi-edit mode: last checked is default (first in result)
last := fm.checkOrder[len(fm.checkOrder)-1]
result := []string{fm.items[last].Name}
var result []string
for _, idx := range fm.checkOrder {
if idx != last {
result = append(result, fm.items[idx].Name)
}
result = append(result, fm.items[idx].Name)
}
return result, nil
}

View File

@@ -539,7 +539,6 @@ func TestMultiView_CursorIndicator(t *testing.T) {
func TestMultiView_CheckedItemShowsX(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), []string{"a"})
m.multi = true
content := m.View()
if !strings.Contains(content, "[x]") {
@@ -551,18 +550,11 @@ func TestMultiView_CheckedItemShowsX(t *testing.T) {
}
func TestMultiView_DefaultTag(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "b"})
m.multi = true
m := newMultiSelectorModel("Pick:", items("a", "b"), []string{"a"})
content := m.View()
if !strings.Contains(content, "(default)") {
t.Error("should have (default) tag")
}
// preChecked[0] ("a") should be the default (last in checkOrder)
aIdx := strings.Index(content, "a")
defaultIdx := strings.Index(content, "(default)")
if defaultIdx < aIdx {
t.Error("(default) tag should appear after 'a' (the current default)")
t.Error("first checked item should have (default) tag")
}
}
@@ -593,7 +585,6 @@ func TestMultiView_OverflowIndicator(t *testing.T) {
func TestMultiUpdate_SpaceTogglesItem(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
m.multi = true
m.cursor = 1
// Simulate space delivered as tea.KeySpace
@@ -610,7 +601,6 @@ func TestMultiUpdate_SpaceTogglesItem(t *testing.T) {
func TestMultiUpdate_SpaceRuneTogglesItem(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
m.multi = true
m.cursor = 1
// Simulate space delivered as tea.KeyRunes (Windows PowerShell behavior)
@@ -628,161 +618,6 @@ func TestMultiUpdate_SpaceRuneTogglesItem(t *testing.T) {
}
}
// --- Single-add mode ---
func TestMulti_StartsInSingleMode(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
if m.multi {
t.Error("should start in single mode (multi=false)")
}
}
func TestMulti_SingleModeNoCheckboxes(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
content := m.View()
if strings.Contains(content, "[x]") || strings.Contains(content, "[ ]") {
t.Error("single mode should not show checkboxes")
}
if !strings.Contains(content, "▸") {
t.Error("single mode should show cursor indicator")
}
}
func TestMulti_SingleModeEnterPicksItem(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
m.cursor = 1
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
m = updated.(multiSelectorModel)
if m.singleAdd != "b" {
t.Errorf("enter in single mode should pick cursor item, got %q", m.singleAdd)
}
if !m.confirmed {
t.Error("should set confirmed")
}
}
func TestMulti_SingleModeSpaceIsNoop(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
m.cursor = 0
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeySpace})
m = updated.(multiSelectorModel)
if len(m.checked) != 0 {
t.Error("space in single mode should not toggle items")
}
}
func TestMulti_SingleModeSpaceRuneIsNoop(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
m.cursor = 0
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{' '}})
m = updated.(multiSelectorModel)
if len(m.checked) != 0 {
t.Error("space rune in single mode should not toggle items")
}
if m.filter != "" {
t.Error("space rune in single mode should not add to filter")
}
}
func TestMulti_TabTogglesMode(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
if m.multi {
t.Fatal("should start in single mode")
}
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyTab})
m = updated.(multiSelectorModel)
if !m.multi {
t.Error("tab should switch to multi mode")
}
updated, _ = m.Update(tea.KeyMsg{Type: tea.KeyTab})
m = updated.(multiSelectorModel)
if m.multi {
t.Error("tab should switch back to single mode")
}
}
func TestMulti_SingleModeHelpText(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a"), nil)
content := m.View()
if !strings.Contains(content, "tab add multiple") {
t.Error("single mode should show 'tab add multiple' in help")
}
}
func TestMulti_MultiModeHelpText(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("a"), nil)
m.multi = true
content := m.View()
if !strings.Contains(content, "tab select single") {
t.Error("multi mode should show 'tab select single' in help")
}
}
// --- preChecked initialization order ---
func TestMulti_PreCheckedDefaultIsLast(t *testing.T) {
// preChecked[0] ("a") is the current default and should end up
// last in checkOrder so it gets the (default) tag.
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "b", "c"})
if len(m.checkOrder) != 3 {
t.Fatalf("expected 3 in checkOrder, got %d", len(m.checkOrder))
}
lastIdx := m.checkOrder[len(m.checkOrder)-1]
if m.items[lastIdx].Name != "a" {
t.Errorf("preChecked[0] should be last in checkOrder, got %q", m.items[lastIdx].Name)
}
}
func TestMulti_CursorOnDefaultModel(t *testing.T) {
// preChecked[0] ("b") is the default; cursor should start on it
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"b", "c"})
if m.cursor != 1 {
t.Errorf("cursor should be on preChecked[0] ('b') at index 1, got %d", m.cursor)
}
}
// --- Multi-mode last-checked is default ---
func TestMulti_LastCheckedIsDefault(t *testing.T) {
m := newMultiSelectorModel("Pick:", items("alpha", "beta", "gamma"), nil)
m.multi = true
// Check "alpha" then "gamma"
m.cursor = 0
m.toggleItem()
m.cursor = 2
m.toggleItem()
// Last checked ("gamma") should be at the end of checkOrder
lastIdx := m.checkOrder[len(m.checkOrder)-1]
if m.items[lastIdx].Name != "gamma" {
t.Errorf("last checked should be 'gamma', got %q", m.items[lastIdx].Name)
}
// The (default) tag renders based on checkOrder[len-1]
content := m.View()
if !strings.Contains(content, "(default)") {
t.Fatal("should show (default) tag")
}
// "alpha" line should NOT have the default tag
for _, line := range strings.Split(content, "\n") {
if strings.Contains(line, "alpha") && strings.Contains(line, "(default)") {
t.Error("'alpha' (first checked) should not have (default) tag")
}
}
}
// Key message helpers for testing
type keyType = int

View File

@@ -429,24 +429,8 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
if m.multiModalSelector.confirmed {
var selected []string
if m.multiModalSelector.singleAdd != "" {
// Single-add mode: prepend picked model, keep existing deduped
selected = []string{m.multiModalSelector.singleAdd}
for _, name := range config.IntegrationModels(m.items[m.cursor].integration) {
if name != m.multiModalSelector.singleAdd {
selected = append(selected, name)
}
}
} else {
// Last checked is default (first in result)
co := m.multiModalSelector.checkOrder
last := co[len(co)-1]
selected = []string{m.multiModalSelector.items[last].Name}
for _, idx := range co {
if idx != last {
selected = append(selected, m.multiModalSelector.items[idx].Name)
}
}
for _, idx := range m.multiModalSelector.checkOrder {
selected = append(selected, m.multiModalSelector.items[idx].Name)
}
if len(selected) > 0 {
m.changeModels = selected

View File

@@ -15,6 +15,7 @@
- [Push a Model](#push-a-model)
- [Generate Embeddings](#generate-embeddings)
- [List Running Models](#list-running-models)
- [Usage](#usage)
- [Version](#version)
- [Experimental: Image Generation](#image-generation-experimental)
@@ -1854,6 +1855,53 @@ curl http://localhost:11434/api/embeddings -d '{
}
```
## Usage
```
GET /api/usage
```
Show aggregate usage statistics per model since the server started. All timestamps are UTC in RFC 3339 format.
### Examples
#### Request
```shell
curl http://localhost:11434/api/usage
```
#### Response
```json
{
"start": "2025-01-27T20:00:00Z",
"usage": [
{
"model": "llama3.2",
"requests": 5,
"prompt_tokens": 130,
"completion_tokens": 890
},
{
"model": "deepseek-r1",
"requests": 2,
"prompt_tokens": 48,
"completion_tokens": 312
}
]
}
```
#### Response fields
- `start`: when the server started tracking usage (UTC, RFC 3339)
- `usage`: list of per-model usage statistics
- `model`: model name
- `requests`: total number of completed requests
- `prompt_tokens`: total prompt tokens evaluated
- `completion_tokens`: total completion tokens generated
## Version
```

View File

@@ -91,6 +91,8 @@ type Server struct {
aliasesOnce sync.Once
aliases *store
aliasesErr error
lowVRAM bool
usage *UsageTracker
}
func init() {
@@ -289,6 +291,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
c.Header("Content-Type", contentType)
fn := func(resp api.GenerateResponse) error {
if resp.Done {
s.usage.Record(origModel, resp.PromptEvalCount, resp.EvalCount)
}
resp.Model = origModel
resp.RemoteModel = m.Config.RemoteModel
resp.RemoteHost = m.Config.RemoteHost
@@ -595,6 +601,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
res.Context = tokens
}
s.usage.Record(req.Model, cr.PromptEvalCount, cr.EvalCount)
}
if builtinParser != nil {
@@ -1622,6 +1630,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.POST("/api/experimental/aliases", s.CreateAliasHandler)
r.DELETE("/api/experimental/aliases", s.DeleteAliasHandler)
r.GET("/api/usage", s.UsageHandler)
// Inference
r.GET("/api/ps", s.PsHandler)
r.POST("/api/generate", s.GenerateHandler)
@@ -1692,7 +1702,7 @@ func Serve(ln net.Listener) error {
}
}
s := &Server{addr: ln.Addr()}
s := &Server{addr: ln.Addr(), usage: NewUsageTracker()}
var rc *ollama.Registry
if useClient2 {
@@ -1927,6 +1937,10 @@ func (s *Server) SignoutHandler(c *gin.Context) {
c.JSON(http.StatusOK, nil)
}
func (s *Server) UsageHandler(c *gin.Context) {
c.JSON(http.StatusOK, s.usage.Stats())
}
func (s *Server) PsHandler(c *gin.Context) {
models := []api.ProcessModelResponse{}
@@ -2097,6 +2111,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
c.Header("Content-Type", contentType)
fn := func(resp api.ChatResponse) error {
if resp.Done {
s.usage.Record(origModel, resp.PromptEvalCount, resp.EvalCount)
}
resp.Model = origModel
resp.RemoteModel = m.Config.RemoteModel
resp.RemoteHost = m.Config.RemoteHost
@@ -2317,6 +2335,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
res.DoneReason = r.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
s.usage.Record(req.Model, r.PromptEvalCount, r.EvalCount)
}
if builtinParser != nil {

View File

@@ -30,6 +30,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
}
s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
@@ -224,6 +225,7 @@ func TestChatDebugRenderOnly(t *testing.T) {
}
s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),

View File

@@ -35,6 +35,7 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) {
}
s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
@@ -220,6 +221,7 @@ func TestGenerateWithDebugRenderOnly(t *testing.T) {
}
s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),

View File

@@ -88,19 +88,39 @@ func TestGenerateChatRemote(t *testing.T) {
if r.Method != http.MethodPost {
t.Errorf("Expected POST request, got %s", r.Method)
}
if r.URL.Path != "/api/chat" {
t.Errorf("Expected path '/api/chat', got %s", r.URL.Path)
}
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
resp := api.ChatResponse{
Model: "test",
Done: true,
DoneReason: "load",
}
if err := json.NewEncoder(w).Encode(&resp); err != nil {
t.Fatal(err)
switch r.URL.Path {
case "/api/chat":
resp := api.ChatResponse{
Model: "test",
Done: true,
DoneReason: "load",
Metrics: api.Metrics{
PromptEvalCount: 10,
EvalCount: 20,
},
}
if err := json.NewEncoder(w).Encode(&resp); err != nil {
t.Fatal(err)
}
case "/api/generate":
resp := api.GenerateResponse{
Model: "test",
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{
PromptEvalCount: 5,
EvalCount: 15,
},
}
if err := json.NewEncoder(w).Encode(&resp); err != nil {
t.Fatal(err)
}
default:
t.Errorf("unexpected path %s", r.URL.Path)
}
}))
defer rs.Close()
@@ -111,7 +131,7 @@ func TestGenerateChatRemote(t *testing.T) {
}
t.Setenv("OLLAMA_REMOTES", p.Hostname())
s := Server{}
s := Server{usage: NewUsageTracker()}
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-cloud",
RemoteHost: rs.URL,
@@ -159,6 +179,61 @@ func TestGenerateChatRemote(t *testing.T) {
t.Errorf("expected done reason load, got %s", actual.DoneReason)
}
})
t.Run("remote chat usage tracking", func(t *testing.T) {
stats := s.usage.Stats()
found := false
for _, m := range stats.Usage {
if m.Model == "test-cloud" {
found = true
if m.Requests != 1 {
t.Errorf("expected 1 request, got %d", m.Requests)
}
if m.PromptTokens != 10 {
t.Errorf("expected 10 prompt tokens, got %d", m.PromptTokens)
}
if m.CompletionTokens != 20 {
t.Errorf("expected 20 completion tokens, got %d", m.CompletionTokens)
}
}
}
if !found {
t.Error("expected usage entry for test-cloud")
}
})
t.Run("remote generate usage tracking", func(t *testing.T) {
// Reset the tracker for a clean test
s.usage = NewUsageTracker()
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-cloud",
Prompt: "hello",
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
stats := s.usage.Stats()
found := false
for _, m := range stats.Usage {
if m.Model == "test-cloud" {
found = true
if m.Requests != 1 {
t.Errorf("expected 1 request, got %d", m.Requests)
}
if m.PromptTokens != 5 {
t.Errorf("expected 5 prompt tokens, got %d", m.PromptTokens)
}
if m.CompletionTokens != 15 {
t.Errorf("expected 15 completion tokens, got %d", m.CompletionTokens)
}
}
}
if !found {
t.Error("expected usage entry for test-cloud")
}
})
}
func TestGenerateChat(t *testing.T) {
@@ -177,6 +252,7 @@ func TestGenerateChat(t *testing.T) {
}
s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
@@ -894,6 +970,7 @@ func TestGenerate(t *testing.T) {
}
s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
@@ -1378,6 +1455,7 @@ func TestGenerateLogprobs(t *testing.T) {
}
s := &Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
@@ -1558,6 +1636,7 @@ func TestChatLogprobs(t *testing.T) {
}
s := &Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
@@ -1668,6 +1747,7 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
}
s := &Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
@@ -2114,6 +2194,7 @@ func TestGenerateUnload(t *testing.T) {
var loadFnCalled bool
s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
@@ -2215,6 +2296,7 @@ func TestGenerateWithImages(t *testing.T) {
}
s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
@@ -2393,6 +2475,7 @@ func TestImageGenerateStreamFalse(t *testing.T) {
opts := api.DefaultOptions()
s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),

View File

@@ -255,6 +255,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
}
s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
@@ -406,6 +407,7 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
}
s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
@@ -588,6 +590,7 @@ func TestChatHarmonyParserStreaming(t *testing.T) {
}
s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),

62
server/usage.go Normal file
View File

@@ -0,0 +1,62 @@
package server
import (
"sync"
"time"
"github.com/ollama/ollama/api"
)
type ModelUsage struct {
Requests int64
PromptTokens int64
CompletionTokens int64
}
type UsageTracker struct {
mu sync.Mutex
start time.Time
models map[string]*ModelUsage
}
func NewUsageTracker() *UsageTracker {
return &UsageTracker{
start: time.Now().UTC(),
models: make(map[string]*ModelUsage),
}
}
func (u *UsageTracker) Record(model string, promptTokens, completionTokens int) {
u.mu.Lock()
defer u.mu.Unlock()
m, ok := u.models[model]
if !ok {
m = &ModelUsage{}
u.models[model] = m
}
m.Requests++
m.PromptTokens += int64(promptTokens)
m.CompletionTokens += int64(completionTokens)
}
func (u *UsageTracker) Stats() api.UsageResponse {
u.mu.Lock()
defer u.mu.Unlock()
byModel := make([]api.ModelUsageData, 0, len(u.models))
for model, usage := range u.models {
byModel = append(byModel, api.ModelUsageData{
Model: model,
Requests: usage.Requests,
PromptTokens: usage.PromptTokens,
CompletionTokens: usage.CompletionTokens,
})
}
return api.UsageResponse{
Start: u.start,
Usage: byModel,
}
}

136
server/usage_test.go Normal file
View File

@@ -0,0 +1,136 @@
package server
import (
"encoding/json"
"net/http"
"net/http/httptest"
"sync"
"testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
)
func TestUsageTrackerRecord(t *testing.T) {
tracker := NewUsageTracker()
tracker.Record("model-a", 10, 20)
tracker.Record("model-a", 5, 15)
tracker.Record("model-b", 100, 200)
stats := tracker.Stats()
if len(stats.Usage) != 2 {
t.Fatalf("expected 2 models, got %d", len(stats.Usage))
}
lookup := make(map[string]api.ModelUsageData)
for _, m := range stats.Usage {
lookup[m.Model] = m
}
a := lookup["model-a"]
if a.Requests != 2 {
t.Errorf("model-a requests: expected 2, got %d", a.Requests)
}
if a.PromptTokens != 15 {
t.Errorf("model-a prompt tokens: expected 15, got %d", a.PromptTokens)
}
if a.CompletionTokens != 35 {
t.Errorf("model-a completion tokens: expected 35, got %d", a.CompletionTokens)
}
b := lookup["model-b"]
if b.Requests != 1 {
t.Errorf("model-b requests: expected 1, got %d", b.Requests)
}
if b.PromptTokens != 100 {
t.Errorf("model-b prompt tokens: expected 100, got %d", b.PromptTokens)
}
if b.CompletionTokens != 200 {
t.Errorf("model-b completion tokens: expected 200, got %d", b.CompletionTokens)
}
}
func TestUsageTrackerConcurrent(t *testing.T) {
tracker := NewUsageTracker()
var wg sync.WaitGroup
for range 100 {
wg.Add(1)
go func() {
defer wg.Done()
tracker.Record("model-a", 1, 2)
}()
}
wg.Wait()
stats := tracker.Stats()
if len(stats.Usage) != 1 {
t.Fatalf("expected 1 model, got %d", len(stats.Usage))
}
m := stats.Usage[0]
if m.Requests != 100 {
t.Errorf("requests: expected 100, got %d", m.Requests)
}
if m.PromptTokens != 100 {
t.Errorf("prompt tokens: expected 100, got %d", m.PromptTokens)
}
if m.CompletionTokens != 200 {
t.Errorf("completion tokens: expected 200, got %d", m.CompletionTokens)
}
}
func TestUsageTrackerStart(t *testing.T) {
tracker := NewUsageTracker()
stats := tracker.Stats()
if stats.Start.IsZero() {
t.Error("expected non-zero start time")
}
}
func TestUsageHandler(t *testing.T) {
gin.SetMode(gin.TestMode)
s := &Server{
usage: NewUsageTracker(),
}
s.usage.Record("llama3", 50, 100)
s.usage.Record("llama3", 25, 50)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/api/usage", nil)
s.UsageHandler(c)
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
var resp api.UsageResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if len(resp.Usage) != 1 {
t.Fatalf("expected 1 model, got %d", len(resp.Usage))
}
m := resp.Usage[0]
if m.Model != "llama3" {
t.Errorf("expected model llama3, got %s", m.Model)
}
if m.Requests != 2 {
t.Errorf("expected 2 requests, got %d", m.Requests)
}
if m.PromptTokens != 75 {
t.Errorf("expected 75 prompt tokens, got %d", m.PromptTokens)
}
if m.CompletionTokens != 150 {
t.Errorf("expected 150 completion tokens, got %d", m.CompletionTokens)
}
}

View File

@@ -1,190 +0,0 @@
package version
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/ollama/ollama/auth"
)
var updateCheckURLBase = "https://ollama.com"
// CheckForUpdate calls the ollama.com update API and reports whether a
// newer version is available.
func CheckForUpdate(ctx context.Context) (bool, error) {
requestURL, err := url.Parse(updateCheckURLBase + "/api/update")
if err != nil {
return false, fmt.Errorf("parse update URL: %w", err)
}
query := requestURL.Query()
query.Add("os", runtime.GOOS)
query.Add("arch", runtime.GOARCH)
query.Add("version", Version)
requestURL.RawQuery = query.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
if err != nil {
return false, fmt.Errorf("create request: %w", err)
}
_ = auth.SignRequest(ctx, req)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return false, fmt.Errorf("update check request: %w", err)
}
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK, nil
}
func cacheFilePath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".ollama", "update"), nil
}
// CacheAvailableUpdate creates the update marker file.
func CacheAvailableUpdate() error {
path, err := cacheFilePath()
if err != nil {
return err
}
f, err := os.Create(path)
if err != nil {
return err
}
return f.Close()
}
// HasCachedUpdate reports whether a non-stale update marker exists.
func HasCachedUpdate() bool {
path, err := cacheFilePath()
if err != nil {
return false
}
fi, err := os.Stat(path)
if err != nil {
return false
}
return time.Since(fi.ModTime()) <= 24*time.Hour
}
// ClearCachedUpdate removes the update marker file.
func ClearCachedUpdate() error {
path, err := cacheFilePath()
if err != nil {
return err
}
err = os.Remove(path)
if os.IsNotExist(err) {
return nil
}
return err
}
func IsOfficialInstall() bool {
exe, err := os.Executable()
if err != nil {
return false
}
exe, err = filepath.EvalSymlinks(exe)
if err != nil {
return false
}
switch runtime.GOOS {
case "windows":
localAppData := os.Getenv("LOCALAPPDATA")
if localAppData == "" {
return false
}
return strings.HasPrefix(strings.ToLower(exe), strings.ToLower(filepath.Join(localAppData, "Programs", "Ollama")+string(filepath.Separator)))
case "darwin":
return strings.HasPrefix(exe, "/Applications/Ollama.app/")
default:
dir := filepath.Dir(exe)
return dir == "/usr/local/bin" || dir == "/usr/bin" || dir == "/bin"
}
}
// DoUpdate downloads and runs the platform-appropriate install script.
func DoUpdate(force bool) error {
if !force && !IsOfficialInstall() {
return fmt.Errorf("ollama appears to be installed through a package manager. Please update it using your package manager")
}
var scriptURL, tmpPattern, shell string
switch runtime.GOOS {
case "windows":
scriptURL = "https://ollama.com/install.ps1"
tmpPattern = "ollama-install-*.ps1"
shell = "powershell"
default:
scriptURL = "https://ollama.com/install.sh"
tmpPattern = "ollama-install-*.sh"
shell = "sh"
}
resp, err := http.Get(scriptURL)
if err != nil {
return fmt.Errorf("download install script: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download install script: status %d", resp.StatusCode)
}
tmpFile, err := os.CreateTemp("", tmpPattern)
if err != nil {
return fmt.Errorf("create temp file: %w", err)
}
defer os.Remove(tmpFile.Name())
if _, err := io.Copy(tmpFile, resp.Body); err != nil {
tmpFile.Close()
return fmt.Errorf("write install script: %w", err)
}
tmpFile.Close()
cmd := exec.Command(shell, tmpFile.Name())
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
// IsLocalHost reports whether the configured Ollama host points to the
// local machine.
func IsLocalHost(host *url.URL) bool {
hostname := host.Hostname()
switch hostname {
case "", "127.0.0.1", "localhost", "::1", "0.0.0.0":
return true
}
if ip := net.ParseIP(hostname); ip != nil {
return ip.IsLoopback()
}
return false
}

View File

@@ -1,146 +0,0 @@
package version
import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"runtime"
"testing"
"time"
)
func setHome(t *testing.T, dir string) {
t.Helper()
if runtime.GOOS == "windows" {
t.Setenv("USERPROFILE", dir)
} else {
t.Setenv("HOME", dir)
}
}
func TestCheckForUpdate(t *testing.T) {
t.Run("update available", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("os") == "" || r.URL.Query().Get("arch") == "" || r.URL.Query().Get("version") == "" {
t.Error("missing expected query parameters")
}
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
old := updateCheckURLBase
updateCheckURLBase = ts.URL
defer func() { updateCheckURLBase = old }()
available, err := CheckForUpdate(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !available {
t.Fatal("expected update to be available")
}
})
t.Run("up to date", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
defer ts.Close()
old := updateCheckURLBase
updateCheckURLBase = ts.URL
defer func() { updateCheckURLBase = old }()
available, err := CheckForUpdate(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if available {
t.Fatal("expected no update available")
}
})
t.Run("network error", func(t *testing.T) {
old := updateCheckURLBase
updateCheckURLBase = "http://localhost:1"
defer func() { updateCheckURLBase = old }()
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
_, err := CheckForUpdate(ctx)
if err == nil {
t.Fatal("expected error for unreachable server")
}
})
}
func TestCacheRoundTrip(t *testing.T) {
tmp := t.TempDir()
setHome(t, tmp)
os.MkdirAll(filepath.Join(tmp, ".ollama"), 0o755)
if err := CacheAvailableUpdate(); err != nil {
t.Fatalf("cache write: %v", err)
}
if !HasCachedUpdate() {
t.Fatal("expected cached update to be present")
}
if err := ClearCachedUpdate(); err != nil {
t.Fatalf("cache clear: %v", err)
}
if HasCachedUpdate() {
t.Fatal("expected no cached update after clear")
}
}
func TestHasCachedUpdateStale(t *testing.T) {
tmp := t.TempDir()
setHome(t, tmp)
os.MkdirAll(filepath.Join(tmp, ".ollama"), 0o755)
if err := CacheAvailableUpdate(); err != nil {
t.Fatalf("cache write: %v", err)
}
// Backdate the file to make it stale
path := filepath.Join(tmp, ".ollama", "update")
staleTime := time.Now().Add(-25 * time.Hour)
os.Chtimes(path, staleTime, staleTime)
if HasCachedUpdate() {
t.Fatal("expected no cached update for stale file")
}
}
func TestIsLocalHost(t *testing.T) {
tests := []struct {
host string
local bool
}{
{"http://127.0.0.1:11434", true},
{"http://localhost:11434", true},
{"http://[::1]:11434", true},
{"http://0.0.0.0:11434", true},
{"http://remote.example.com:11434", false},
{"http://192.168.1.100:11434", false},
}
for _, tt := range tests {
t.Run(tt.host, func(t *testing.T) {
u, err := url.Parse(tt.host)
if err != nil {
t.Fatalf("parse URL: %v", err)
}
if got := IsLocalHost(u); got != tt.local {
t.Errorf("IsLocalHost(%s) = %v, want %v", tt.host, got, tt.local)
}
})
}
}

View File

@@ -16,10 +16,10 @@ import (
)
type Function struct {
Name string
ReturnType string
Params string
ParamNames []string
Name string
ReturnType string
Params string
ParamNames []string
NeedsARM64Guard bool
}
@@ -29,11 +29,6 @@ func findHeaders(directory string) ([]string, error) {
if err != nil {
return err
}
// Private headers contain C++ implementation helpers and are not part of
// the C API surface; parsing them can produce invalid wrapper signatures.
if d.IsDir() && d.Name() == "private" {
return fs.SkipDir
}
if !d.IsDir() && strings.HasSuffix(path, ".h") {
headers = append(headers, path)
}
@@ -199,10 +194,10 @@ func parseFunctions(content string) []Function {
needsGuard := needsARM64Guard(funcName, returnType, params)
functions = append(functions, Function{
Name: funcName,
ReturnType: returnType,
Params: params,
ParamNames: paramNames,
Name: funcName,
ReturnType: returnType,
Params: params,
ParamNames: paramNames,
NeedsARM64Guard: needsGuard,
})
}

View File

@@ -20,8 +20,6 @@ mlx_array (*mlx_array_new_float64_ptr)(double val) = NULL;
mlx_array (*mlx_array_new_double_ptr)(double val) = NULL;
mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val) = NULL;
mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype) = NULL;
mlx_array (*mlx_array_new_data_managed_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) = NULL;
mlx_array (*mlx_array_new_data_managed_payload_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) = NULL;
int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src) = NULL;
int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val) = NULL;
int (*mlx_array_set_int_ptr)(mlx_array* arr, int val) = NULL;
@@ -51,7 +49,7 @@ int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr) = NULL;
int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr) = NULL;
int (*mlx_array_item_complex64_ptr)(mlx_complex64_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_complex64_ptr)(float _Complex* res, const mlx_array arr) = NULL;
#if defined(__aarch64__) || defined(_M_ARM64)
int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr) = NULL;
#endif
@@ -69,7 +67,7 @@ const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr) = NULL;
const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr) = NULL;
const float* (*mlx_array_data_float32_ptr)(const mlx_array arr) = NULL;
const double* (*mlx_array_data_float64_ptr)(const mlx_array arr) = NULL;
const mlx_complex64_t* (*mlx_array_data_complex64_ptr)(const mlx_array arr) = NULL;
const float _Complex* (*mlx_array_data_complex64_ptr)(const mlx_array arr) = NULL;
#if defined(__aarch64__) || defined(_M_ARM64)
const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr) = NULL;
#endif
@@ -125,7 +123,6 @@ int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id) = NULL;
int (*mlx_disable_compile_ptr)(void) = NULL;
int (*mlx_enable_compile_ptr)(void) = NULL;
int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode) = NULL;
int (*mlx_cuda_is_available_ptr)(bool* res) = NULL;
mlx_device (*mlx_device_new_ptr)(void) = NULL;
mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index) = NULL;
int (*mlx_device_free_ptr)(mlx_device dev) = NULL;
@@ -136,16 +133,6 @@ int (*mlx_device_get_index_ptr)(int* index, mlx_device dev) = NULL;
int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev) = NULL;
int (*mlx_get_default_device_ptr)(mlx_device* dev) = NULL;
int (*mlx_set_default_device_ptr)(mlx_device dev) = NULL;
int (*mlx_device_is_available_ptr)(bool* avail, mlx_device dev) = NULL;
int (*mlx_device_count_ptr)(int* count, mlx_device_type type) = NULL;
mlx_device_info (*mlx_device_info_new_ptr)(void) = NULL;
int (*mlx_device_info_get_ptr)(mlx_device_info* info, mlx_device dev) = NULL;
int (*mlx_device_info_free_ptr)(mlx_device_info info) = NULL;
int (*mlx_device_info_has_key_ptr)(bool* exists, mlx_device_info info, const char* key) = NULL;
int (*mlx_device_info_is_string_ptr)(bool* is_string, mlx_device_info info, const char* key) = NULL;
int (*mlx_device_info_get_string_ptr)(const char** value, mlx_device_info info, const char* key) = NULL;
int (*mlx_device_info_get_size_ptr)(size_t* value, mlx_device_info info, const char* key) = NULL;
int (*mlx_device_info_get_keys_ptr)(mlx_vector_string* keys, mlx_device_info info) = NULL;
int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) = NULL;
int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL;
int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL;
@@ -276,6 +263,7 @@ int (*mlx_reset_peak_memory_ptr)(void) = NULL;
int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit) = NULL;
int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit) = NULL;
int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit) = NULL;
mlx_metal_device_info_t (*mlx_metal_device_info_ptr)(void) = NULL;
int (*mlx_metal_is_available_ptr)(bool* res) = NULL;
int (*mlx_metal_start_capture_ptr)(const char* path) = NULL;
int (*mlx_metal_stop_capture_ptr)(void) = NULL;
@@ -670,16 +658,6 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data\n");
return -1;
}
mlx_array_new_data_managed_ptr = dlsym(handle, "mlx_array_new_data_managed");
if (mlx_array_new_data_managed_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed\n");
return -1;
}
mlx_array_new_data_managed_payload_ptr = dlsym(handle, "mlx_array_new_data_managed_payload");
if (mlx_array_new_data_managed_payload_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed_payload\n");
return -1;
}
mlx_array_set_ptr = dlsym(handle, "mlx_array_set");
if (mlx_array_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set\n");
@@ -1163,11 +1141,6 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_compile_mode\n");
return -1;
}
mlx_cuda_is_available_ptr = dlsym(handle, "mlx_cuda_is_available");
if (mlx_cuda_is_available_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_cuda_is_available\n");
return -1;
}
mlx_device_new_ptr = dlsym(handle, "mlx_device_new");
if (mlx_device_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_new\n");
@@ -1218,56 +1191,6 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_default_device\n");
return -1;
}
mlx_device_is_available_ptr = dlsym(handle, "mlx_device_is_available");
if (mlx_device_is_available_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_is_available\n");
return -1;
}
mlx_device_count_ptr = dlsym(handle, "mlx_device_count");
if (mlx_device_count_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_count\n");
return -1;
}
mlx_device_info_new_ptr = dlsym(handle, "mlx_device_info_new");
if (mlx_device_info_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_new\n");
return -1;
}
mlx_device_info_get_ptr = dlsym(handle, "mlx_device_info_get");
if (mlx_device_info_get_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get\n");
return -1;
}
mlx_device_info_free_ptr = dlsym(handle, "mlx_device_info_free");
if (mlx_device_info_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_free\n");
return -1;
}
mlx_device_info_has_key_ptr = dlsym(handle, "mlx_device_info_has_key");
if (mlx_device_info_has_key_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_has_key\n");
return -1;
}
mlx_device_info_is_string_ptr = dlsym(handle, "mlx_device_info_is_string");
if (mlx_device_info_is_string_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_is_string\n");
return -1;
}
mlx_device_info_get_string_ptr = dlsym(handle, "mlx_device_info_get_string");
if (mlx_device_info_get_string_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_string\n");
return -1;
}
mlx_device_info_get_size_ptr = dlsym(handle, "mlx_device_info_get_size");
if (mlx_device_info_get_size_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_size\n");
return -1;
}
mlx_device_info_get_keys_ptr = dlsym(handle, "mlx_device_info_get_keys");
if (mlx_device_info_get_keys_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_keys\n");
return -1;
}
mlx_distributed_all_gather_ptr = dlsym(handle, "mlx_distributed_all_gather");
if (mlx_distributed_all_gather_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_gather\n");
@@ -1918,6 +1841,11 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_wired_limit\n");
return -1;
}
mlx_metal_device_info_ptr = dlsym(handle, "mlx_metal_device_info");
if (mlx_metal_device_info_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_device_info\n");
return -1;
}
mlx_metal_is_available_ptr = dlsym(handle, "mlx_metal_is_available");
if (mlx_metal_is_available_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_is_available\n");
@@ -3600,14 +3528,6 @@ mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dt
return mlx_array_new_data_ptr(data, shape, dim, dtype);
}
mlx_array mlx_array_new_data_managed(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) {
return mlx_array_new_data_managed_ptr(data, shape, dim, dtype, dtor);
}
mlx_array mlx_array_new_data_managed_payload(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) {
return mlx_array_new_data_managed_payload_ptr(data, shape, dim, dtype, payload, dtor);
}
int mlx_array_set(mlx_array* arr, const mlx_array src) {
return mlx_array_set_ptr(arr, src);
}
@@ -3724,7 +3644,7 @@ int mlx_array_item_float64(double* res, const mlx_array arr) {
return mlx_array_item_float64_ptr(res, arr);
}
int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr) {
int mlx_array_item_complex64(float _Complex* res, const mlx_array arr) {
return mlx_array_item_complex64_ptr(res, arr);
}
@@ -3784,7 +3704,7 @@ const double* mlx_array_data_float64(const mlx_array arr) {
return mlx_array_data_float64_ptr(arr);
}
const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr) {
const float _Complex* mlx_array_data_complex64(const mlx_array arr) {
return mlx_array_data_complex64_ptr(arr);
}
@@ -3996,10 +3916,6 @@ int mlx_set_compile_mode(mlx_compile_mode mode) {
return mlx_set_compile_mode_ptr(mode);
}
int mlx_cuda_is_available(bool* res) {
return mlx_cuda_is_available_ptr(res);
}
mlx_device mlx_device_new(void) {
return mlx_device_new_ptr();
}
@@ -4040,46 +3956,6 @@ int mlx_set_default_device(mlx_device dev) {
return mlx_set_default_device_ptr(dev);
}
int mlx_device_is_available(bool* avail, mlx_device dev) {
return mlx_device_is_available_ptr(avail, dev);
}
int mlx_device_count(int* count, mlx_device_type type) {
return mlx_device_count_ptr(count, type);
}
mlx_device_info mlx_device_info_new(void) {
return mlx_device_info_new_ptr();
}
int mlx_device_info_get(mlx_device_info* info, mlx_device dev) {
return mlx_device_info_get_ptr(info, dev);
}
int mlx_device_info_free(mlx_device_info info) {
return mlx_device_info_free_ptr(info);
}
int mlx_device_info_has_key(bool* exists, mlx_device_info info, const char* key) {
return mlx_device_info_has_key_ptr(exists, info, key);
}
int mlx_device_info_is_string(bool* is_string, mlx_device_info info, const char* key) {
return mlx_device_info_is_string_ptr(is_string, info, key);
}
int mlx_device_info_get_string(const char** value, mlx_device_info info, const char* key) {
return mlx_device_info_get_string_ptr(value, info, key);
}
int mlx_device_info_get_size(size_t* value, mlx_device_info info, const char* key) {
return mlx_device_info_get_size_ptr(value, info, key);
}
int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info) {
return mlx_device_info_get_keys_ptr(keys, info);
}
int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) {
return mlx_distributed_all_gather_ptr(res, x, group, S);
}
@@ -4600,6 +4476,10 @@ int mlx_set_wired_limit(size_t* res, size_t limit) {
return mlx_set_wired_limit_ptr(res, limit);
}
mlx_metal_device_info_t mlx_metal_device_info(void) {
return mlx_metal_device_info_ptr();
}
int mlx_metal_is_available(bool* res) {
return mlx_metal_is_available_ptr(res);
}

View File

@@ -26,8 +26,6 @@
#undef mlx_array_new_double
#undef mlx_array_new_complex
#undef mlx_array_new_data
#undef mlx_array_new_data_managed
#undef mlx_array_new_data_managed_payload
#undef mlx_array_set
#undef mlx_array_set_bool
#undef mlx_array_set_int
@@ -123,7 +121,6 @@
#undef mlx_disable_compile
#undef mlx_enable_compile
#undef mlx_set_compile_mode
#undef mlx_cuda_is_available
#undef mlx_device_new
#undef mlx_device_new_type
#undef mlx_device_free
@@ -134,16 +131,6 @@
#undef mlx_device_get_type
#undef mlx_get_default_device
#undef mlx_set_default_device
#undef mlx_device_is_available
#undef mlx_device_count
#undef mlx_device_info_new
#undef mlx_device_info_get
#undef mlx_device_info_free
#undef mlx_device_info_has_key
#undef mlx_device_info_is_string
#undef mlx_device_info_get_string
#undef mlx_device_info_get_size
#undef mlx_device_info_get_keys
#undef mlx_distributed_all_gather
#undef mlx_distributed_all_max
#undef mlx_distributed_all_min
@@ -274,6 +261,7 @@
#undef mlx_set_cache_limit
#undef mlx_set_memory_limit
#undef mlx_set_wired_limit
#undef mlx_metal_device_info
#undef mlx_metal_is_available
#undef mlx_metal_start_capture
#undef mlx_metal_stop_capture
@@ -614,8 +602,6 @@ extern mlx_array (*mlx_array_new_float64_ptr)(double val);
extern mlx_array (*mlx_array_new_double_ptr)(double val);
extern mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val);
extern mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype);
extern mlx_array (*mlx_array_new_data_managed_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*));
extern mlx_array (*mlx_array_new_data_managed_payload_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*));
extern int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src);
extern int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val);
extern int (*mlx_array_set_int_ptr)(mlx_array* arr, int val);
@@ -645,7 +631,7 @@ extern int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr);
extern int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr);
extern int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr);
extern int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr);
extern int (*mlx_array_item_complex64_ptr)(mlx_complex64_t* res, const mlx_array arr);
extern int (*mlx_array_item_complex64_ptr)(float _Complex* res, const mlx_array arr);
#if defined(__aarch64__) || defined(_M_ARM64)
extern int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr);
#endif
@@ -663,7 +649,7 @@ extern const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr);
extern const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr);
extern const float* (*mlx_array_data_float32_ptr)(const mlx_array arr);
extern const double* (*mlx_array_data_float64_ptr)(const mlx_array arr);
extern const mlx_complex64_t* (*mlx_array_data_complex64_ptr)(const mlx_array arr);
extern const float _Complex* (*mlx_array_data_complex64_ptr)(const mlx_array arr);
#if defined(__aarch64__) || defined(_M_ARM64)
extern const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr);
#endif
@@ -719,7 +705,6 @@ extern int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id);
extern int (*mlx_disable_compile_ptr)(void);
extern int (*mlx_enable_compile_ptr)(void);
extern int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode);
extern int (*mlx_cuda_is_available_ptr)(bool* res);
extern mlx_device (*mlx_device_new_ptr)(void);
extern mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index);
extern int (*mlx_device_free_ptr)(mlx_device dev);
@@ -730,16 +715,6 @@ extern int (*mlx_device_get_index_ptr)(int* index, mlx_device dev);
extern int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev);
extern int (*mlx_get_default_device_ptr)(mlx_device* dev);
extern int (*mlx_set_default_device_ptr)(mlx_device dev);
extern int (*mlx_device_is_available_ptr)(bool* avail, mlx_device dev);
extern int (*mlx_device_count_ptr)(int* count, mlx_device_type type);
extern mlx_device_info (*mlx_device_info_new_ptr)(void);
extern int (*mlx_device_info_get_ptr)(mlx_device_info* info, mlx_device dev);
extern int (*mlx_device_info_free_ptr)(mlx_device_info info);
extern int (*mlx_device_info_has_key_ptr)(bool* exists, mlx_device_info info, const char* key);
extern int (*mlx_device_info_is_string_ptr)(bool* is_string, mlx_device_info info, const char* key);
extern int (*mlx_device_info_get_string_ptr)(const char** value, mlx_device_info info, const char* key);
extern int (*mlx_device_info_get_size_ptr)(size_t* value, mlx_device_info info, const char* key);
extern int (*mlx_device_info_get_keys_ptr)(mlx_vector_string* keys, mlx_device_info info);
extern int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S);
extern int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
extern int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
@@ -870,6 +845,7 @@ extern int (*mlx_reset_peak_memory_ptr)(void);
extern int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit);
extern int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit);
extern int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit);
extern mlx_metal_device_info_t (*mlx_metal_device_info_ptr)(void);
extern int (*mlx_metal_is_available_ptr)(bool* res);
extern int (*mlx_metal_start_capture_ptr)(const char* path);
extern int (*mlx_metal_stop_capture_ptr)(void);
@@ -1226,10 +1202,6 @@ mlx_array mlx_array_new_complex(float real_val, float imag_val);
mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dtype dtype);
mlx_array mlx_array_new_data_managed(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*));
mlx_array mlx_array_new_data_managed_payload(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*));
int mlx_array_set(mlx_array* arr, const mlx_array src);
int mlx_array_set_bool(mlx_array* arr, bool val);
@@ -1288,7 +1260,7 @@ int mlx_array_item_float32(float* res, const mlx_array arr);
int mlx_array_item_float64(double* res, const mlx_array arr);
int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr);
int mlx_array_item_complex64(float _Complex* res, const mlx_array arr);
#if defined(__aarch64__) || defined(_M_ARM64)
int mlx_array_item_float16(float16_t* res, const mlx_array arr);
@@ -1320,7 +1292,7 @@ const float* mlx_array_data_float32(const mlx_array arr);
const double* mlx_array_data_float64(const mlx_array arr);
const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr);
const float _Complex* mlx_array_data_complex64(const mlx_array arr);
#if defined(__aarch64__) || defined(_M_ARM64)
const float16_t* mlx_array_data_float16(const mlx_array arr);
@@ -1428,8 +1400,6 @@ int mlx_enable_compile(void);
int mlx_set_compile_mode(mlx_compile_mode mode);
int mlx_cuda_is_available(bool* res);
mlx_device mlx_device_new(void);
mlx_device mlx_device_new_type(mlx_device_type type, int index);
@@ -1450,26 +1420,6 @@ int mlx_get_default_device(mlx_device* dev);
int mlx_set_default_device(mlx_device dev);
int mlx_device_is_available(bool* avail, mlx_device dev);
int mlx_device_count(int* count, mlx_device_type type);
mlx_device_info mlx_device_info_new(void);
int mlx_device_info_get(mlx_device_info* info, mlx_device dev);
int mlx_device_info_free(mlx_device_info info);
int mlx_device_info_has_key(bool* exists, mlx_device_info info, const char* key);
int mlx_device_info_is_string(bool* is_string, mlx_device_info info, const char* key);
int mlx_device_info_get_string(const char** value, mlx_device_info info, const char* key);
int mlx_device_info_get_size(size_t* value, mlx_device_info info, const char* key);
int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info);
int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S);
int mlx_distributed_all_max(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
@@ -1730,6 +1680,8 @@ int mlx_set_memory_limit(size_t* res, size_t limit);
int mlx_set_wired_limit(size_t* res, size_t limit);
mlx_metal_device_info_t mlx_metal_device_info(void);
int mlx_metal_is_available(bool* res);
int mlx_metal_start_capture(const char* path);

View File

@@ -15,7 +15,7 @@ set(CMAKE_INSTALL_RPATH "@loader_path")
include(FetchContent)
set(MLX_C_GIT_TAG "v0.5.0" CACHE STRING "")
set(MLX_C_GIT_TAG "v0.4.1" CACHE STRING "")
FetchContent_Declare(
mlx-c

View File

@@ -22,19 +22,6 @@ mlx_array (*mlx_array_new_data_)(
const int* shape,
int dim,
mlx_dtype dtype) = NULL;
mlx_array (*mlx_array_new_data_managed_)(
void* data,
const int* shape,
int dim,
mlx_dtype dtype,
void (*dtor)(void*)) = NULL;
mlx_array (*mlx_array_new_data_managed_payload_)(
void* data,
const int* shape,
int dim,
mlx_dtype dtype,
void* payload,
void (*dtor)(void*)) = NULL;
int (*mlx_array_set_)(mlx_array* arr, const mlx_array src) = NULL;
int (*mlx_array_set_bool_)(mlx_array* arr, bool val) = NULL;
int (*mlx_array_set_int_)(mlx_array* arr, int val) = NULL;
@@ -69,7 +56,7 @@ int (*mlx_array_item_int32_)(int32_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_int64_)(int64_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_float32_)(float* res, const mlx_array arr) = NULL;
int (*mlx_array_item_float64_)(double* res, const mlx_array arr) = NULL;
int (*mlx_array_item_complex64_)(mlx_complex64_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_complex64_)(float _Complex* res, const mlx_array arr) = NULL;
int (*mlx_array_item_float16_)(float16_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_bfloat16_)(bfloat16_t* res, const mlx_array arr) = NULL;
const bool * (*mlx_array_data_bool_)(const mlx_array arr) = NULL;
@@ -83,7 +70,7 @@ const int32_t * (*mlx_array_data_int32_)(const mlx_array arr) = NULL;
const int64_t * (*mlx_array_data_int64_)(const mlx_array arr) = NULL;
const float * (*mlx_array_data_float32_)(const mlx_array arr) = NULL;
const double * (*mlx_array_data_float64_)(const mlx_array arr) = NULL;
const mlx_complex64_t * (*mlx_array_data_complex64_)(const mlx_array arr) = NULL;
const float _Complex * (*mlx_array_data_complex64_)(const mlx_array arr) = NULL;
const float16_t * (*mlx_array_data_float16_)(const mlx_array arr) = NULL;
const bfloat16_t * (*mlx_array_data_bfloat16_)(const mlx_array arr) = NULL;
int (*_mlx_array_is_available_)(bool* res, const mlx_array arr) = NULL;
@@ -107,11 +94,10 @@ int (*mlx_closure_apply_)(
mlx_closure (*mlx_closure_new_unary_)(int (*fun)(mlx_array*, const mlx_array)) = NULL;
mlx_closure_kwargs (*mlx_closure_kwargs_new_)(void) = NULL;
int (*mlx_closure_kwargs_free_)(mlx_closure_kwargs cls) = NULL;
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_map_string_to_array)) = NULL;
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_map_string_to_array)) = NULL;
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_payload_)(
int (*fun)(
mlx_vector_array*,
@@ -150,12 +136,11 @@ int (*mlx_closure_value_and_grad_apply_)(
const mlx_vector_array input) = NULL;
mlx_closure_custom (*mlx_closure_custom_new_)(void) = NULL;
int (*mlx_closure_custom_free_)(mlx_closure_custom cls) = NULL;
mlx_closure_custom (*mlx_closure_custom_new_func_)(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const mlx_vector_array)) = NULL;
mlx_closure_custom (*mlx_closure_custom_new_func_)(int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const mlx_vector_array)) = NULL;
mlx_closure_custom (*mlx_closure_custom_new_func_payload_)(
int (*fun)(
mlx_vector_array*,
@@ -176,13 +161,12 @@ int (*mlx_closure_custom_apply_)(
const mlx_vector_array input_2) = NULL;
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_)(void) = NULL;
int (*mlx_closure_custom_jvp_free_)(mlx_closure_custom_jvp cls) = NULL;
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const int*,
size_t _num)) = NULL;
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const int*,
size_t _num)) = NULL;
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_payload_)(
int (*fun)(
mlx_vector_array*,
@@ -205,13 +189,12 @@ int (*mlx_closure_custom_jvp_apply_)(
size_t input_2_num) = NULL;
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_)(void) = NULL;
int (*mlx_closure_custom_vmap_free_)(mlx_closure_custom_vmap cls) = NULL;
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(
int (*fun)(
mlx_vector_array*,
mlx_vector_int*,
const mlx_vector_array,
const int*,
size_t _num)) = NULL;
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(int (*fun)(
mlx_vector_array*,
mlx_vector_int*,
const mlx_vector_array,
const int*,
size_t _num)) = NULL;
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_payload_)(
int (*fun)(
mlx_vector_array*,
@@ -245,7 +228,6 @@ int (*mlx_detail_compile_erase_)(uintptr_t fun_id) = NULL;
int (*mlx_disable_compile_)(void) = NULL;
int (*mlx_enable_compile_)(void) = NULL;
int (*mlx_set_compile_mode_)(mlx_compile_mode mode) = NULL;
int (*mlx_cuda_is_available_)(bool* res) = NULL;
mlx_device (*mlx_device_new_)(void) = NULL;
mlx_device (*mlx_device_new_type_)(mlx_device_type type, int index) = NULL;
int (*mlx_device_free_)(mlx_device dev) = NULL;
@@ -256,28 +238,11 @@ int (*mlx_device_get_index_)(int* index, mlx_device dev) = NULL;
int (*mlx_device_get_type_)(mlx_device_type* type, mlx_device dev) = NULL;
int (*mlx_get_default_device_)(mlx_device* dev) = NULL;
int (*mlx_set_default_device_)(mlx_device dev) = NULL;
int (*mlx_device_is_available_)(bool* avail, mlx_device dev) = NULL;
int (*mlx_device_count_)(int* count, mlx_device_type type) = NULL;
mlx_device_info (*mlx_device_info_new_)(void) = NULL;
int (*mlx_device_info_get_)(mlx_device_info* info, mlx_device dev) = NULL;
int (*mlx_device_info_free_)(mlx_device_info info) = NULL;
int (*mlx_device_info_has_key_)(
bool* exists,
mlx_device_info info,
const char* key) = NULL;
int (*mlx_device_info_is_string_)(
bool* is_string,
mlx_device_info info,
const char* key) = NULL;
int (*mlx_device_info_get_string_)(
const char** value,
mlx_device_info info,
const char* key) = NULL;
int (*mlx_device_info_get_size_)(
size_t* value,
mlx_device_info info,
const char* key) = NULL;
int (*mlx_device_info_get_keys_)(mlx_vector_string* keys, mlx_device_info info) = NULL;
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
bool (*mlx_distributed_is_available_)(void) = NULL;
mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL;
int (*mlx_distributed_all_gather_)(
mlx_array* res,
const mlx_array x,
@@ -323,11 +288,6 @@ int (*mlx_distributed_sum_scatter_)(
const mlx_array x,
const mlx_distributed_group group /* may be null */,
const mlx_stream s) = NULL;
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
bool (*mlx_distributed_is_available_)(void) = NULL;
mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL;
void (*mlx_set_error_handler_)(
mlx_error_handler_func handler,
void* data,
@@ -490,16 +450,6 @@ int (*mlx_fast_rope_)(
int offset,
const mlx_array freqs /* may be null */,
const mlx_stream s) = NULL;
int (*mlx_fast_rope_dynamic_)(
mlx_array* res,
const mlx_array x,
int dims,
bool traditional,
mlx_optional_float base,
float scale,
const mlx_array offset,
const mlx_array freqs /* may be null */,
const mlx_stream s) = NULL;
int (*mlx_fast_scaled_dot_product_attention_)(
mlx_array* res,
const mlx_array queries,
@@ -610,6 +560,14 @@ int (*mlx_fft_rfftn_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable) = NULL;
int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io) = NULL;
int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io) = NULL;
int (*mlx_io_reader_free_)(mlx_io_reader io) = NULL;
mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable) = NULL;
int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io) = NULL;
int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io) = NULL;
int (*mlx_io_writer_free_)(mlx_io_writer io) = NULL;
int (*mlx_load_reader_)(
mlx_array* res,
mlx_io_reader in_stream,
@@ -635,14 +593,6 @@ int (*mlx_save_safetensors_)(
const char* file,
const mlx_map_string_to_array param,
const mlx_map_string_to_string metadata) = NULL;
mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable) = NULL;
int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io) = NULL;
int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io) = NULL;
int (*mlx_io_reader_free_)(mlx_io_reader io) = NULL;
mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable) = NULL;
int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io) = NULL;
int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io) = NULL;
int (*mlx_io_writer_free_)(mlx_io_writer io) = NULL;
int (*mlx_linalg_cholesky_)(
mlx_array* res,
const mlx_array a,
@@ -783,6 +733,7 @@ int (*mlx_reset_peak_memory_)(void) = NULL;
int (*mlx_set_cache_limit_)(size_t* res, size_t limit) = NULL;
int (*mlx_set_memory_limit_)(size_t* res, size_t limit) = NULL;
int (*mlx_set_wired_limit_)(size_t* res, size_t limit) = NULL;
mlx_metal_device_info_t (*mlx_metal_device_info_)(void) = NULL;
int (*mlx_metal_is_available_)(bool* res) = NULL;
int (*mlx_metal_start_capture_)(const char* path) = NULL;
int (*mlx_metal_stop_capture_)(void) = NULL;
@@ -1211,14 +1162,6 @@ int (*mlx_gather_)(
const int* slice_sizes,
size_t slice_sizes_num,
const mlx_stream s) = NULL;
int (*mlx_gather_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
int axis,
const int* slice_sizes,
size_t slice_sizes_num,
const mlx_stream s) = NULL;
int (*mlx_gather_mm_)(
mlx_array* res,
const mlx_array a,
@@ -1540,15 +1483,6 @@ int (*mlx_put_along_axis_)(
const mlx_array values,
int axis,
const mlx_stream s) = NULL;
int (*mlx_qqmm_)(
mlx_array* res,
const mlx_array x,
const mlx_array w,
const mlx_array w_scales /* may be null */,
mlx_optional_int group_size,
mlx_optional_int bits,
const char* mode,
const mlx_stream s) = NULL;
int (*mlx_quantize_)(
mlx_vector_array* res,
const mlx_array w,
@@ -1632,13 +1566,6 @@ int (*mlx_scatter_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
int (*mlx_scatter_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array updates,
int axis,
const mlx_stream s) = NULL;
int (*mlx_scatter_add_)(
mlx_array* res,
const mlx_array a,
@@ -1647,13 +1574,6 @@ int (*mlx_scatter_add_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
int (*mlx_scatter_add_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array updates,
int axis,
const mlx_stream s) = NULL;
int (*mlx_scatter_add_axis_)(
mlx_array* res,
const mlx_array a,
@@ -1669,13 +1589,6 @@ int (*mlx_scatter_max_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
int (*mlx_scatter_max_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array updates,
int axis,
const mlx_stream s) = NULL;
int (*mlx_scatter_min_)(
mlx_array* res,
const mlx_array a,
@@ -1684,13 +1597,6 @@ int (*mlx_scatter_min_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
int (*mlx_scatter_min_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array updates,
int axis,
const mlx_stream s) = NULL;
int (*mlx_scatter_prod_)(
mlx_array* res,
const mlx_array a,
@@ -1699,13 +1605,6 @@ int (*mlx_scatter_prod_)(
const int* axes,
size_t axes_num,
const mlx_stream s) = NULL;
int (*mlx_scatter_prod_single_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array updates,
int axis,
const mlx_stream s) = NULL;
int (*mlx_segmented_mm_)(
mlx_array* res,
const mlx_array a,
@@ -2129,6 +2028,22 @@ mlx_string (*mlx_string_new_data_)(const char* str) = NULL;
int (*mlx_string_set_)(mlx_string* str, const mlx_string src) = NULL;
const char * (*mlx_string_data_)(mlx_string str) = NULL;
int (*mlx_string_free_)(mlx_string str) = NULL;
int (*mlx_detail_vmap_replace_)(
mlx_vector_array* res,
const mlx_vector_array inputs,
const mlx_vector_array s_inputs,
const mlx_vector_array s_outputs,
const int* in_axes,
size_t in_axes_num,
const int* out_axes,
size_t out_axes_num) = NULL;
int (*mlx_detail_vmap_trace_)(
mlx_vector_array* res_0,
mlx_vector_array* res_1,
const mlx_closure fun,
const mlx_vector_array inputs,
const int* in_axes,
size_t in_axes_num) = NULL;
int (*mlx_async_eval_)(const mlx_vector_array outputs) = NULL;
int (*mlx_checkpoint_)(mlx_closure* res, const mlx_closure fun) = NULL;
int (*mlx_custom_function_)(
@@ -2159,22 +2074,6 @@ int (*mlx_vjp_)(
const mlx_closure fun,
const mlx_vector_array primals,
const mlx_vector_array cotangents) = NULL;
int (*mlx_detail_vmap_replace_)(
mlx_vector_array* res,
const mlx_vector_array inputs,
const mlx_vector_array s_inputs,
const mlx_vector_array s_outputs,
const int* in_axes,
size_t in_axes_num,
const int* out_axes,
size_t out_axes_num) = NULL;
int (*mlx_detail_vmap_trace_)(
mlx_vector_array* res_0,
mlx_vector_array* res_1,
const mlx_closure fun,
const mlx_vector_array inputs,
const int* in_axes,
size_t in_axes_num) = NULL;
mlx_vector_array (*mlx_vector_array_new_)(void) = NULL;
int (*mlx_vector_array_set_)(mlx_vector_array* vec, const mlx_vector_array src) = NULL;
int (*mlx_vector_array_free_)(mlx_vector_array vec) = NULL;
@@ -2267,8 +2166,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_array_new_double);
CHECK_LOAD(handle, mlx_array_new_complex);
CHECK_LOAD(handle, mlx_array_new_data);
CHECK_LOAD(handle, mlx_array_new_data_managed);
CHECK_LOAD(handle, mlx_array_new_data_managed_payload);
CHECK_LOAD(handle, mlx_array_set);
CHECK_LOAD(handle, mlx_array_set_bool);
CHECK_LOAD(handle, mlx_array_set_int);
@@ -2364,7 +2261,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_disable_compile);
CHECK_LOAD(handle, mlx_enable_compile);
CHECK_LOAD(handle, mlx_set_compile_mode);
CHECK_LOAD(handle, mlx_cuda_is_available);
CHECK_LOAD(handle, mlx_device_new);
CHECK_LOAD(handle, mlx_device_new_type);
CHECK_LOAD(handle, mlx_device_free);
@@ -2375,16 +2271,11 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_device_get_type);
CHECK_LOAD(handle, mlx_get_default_device);
CHECK_LOAD(handle, mlx_set_default_device);
CHECK_LOAD(handle, mlx_device_is_available);
CHECK_LOAD(handle, mlx_device_count);
CHECK_LOAD(handle, mlx_device_info_new);
CHECK_LOAD(handle, mlx_device_info_get);
CHECK_LOAD(handle, mlx_device_info_free);
CHECK_LOAD(handle, mlx_device_info_has_key);
CHECK_LOAD(handle, mlx_device_info_is_string);
CHECK_LOAD(handle, mlx_device_info_get_string);
CHECK_LOAD(handle, mlx_device_info_get_size);
CHECK_LOAD(handle, mlx_device_info_get_keys);
CHECK_LOAD(handle, mlx_distributed_group_rank);
CHECK_LOAD(handle, mlx_distributed_group_size);
CHECK_LOAD(handle, mlx_distributed_group_split);
CHECK_LOAD(handle, mlx_distributed_is_available);
CHECK_LOAD(handle, mlx_distributed_init);
CHECK_LOAD(handle, mlx_distributed_all_gather);
CHECK_LOAD(handle, mlx_distributed_all_max);
CHECK_LOAD(handle, mlx_distributed_all_min);
@@ -2393,11 +2284,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_distributed_recv_like);
CHECK_LOAD(handle, mlx_distributed_send);
CHECK_LOAD(handle, mlx_distributed_sum_scatter);
CHECK_LOAD(handle, mlx_distributed_group_rank);
CHECK_LOAD(handle, mlx_distributed_group_size);
CHECK_LOAD(handle, mlx_distributed_group_split);
CHECK_LOAD(handle, mlx_distributed_is_available);
CHECK_LOAD(handle, mlx_distributed_init);
CHECK_LOAD(handle, mlx_set_error_handler);
CHECK_LOAD(handle, _mlx_error);
CHECK_LOAD(handle, mlx_export_function);
@@ -2439,7 +2325,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_fast_metal_kernel_apply);
CHECK_LOAD(handle, mlx_fast_rms_norm);
CHECK_LOAD(handle, mlx_fast_rope);
CHECK_LOAD(handle, mlx_fast_rope_dynamic);
CHECK_LOAD(handle, mlx_fast_scaled_dot_product_attention);
CHECK_LOAD(handle, mlx_fft_fft);
CHECK_LOAD(handle, mlx_fft_fft2);
@@ -2455,14 +2340,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_fft_rfft);
CHECK_LOAD(handle, mlx_fft_rfft2);
CHECK_LOAD(handle, mlx_fft_rfftn);
CHECK_LOAD(handle, mlx_load_reader);
CHECK_LOAD(handle, mlx_load);
CHECK_LOAD(handle, mlx_load_safetensors_reader);
CHECK_LOAD(handle, mlx_load_safetensors);
CHECK_LOAD(handle, mlx_save_writer);
CHECK_LOAD(handle, mlx_save);
CHECK_LOAD(handle, mlx_save_safetensors_writer);
CHECK_LOAD(handle, mlx_save_safetensors);
CHECK_LOAD(handle, mlx_io_reader_new);
CHECK_LOAD(handle, mlx_io_reader_descriptor);
CHECK_LOAD(handle, mlx_io_reader_tostring);
@@ -2471,6 +2348,14 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_io_writer_descriptor);
CHECK_LOAD(handle, mlx_io_writer_tostring);
CHECK_LOAD(handle, mlx_io_writer_free);
CHECK_LOAD(handle, mlx_load_reader);
CHECK_LOAD(handle, mlx_load);
CHECK_LOAD(handle, mlx_load_safetensors_reader);
CHECK_LOAD(handle, mlx_load_safetensors);
CHECK_LOAD(handle, mlx_save_writer);
CHECK_LOAD(handle, mlx_save);
CHECK_LOAD(handle, mlx_save_safetensors_writer);
CHECK_LOAD(handle, mlx_save_safetensors);
CHECK_LOAD(handle, mlx_linalg_cholesky);
CHECK_LOAD(handle, mlx_linalg_cholesky_inv);
CHECK_LOAD(handle, mlx_linalg_cross);
@@ -2515,6 +2400,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_set_cache_limit);
CHECK_LOAD(handle, mlx_set_memory_limit);
CHECK_LOAD(handle, mlx_set_wired_limit);
CHECK_LOAD(handle, mlx_metal_device_info);
CHECK_LOAD(handle, mlx_metal_is_available);
CHECK_LOAD(handle, mlx_metal_start_capture);
CHECK_LOAD(handle, mlx_metal_stop_capture);
@@ -2600,7 +2486,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_full);
CHECK_LOAD(handle, mlx_full_like);
CHECK_LOAD(handle, mlx_gather);
CHECK_LOAD(handle, mlx_gather_single);
CHECK_LOAD(handle, mlx_gather_mm);
CHECK_LOAD(handle, mlx_gather_qmm);
CHECK_LOAD(handle, mlx_greater);
@@ -2665,7 +2550,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_prod_axis);
CHECK_LOAD(handle, mlx_prod);
CHECK_LOAD(handle, mlx_put_along_axis);
CHECK_LOAD(handle, mlx_qqmm);
CHECK_LOAD(handle, mlx_quantize);
CHECK_LOAD(handle, mlx_quantized_matmul);
CHECK_LOAD(handle, mlx_radians);
@@ -2682,16 +2566,11 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_round);
CHECK_LOAD(handle, mlx_rsqrt);
CHECK_LOAD(handle, mlx_scatter);
CHECK_LOAD(handle, mlx_scatter_single);
CHECK_LOAD(handle, mlx_scatter_add);
CHECK_LOAD(handle, mlx_scatter_add_single);
CHECK_LOAD(handle, mlx_scatter_add_axis);
CHECK_LOAD(handle, mlx_scatter_max);
CHECK_LOAD(handle, mlx_scatter_max_single);
CHECK_LOAD(handle, mlx_scatter_min);
CHECK_LOAD(handle, mlx_scatter_min_single);
CHECK_LOAD(handle, mlx_scatter_prod);
CHECK_LOAD(handle, mlx_scatter_prod_single);
CHECK_LOAD(handle, mlx_segmented_mm);
CHECK_LOAD(handle, mlx_sigmoid);
CHECK_LOAD(handle, mlx_sign);
@@ -2786,6 +2665,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_string_set);
CHECK_LOAD(handle, mlx_string_data);
CHECK_LOAD(handle, mlx_string_free);
CHECK_LOAD(handle, mlx_detail_vmap_replace);
CHECK_LOAD(handle, mlx_detail_vmap_trace);
CHECK_LOAD(handle, mlx_async_eval);
CHECK_LOAD(handle, mlx_checkpoint);
CHECK_LOAD(handle, mlx_custom_function);
@@ -2794,8 +2675,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_jvp);
CHECK_LOAD(handle, mlx_value_and_grad);
CHECK_LOAD(handle, mlx_vjp);
CHECK_LOAD(handle, mlx_detail_vmap_replace);
CHECK_LOAD(handle, mlx_detail_vmap_trace);
CHECK_LOAD(handle, mlx_vector_array_new);
CHECK_LOAD(handle, mlx_vector_array_set);
CHECK_LOAD(handle, mlx_vector_array_free);

View File

File diff suppressed because it is too large Load Diff

View File

@@ -4,10 +4,6 @@
#define MLX_GENERATED_H
#include "dynamic.h"
{{ range .Functions }}
#define {{ .Name }} {{ .Name }}_mlx_gen_orig_
{{- end }}
#include "mlx/c/mlx.h"
{{ range .Functions }}
#undef {{ .Name }}