mirror of
https://github.com/ollama/ollama.git
synced 2026-02-18 15:25:27 -05:00
Compare commits
3 Commits
v0.16.3-rc
...
brucemacd/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9ac1300805 | ||
|
|
43d9907dd6 | ||
|
|
91dc088e8b |
@@ -1 +1 @@
|
||||
v0.5.0
|
||||
v0.4.1
|
||||
|
||||
13
api/types.go
13
api/types.go
@@ -922,6 +922,19 @@ type UserResponse struct {
|
||||
Plan string `json:"plan,omitempty"`
|
||||
}
|
||||
|
||||
type UsageResponse struct {
|
||||
// Start is the time the server started tracking usage (UTC, RFC 3339).
|
||||
Start time.Time `json:"start"`
|
||||
Usage []ModelUsageData `json:"usage"`
|
||||
}
|
||||
|
||||
type ModelUsageData struct {
|
||||
Model string `json:"model"`
|
||||
Requests int64 `json:"requests"`
|
||||
PromptTokens int64 `json:"prompt_tokens"`
|
||||
CompletionTokens int64 `json:"completion_tokens"`
|
||||
}
|
||||
|
||||
// Tensor describes the metadata for a given tensor.
|
||||
type Tensor struct {
|
||||
Name string `json:"name"`
|
||||
|
||||
@@ -415,12 +415,6 @@ type multiSelectorModel struct {
|
||||
cancelled bool
|
||||
confirmed bool
|
||||
width int
|
||||
|
||||
// multi enables full multi-select editing mode. The zero value (false)
|
||||
// shows a single-select picker where Enter adds the chosen model to
|
||||
// the existing list. Tab toggles between modes.
|
||||
multi bool
|
||||
singleAdd string // model picked in single mode
|
||||
}
|
||||
|
||||
func newMultiSelectorModel(title string, items []SelectItem, preChecked []string) multiSelectorModel {
|
||||
@@ -435,23 +429,13 @@ func newMultiSelectorModel(title string, items []SelectItem, preChecked []string
|
||||
m.itemIndex[item.Name] = i
|
||||
}
|
||||
|
||||
// Reverse order so preChecked[0] (the current default) ends up last
|
||||
// in checkOrder, matching the "last checked = default" convention.
|
||||
for i := len(preChecked) - 1; i >= 0; i-- {
|
||||
if idx, ok := m.itemIndex[preChecked[i]]; ok {
|
||||
for _, name := range preChecked {
|
||||
if idx, ok := m.itemIndex[name]; ok {
|
||||
m.checked[idx] = true
|
||||
m.checkOrder = append(m.checkOrder, idx)
|
||||
}
|
||||
}
|
||||
|
||||
// Position cursor on the current default model
|
||||
if len(preChecked) > 0 {
|
||||
if idx, ok := m.itemIndex[preChecked[0]]; ok {
|
||||
m.cursor = idx
|
||||
m.updateScroll(m.otherStart())
|
||||
}
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
@@ -562,25 +546,14 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.cancelled = true
|
||||
return m, tea.Quit
|
||||
|
||||
case tea.KeyTab:
|
||||
m.multi = !m.multi
|
||||
|
||||
case tea.KeyEnter:
|
||||
if !m.multi {
|
||||
if len(filtered) > 0 && m.cursor < len(filtered) {
|
||||
m.singleAdd = filtered[m.cursor].Name
|
||||
m.confirmed = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
} else if len(m.checkOrder) > 0 {
|
||||
if len(m.checkOrder) > 0 {
|
||||
m.confirmed = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
|
||||
case tea.KeySpace:
|
||||
if m.multi {
|
||||
m.toggleItem()
|
||||
}
|
||||
m.toggleItem()
|
||||
|
||||
case tea.KeyUp:
|
||||
if m.cursor > 0 {
|
||||
@@ -619,9 +592,7 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// On some terminals (e.g. Windows PowerShell), space arrives as
|
||||
// KeyRunes instead of KeySpace. Intercept it so toggle still works.
|
||||
if len(msg.Runes) == 1 && msg.Runes[0] == ' ' {
|
||||
if m.multi {
|
||||
m.toggleItem()
|
||||
}
|
||||
m.toggleItem()
|
||||
} else {
|
||||
m.filter += string(msg.Runes)
|
||||
m.cursor = 0
|
||||
@@ -633,19 +604,6 @@ func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) renderSingleItem(s *strings.Builder, item SelectItem, idx int) {
|
||||
if idx == m.cursor {
|
||||
s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
|
||||
} else {
|
||||
s.WriteString(selectorItemStyle.Render(item.Name))
|
||||
}
|
||||
s.WriteString("\n")
|
||||
if item.Description != "" {
|
||||
s.WriteString(selectorDescLineStyle.Render(item.Description))
|
||||
s.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem, idx int) {
|
||||
origIdx := m.itemIndex[item.Name]
|
||||
|
||||
@@ -657,7 +615,7 @@ func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem,
|
||||
}
|
||||
|
||||
suffix := ""
|
||||
if len(m.checkOrder) > 0 && m.checkOrder[len(m.checkOrder)-1] == origIdx {
|
||||
if len(m.checkOrder) > 0 && m.checkOrder[0] == origIdx {
|
||||
suffix = " " + selectorDefaultTagStyle.Render("(default)")
|
||||
}
|
||||
|
||||
@@ -679,11 +637,6 @@ func (m multiSelectorModel) View() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
renderItem := m.renderSingleItem
|
||||
if m.multi {
|
||||
renderItem = m.renderMultiItem
|
||||
}
|
||||
|
||||
var s strings.Builder
|
||||
|
||||
s.WriteString(selectorTitleStyle.Render(m.title))
|
||||
@@ -708,7 +661,7 @@ func (m multiSelectorModel) View() string {
|
||||
if idx >= len(filtered) {
|
||||
break
|
||||
}
|
||||
renderItem(&s, filtered[idx], idx)
|
||||
m.renderMultiItem(&s, filtered[idx], idx)
|
||||
}
|
||||
|
||||
if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
|
||||
@@ -731,7 +684,7 @@ func (m multiSelectorModel) View() string {
|
||||
s.WriteString(sectionHeaderStyle.Render("Recommended"))
|
||||
s.WriteString("\n")
|
||||
for _, idx := range recItems {
|
||||
renderItem(&s, filtered[idx], idx)
|
||||
m.renderMultiItem(&s, filtered[idx], idx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -751,7 +704,7 @@ func (m multiSelectorModel) View() string {
|
||||
if idx >= len(otherItems) {
|
||||
break
|
||||
}
|
||||
renderItem(&s, filtered[otherItems[idx]], otherItems[idx])
|
||||
m.renderMultiItem(&s, filtered[otherItems[idx]], otherItems[idx])
|
||||
}
|
||||
|
||||
if remaining := len(otherItems) - m.scrollOffset - displayCount; remaining > 0 {
|
||||
@@ -763,18 +716,15 @@ func (m multiSelectorModel) View() string {
|
||||
|
||||
s.WriteString("\n")
|
||||
|
||||
if !m.multi {
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • tab add multiple • esc cancel"))
|
||||
count := m.selectedCount()
|
||||
if count == 0 {
|
||||
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
|
||||
} else {
|
||||
count := m.selectedCount()
|
||||
if count == 0 {
|
||||
s.WriteString(selectorDescStyle.Render(" Select at least one model."))
|
||||
} else {
|
||||
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
|
||||
}
|
||||
s.WriteString("\n\n")
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • tab select single • enter confirm • esc cancel"))
|
||||
s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
|
||||
}
|
||||
s.WriteString("\n\n")
|
||||
|
||||
s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • enter confirm • esc cancel"))
|
||||
|
||||
result := s.String()
|
||||
if m.width > 0 {
|
||||
@@ -797,28 +747,18 @@ func SelectMultiple(title string, items []SelectItem, preChecked []string) ([]st
|
||||
}
|
||||
|
||||
fm := finalModel.(multiSelectorModel)
|
||||
if fm.cancelled || !fm.confirmed {
|
||||
if fm.cancelled {
|
||||
return nil, ErrCancelled
|
||||
}
|
||||
|
||||
// Single-add mode: prepend the picked model, keep existing models deduped
|
||||
if fm.singleAdd != "" {
|
||||
result := []string{fm.singleAdd}
|
||||
for _, name := range preChecked {
|
||||
if name != fm.singleAdd {
|
||||
result = append(result, name)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
if !fm.confirmed {
|
||||
return nil, ErrCancelled
|
||||
}
|
||||
|
||||
// Multi-edit mode: last checked is default (first in result)
|
||||
last := fm.checkOrder[len(fm.checkOrder)-1]
|
||||
result := []string{fm.items[last].Name}
|
||||
var result []string
|
||||
for _, idx := range fm.checkOrder {
|
||||
if idx != last {
|
||||
result = append(result, fm.items[idx].Name)
|
||||
}
|
||||
result = append(result, fm.items[idx].Name)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -539,7 +539,6 @@ func TestMultiView_CursorIndicator(t *testing.T) {
|
||||
|
||||
func TestMultiView_CheckedItemShowsX(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), []string{"a"})
|
||||
m.multi = true
|
||||
content := m.View()
|
||||
|
||||
if !strings.Contains(content, "[x]") {
|
||||
@@ -551,18 +550,11 @@ func TestMultiView_CheckedItemShowsX(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMultiView_DefaultTag(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "b"})
|
||||
m.multi = true
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), []string{"a"})
|
||||
content := m.View()
|
||||
|
||||
if !strings.Contains(content, "(default)") {
|
||||
t.Error("should have (default) tag")
|
||||
}
|
||||
// preChecked[0] ("a") should be the default (last in checkOrder)
|
||||
aIdx := strings.Index(content, "a")
|
||||
defaultIdx := strings.Index(content, "(default)")
|
||||
if defaultIdx < aIdx {
|
||||
t.Error("(default) tag should appear after 'a' (the current default)")
|
||||
t.Error("first checked item should have (default) tag")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -593,7 +585,6 @@ func TestMultiView_OverflowIndicator(t *testing.T) {
|
||||
|
||||
func TestMultiUpdate_SpaceTogglesItem(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
||||
m.multi = true
|
||||
m.cursor = 1
|
||||
|
||||
// Simulate space delivered as tea.KeySpace
|
||||
@@ -610,7 +601,6 @@ func TestMultiUpdate_SpaceTogglesItem(t *testing.T) {
|
||||
|
||||
func TestMultiUpdate_SpaceRuneTogglesItem(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
||||
m.multi = true
|
||||
m.cursor = 1
|
||||
|
||||
// Simulate space delivered as tea.KeyRunes (Windows PowerShell behavior)
|
||||
@@ -628,161 +618,6 @@ func TestMultiUpdate_SpaceRuneTogglesItem(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// --- Single-add mode ---
|
||||
|
||||
func TestMulti_StartsInSingleMode(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
if m.multi {
|
||||
t.Error("should start in single mode (multi=false)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeNoCheckboxes(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
content := m.View()
|
||||
if strings.Contains(content, "[x]") || strings.Contains(content, "[ ]") {
|
||||
t.Error("single mode should not show checkboxes")
|
||||
}
|
||||
if !strings.Contains(content, "▸") {
|
||||
t.Error("single mode should show cursor indicator")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeEnterPicksItem(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
|
||||
m.cursor = 1
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
m = updated.(multiSelectorModel)
|
||||
|
||||
if m.singleAdd != "b" {
|
||||
t.Errorf("enter in single mode should pick cursor item, got %q", m.singleAdd)
|
||||
}
|
||||
if !m.confirmed {
|
||||
t.Error("should set confirmed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeSpaceIsNoop(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
m.cursor = 0
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeySpace})
|
||||
m = updated.(multiSelectorModel)
|
||||
|
||||
if len(m.checked) != 0 {
|
||||
t.Error("space in single mode should not toggle items")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeSpaceRuneIsNoop(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
m.cursor = 0
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{' '}})
|
||||
m = updated.(multiSelectorModel)
|
||||
|
||||
if len(m.checked) != 0 {
|
||||
t.Error("space rune in single mode should not toggle items")
|
||||
}
|
||||
if m.filter != "" {
|
||||
t.Error("space rune in single mode should not add to filter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_TabTogglesMode(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
|
||||
|
||||
if m.multi {
|
||||
t.Fatal("should start in single mode")
|
||||
}
|
||||
|
||||
updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyTab})
|
||||
m = updated.(multiSelectorModel)
|
||||
if !m.multi {
|
||||
t.Error("tab should switch to multi mode")
|
||||
}
|
||||
|
||||
updated, _ = m.Update(tea.KeyMsg{Type: tea.KeyTab})
|
||||
m = updated.(multiSelectorModel)
|
||||
if m.multi {
|
||||
t.Error("tab should switch back to single mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_SingleModeHelpText(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a"), nil)
|
||||
content := m.View()
|
||||
if !strings.Contains(content, "tab add multiple") {
|
||||
t.Error("single mode should show 'tab add multiple' in help")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_MultiModeHelpText(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("a"), nil)
|
||||
m.multi = true
|
||||
content := m.View()
|
||||
if !strings.Contains(content, "tab select single") {
|
||||
t.Error("multi mode should show 'tab select single' in help")
|
||||
}
|
||||
}
|
||||
|
||||
// --- preChecked initialization order ---
|
||||
|
||||
func TestMulti_PreCheckedDefaultIsLast(t *testing.T) {
|
||||
// preChecked[0] ("a") is the current default and should end up
|
||||
// last in checkOrder so it gets the (default) tag.
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "b", "c"})
|
||||
|
||||
if len(m.checkOrder) != 3 {
|
||||
t.Fatalf("expected 3 in checkOrder, got %d", len(m.checkOrder))
|
||||
}
|
||||
lastIdx := m.checkOrder[len(m.checkOrder)-1]
|
||||
if m.items[lastIdx].Name != "a" {
|
||||
t.Errorf("preChecked[0] should be last in checkOrder, got %q", m.items[lastIdx].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMulti_CursorOnDefaultModel(t *testing.T) {
|
||||
// preChecked[0] ("b") is the default; cursor should start on it
|
||||
m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"b", "c"})
|
||||
|
||||
if m.cursor != 1 {
|
||||
t.Errorf("cursor should be on preChecked[0] ('b') at index 1, got %d", m.cursor)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Multi-mode last-checked is default ---
|
||||
|
||||
func TestMulti_LastCheckedIsDefault(t *testing.T) {
|
||||
m := newMultiSelectorModel("Pick:", items("alpha", "beta", "gamma"), nil)
|
||||
m.multi = true
|
||||
|
||||
// Check "alpha" then "gamma"
|
||||
m.cursor = 0
|
||||
m.toggleItem()
|
||||
m.cursor = 2
|
||||
m.toggleItem()
|
||||
|
||||
// Last checked ("gamma") should be at the end of checkOrder
|
||||
lastIdx := m.checkOrder[len(m.checkOrder)-1]
|
||||
if m.items[lastIdx].Name != "gamma" {
|
||||
t.Errorf("last checked should be 'gamma', got %q", m.items[lastIdx].Name)
|
||||
}
|
||||
|
||||
// The (default) tag renders based on checkOrder[len-1]
|
||||
content := m.View()
|
||||
if !strings.Contains(content, "(default)") {
|
||||
t.Fatal("should show (default) tag")
|
||||
}
|
||||
// "alpha" line should NOT have the default tag
|
||||
for _, line := range strings.Split(content, "\n") {
|
||||
if strings.Contains(line, "alpha") && strings.Contains(line, "(default)") {
|
||||
t.Error("'alpha' (first checked) should not have (default) tag")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Key message helpers for testing
|
||||
|
||||
type keyType = int
|
||||
|
||||
@@ -429,24 +429,8 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
if m.multiModalSelector.confirmed {
|
||||
var selected []string
|
||||
if m.multiModalSelector.singleAdd != "" {
|
||||
// Single-add mode: prepend picked model, keep existing deduped
|
||||
selected = []string{m.multiModalSelector.singleAdd}
|
||||
for _, name := range config.IntegrationModels(m.items[m.cursor].integration) {
|
||||
if name != m.multiModalSelector.singleAdd {
|
||||
selected = append(selected, name)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Last checked is default (first in result)
|
||||
co := m.multiModalSelector.checkOrder
|
||||
last := co[len(co)-1]
|
||||
selected = []string{m.multiModalSelector.items[last].Name}
|
||||
for _, idx := range co {
|
||||
if idx != last {
|
||||
selected = append(selected, m.multiModalSelector.items[idx].Name)
|
||||
}
|
||||
}
|
||||
for _, idx := range m.multiModalSelector.checkOrder {
|
||||
selected = append(selected, m.multiModalSelector.items[idx].Name)
|
||||
}
|
||||
if len(selected) > 0 {
|
||||
m.changeModels = selected
|
||||
|
||||
48
docs/api.md
48
docs/api.md
@@ -15,6 +15,7 @@
|
||||
- [Push a Model](#push-a-model)
|
||||
- [Generate Embeddings](#generate-embeddings)
|
||||
- [List Running Models](#list-running-models)
|
||||
- [Usage](#usage)
|
||||
- [Version](#version)
|
||||
- [Experimental: Image Generation](#image-generation-experimental)
|
||||
|
||||
@@ -1854,6 +1855,53 @@ curl http://localhost:11434/api/embeddings -d '{
|
||||
}
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```
|
||||
GET /api/usage
|
||||
```
|
||||
|
||||
Show aggregate usage statistics per model since the server started. All timestamps are UTC in RFC 3339 format.
|
||||
|
||||
### Examples
|
||||
|
||||
#### Request
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/usage
|
||||
```
|
||||
|
||||
#### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"start": "2025-01-27T20:00:00Z",
|
||||
"usage": [
|
||||
{
|
||||
"model": "llama3.2",
|
||||
"requests": 5,
|
||||
"prompt_tokens": 130,
|
||||
"completion_tokens": 890
|
||||
},
|
||||
{
|
||||
"model": "deepseek-r1",
|
||||
"requests": 2,
|
||||
"prompt_tokens": 48,
|
||||
"completion_tokens": 312
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
#### Response fields
|
||||
|
||||
- `start`: when the server started tracking usage (UTC, RFC 3339)
|
||||
- `usage`: list of per-model usage statistics
|
||||
- `model`: model name
|
||||
- `requests`: total number of completed requests
|
||||
- `prompt_tokens`: total prompt tokens evaluated
|
||||
- `completion_tokens`: total completion tokens generated
|
||||
|
||||
## Version
|
||||
|
||||
```
|
||||
|
||||
@@ -91,6 +91,8 @@ type Server struct {
|
||||
aliasesOnce sync.Once
|
||||
aliases *store
|
||||
aliasesErr error
|
||||
lowVRAM bool
|
||||
usage *UsageTracker
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -289,6 +291,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
c.Header("Content-Type", contentType)
|
||||
|
||||
fn := func(resp api.GenerateResponse) error {
|
||||
if resp.Done {
|
||||
s.usage.Record(origModel, resp.PromptEvalCount, resp.EvalCount)
|
||||
}
|
||||
|
||||
resp.Model = origModel
|
||||
resp.RemoteModel = m.Config.RemoteModel
|
||||
resp.RemoteHost = m.Config.RemoteHost
|
||||
@@ -595,6 +601,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
res.Context = tokens
|
||||
}
|
||||
|
||||
s.usage.Record(req.Model, cr.PromptEvalCount, cr.EvalCount)
|
||||
}
|
||||
|
||||
if builtinParser != nil {
|
||||
@@ -1622,6 +1630,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.POST("/api/experimental/aliases", s.CreateAliasHandler)
|
||||
r.DELETE("/api/experimental/aliases", s.DeleteAliasHandler)
|
||||
|
||||
r.GET("/api/usage", s.UsageHandler)
|
||||
|
||||
// Inference
|
||||
r.GET("/api/ps", s.PsHandler)
|
||||
r.POST("/api/generate", s.GenerateHandler)
|
||||
@@ -1692,7 +1702,7 @@ func Serve(ln net.Listener) error {
|
||||
}
|
||||
}
|
||||
|
||||
s := &Server{addr: ln.Addr()}
|
||||
s := &Server{addr: ln.Addr(), usage: NewUsageTracker()}
|
||||
|
||||
var rc *ollama.Registry
|
||||
if useClient2 {
|
||||
@@ -1927,6 +1937,10 @@ func (s *Server) SignoutHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, nil)
|
||||
}
|
||||
|
||||
func (s *Server) UsageHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, s.usage.Stats())
|
||||
}
|
||||
|
||||
func (s *Server) PsHandler(c *gin.Context) {
|
||||
models := []api.ProcessModelResponse{}
|
||||
|
||||
@@ -2097,6 +2111,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
c.Header("Content-Type", contentType)
|
||||
|
||||
fn := func(resp api.ChatResponse) error {
|
||||
if resp.Done {
|
||||
s.usage.Record(origModel, resp.PromptEvalCount, resp.EvalCount)
|
||||
}
|
||||
|
||||
resp.Model = origModel
|
||||
resp.RemoteModel = m.Config.RemoteModel
|
||||
resp.RemoteHost = m.Config.RemoteHost
|
||||
@@ -2317,6 +2335,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
res.DoneReason = r.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
|
||||
s.usage.Record(req.Model, r.PromptEvalCount, r.EvalCount)
|
||||
}
|
||||
|
||||
if builtinParser != nil {
|
||||
|
||||
@@ -30,6 +30,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -224,6 +225,7 @@ func TestChatDebugRenderOnly(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
|
||||
@@ -35,6 +35,7 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -220,6 +221,7 @@ func TestGenerateWithDebugRenderOnly(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
|
||||
@@ -88,19 +88,39 @@ func TestGenerateChatRemote(t *testing.T) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("Expected POST request, got %s", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/api/chat" {
|
||||
t.Errorf("Expected path '/api/chat', got %s", r.URL.Path)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
resp := api.ChatResponse{
|
||||
Model: "test",
|
||||
Done: true,
|
||||
DoneReason: "load",
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
switch r.URL.Path {
|
||||
case "/api/chat":
|
||||
resp := api.ChatResponse{
|
||||
Model: "test",
|
||||
Done: true,
|
||||
DoneReason: "load",
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: 10,
|
||||
EvalCount: 20,
|
||||
},
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
case "/api/generate":
|
||||
resp := api.GenerateResponse{
|
||||
Model: "test",
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: 5,
|
||||
EvalCount: 15,
|
||||
},
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
default:
|
||||
t.Errorf("unexpected path %s", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer rs.Close()
|
||||
@@ -111,7 +131,7 @@ func TestGenerateChatRemote(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_REMOTES", p.Hostname())
|
||||
s := Server{}
|
||||
s := Server{usage: NewUsageTracker()}
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-cloud",
|
||||
RemoteHost: rs.URL,
|
||||
@@ -159,6 +179,61 @@ func TestGenerateChatRemote(t *testing.T) {
|
||||
t.Errorf("expected done reason load, got %s", actual.DoneReason)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("remote chat usage tracking", func(t *testing.T) {
|
||||
stats := s.usage.Stats()
|
||||
found := false
|
||||
for _, m := range stats.Usage {
|
||||
if m.Model == "test-cloud" {
|
||||
found = true
|
||||
if m.Requests != 1 {
|
||||
t.Errorf("expected 1 request, got %d", m.Requests)
|
||||
}
|
||||
if m.PromptTokens != 10 {
|
||||
t.Errorf("expected 10 prompt tokens, got %d", m.PromptTokens)
|
||||
}
|
||||
if m.CompletionTokens != 20 {
|
||||
t.Errorf("expected 20 completion tokens, got %d", m.CompletionTokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected usage entry for test-cloud")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("remote generate usage tracking", func(t *testing.T) {
|
||||
// Reset the tracker for a clean test
|
||||
s.usage = NewUsageTracker()
|
||||
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-cloud",
|
||||
Prompt: "hello",
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
stats := s.usage.Stats()
|
||||
found := false
|
||||
for _, m := range stats.Usage {
|
||||
if m.Model == "test-cloud" {
|
||||
found = true
|
||||
if m.Requests != 1 {
|
||||
t.Errorf("expected 1 request, got %d", m.Requests)
|
||||
}
|
||||
if m.PromptTokens != 5 {
|
||||
t.Errorf("expected 5 prompt tokens, got %d", m.PromptTokens)
|
||||
}
|
||||
if m.CompletionTokens != 15 {
|
||||
t.Errorf("expected 15 completion tokens, got %d", m.CompletionTokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected usage entry for test-cloud")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateChat(t *testing.T) {
|
||||
@@ -177,6 +252,7 @@ func TestGenerateChat(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -894,6 +970,7 @@ func TestGenerate(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -1378,6 +1455,7 @@ func TestGenerateLogprobs(t *testing.T) {
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -1558,6 +1636,7 @@ func TestChatLogprobs(t *testing.T) {
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -1668,6 +1747,7 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -2114,6 +2194,7 @@ func TestGenerateUnload(t *testing.T) {
|
||||
var loadFnCalled bool
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -2215,6 +2296,7 @@ func TestGenerateWithImages(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -2393,6 +2475,7 @@ func TestImageGenerateStreamFalse(t *testing.T) {
|
||||
|
||||
opts := api.DefaultOptions()
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
|
||||
@@ -255,6 +255,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -406,6 +407,7 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
@@ -588,6 +590,7 @@ func TestChatHarmonyParserStreaming(t *testing.T) {
|
||||
}
|
||||
|
||||
s := Server{
|
||||
usage: NewUsageTracker(),
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
|
||||
62
server/usage.go
Normal file
62
server/usage.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type ModelUsage struct {
|
||||
Requests int64
|
||||
PromptTokens int64
|
||||
CompletionTokens int64
|
||||
}
|
||||
|
||||
type UsageTracker struct {
|
||||
mu sync.Mutex
|
||||
start time.Time
|
||||
models map[string]*ModelUsage
|
||||
}
|
||||
|
||||
func NewUsageTracker() *UsageTracker {
|
||||
return &UsageTracker{
|
||||
start: time.Now().UTC(),
|
||||
models: make(map[string]*ModelUsage),
|
||||
}
|
||||
}
|
||||
|
||||
func (u *UsageTracker) Record(model string, promptTokens, completionTokens int) {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
|
||||
m, ok := u.models[model]
|
||||
if !ok {
|
||||
m = &ModelUsage{}
|
||||
u.models[model] = m
|
||||
}
|
||||
|
||||
m.Requests++
|
||||
m.PromptTokens += int64(promptTokens)
|
||||
m.CompletionTokens += int64(completionTokens)
|
||||
}
|
||||
|
||||
func (u *UsageTracker) Stats() api.UsageResponse {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
|
||||
byModel := make([]api.ModelUsageData, 0, len(u.models))
|
||||
for model, usage := range u.models {
|
||||
byModel = append(byModel, api.ModelUsageData{
|
||||
Model: model,
|
||||
Requests: usage.Requests,
|
||||
PromptTokens: usage.PromptTokens,
|
||||
CompletionTokens: usage.CompletionTokens,
|
||||
})
|
||||
}
|
||||
|
||||
return api.UsageResponse{
|
||||
Start: u.start,
|
||||
Usage: byModel,
|
||||
}
|
||||
}
|
||||
136
server/usage_test.go
Normal file
136
server/usage_test.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestUsageTrackerRecord(t *testing.T) {
|
||||
tracker := NewUsageTracker()
|
||||
|
||||
tracker.Record("model-a", 10, 20)
|
||||
tracker.Record("model-a", 5, 15)
|
||||
tracker.Record("model-b", 100, 200)
|
||||
|
||||
stats := tracker.Stats()
|
||||
|
||||
if len(stats.Usage) != 2 {
|
||||
t.Fatalf("expected 2 models, got %d", len(stats.Usage))
|
||||
}
|
||||
|
||||
lookup := make(map[string]api.ModelUsageData)
|
||||
for _, m := range stats.Usage {
|
||||
lookup[m.Model] = m
|
||||
}
|
||||
|
||||
a := lookup["model-a"]
|
||||
if a.Requests != 2 {
|
||||
t.Errorf("model-a requests: expected 2, got %d", a.Requests)
|
||||
}
|
||||
if a.PromptTokens != 15 {
|
||||
t.Errorf("model-a prompt tokens: expected 15, got %d", a.PromptTokens)
|
||||
}
|
||||
if a.CompletionTokens != 35 {
|
||||
t.Errorf("model-a completion tokens: expected 35, got %d", a.CompletionTokens)
|
||||
}
|
||||
|
||||
b := lookup["model-b"]
|
||||
if b.Requests != 1 {
|
||||
t.Errorf("model-b requests: expected 1, got %d", b.Requests)
|
||||
}
|
||||
if b.PromptTokens != 100 {
|
||||
t.Errorf("model-b prompt tokens: expected 100, got %d", b.PromptTokens)
|
||||
}
|
||||
if b.CompletionTokens != 200 {
|
||||
t.Errorf("model-b completion tokens: expected 200, got %d", b.CompletionTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageTrackerConcurrent(t *testing.T) {
|
||||
tracker := NewUsageTracker()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for range 100 {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
tracker.Record("model-a", 1, 2)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
stats := tracker.Stats()
|
||||
if len(stats.Usage) != 1 {
|
||||
t.Fatalf("expected 1 model, got %d", len(stats.Usage))
|
||||
}
|
||||
|
||||
m := stats.Usage[0]
|
||||
if m.Requests != 100 {
|
||||
t.Errorf("requests: expected 100, got %d", m.Requests)
|
||||
}
|
||||
if m.PromptTokens != 100 {
|
||||
t.Errorf("prompt tokens: expected 100, got %d", m.PromptTokens)
|
||||
}
|
||||
if m.CompletionTokens != 200 {
|
||||
t.Errorf("completion tokens: expected 200, got %d", m.CompletionTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageTrackerStart(t *testing.T) {
|
||||
tracker := NewUsageTracker()
|
||||
|
||||
stats := tracker.Stats()
|
||||
if stats.Start.IsZero() {
|
||||
t.Error("expected non-zero start time")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageHandler(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
s := &Server{
|
||||
usage: NewUsageTracker(),
|
||||
}
|
||||
|
||||
s.usage.Record("llama3", 50, 100)
|
||||
s.usage.Record("llama3", 25, 50)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/api/usage", nil)
|
||||
|
||||
s.UsageHandler(c)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.UsageResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if len(resp.Usage) != 1 {
|
||||
t.Fatalf("expected 1 model, got %d", len(resp.Usage))
|
||||
}
|
||||
|
||||
m := resp.Usage[0]
|
||||
if m.Model != "llama3" {
|
||||
t.Errorf("expected model llama3, got %s", m.Model)
|
||||
}
|
||||
if m.Requests != 2 {
|
||||
t.Errorf("expected 2 requests, got %d", m.Requests)
|
||||
}
|
||||
if m.PromptTokens != 75 {
|
||||
t.Errorf("expected 75 prompt tokens, got %d", m.PromptTokens)
|
||||
}
|
||||
if m.CompletionTokens != 150 {
|
||||
t.Errorf("expected 150 completion tokens, got %d", m.CompletionTokens)
|
||||
}
|
||||
}
|
||||
@@ -16,10 +16,10 @@ import (
|
||||
)
|
||||
|
||||
type Function struct {
|
||||
Name string
|
||||
ReturnType string
|
||||
Params string
|
||||
ParamNames []string
|
||||
Name string
|
||||
ReturnType string
|
||||
Params string
|
||||
ParamNames []string
|
||||
NeedsARM64Guard bool
|
||||
}
|
||||
|
||||
@@ -29,11 +29,6 @@ func findHeaders(directory string) ([]string, error) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Private headers contain C++ implementation helpers and are not part of
|
||||
// the C API surface; parsing them can produce invalid wrapper signatures.
|
||||
if d.IsDir() && d.Name() == "private" {
|
||||
return fs.SkipDir
|
||||
}
|
||||
if !d.IsDir() && strings.HasSuffix(path, ".h") {
|
||||
headers = append(headers, path)
|
||||
}
|
||||
@@ -199,10 +194,10 @@ func parseFunctions(content string) []Function {
|
||||
needsGuard := needsARM64Guard(funcName, returnType, params)
|
||||
|
||||
functions = append(functions, Function{
|
||||
Name: funcName,
|
||||
ReturnType: returnType,
|
||||
Params: params,
|
||||
ParamNames: paramNames,
|
||||
Name: funcName,
|
||||
ReturnType: returnType,
|
||||
Params: params,
|
||||
ParamNames: paramNames,
|
||||
NeedsARM64Guard: needsGuard,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -20,8 +20,6 @@ mlx_array (*mlx_array_new_float64_ptr)(double val) = NULL;
|
||||
mlx_array (*mlx_array_new_double_ptr)(double val) = NULL;
|
||||
mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val) = NULL;
|
||||
mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype) = NULL;
|
||||
mlx_array (*mlx_array_new_data_managed_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) = NULL;
|
||||
mlx_array (*mlx_array_new_data_managed_payload_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) = NULL;
|
||||
int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src) = NULL;
|
||||
int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val) = NULL;
|
||||
int (*mlx_array_set_int_ptr)(mlx_array* arr, int val) = NULL;
|
||||
@@ -51,7 +49,7 @@ int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_complex64_ptr)(mlx_complex64_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_complex64_ptr)(float _Complex* res, const mlx_array arr) = NULL;
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr) = NULL;
|
||||
#endif
|
||||
@@ -69,7 +67,7 @@ const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr) = NULL;
|
||||
const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr) = NULL;
|
||||
const float* (*mlx_array_data_float32_ptr)(const mlx_array arr) = NULL;
|
||||
const double* (*mlx_array_data_float64_ptr)(const mlx_array arr) = NULL;
|
||||
const mlx_complex64_t* (*mlx_array_data_complex64_ptr)(const mlx_array arr) = NULL;
|
||||
const float _Complex* (*mlx_array_data_complex64_ptr)(const mlx_array arr) = NULL;
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr) = NULL;
|
||||
#endif
|
||||
@@ -125,7 +123,6 @@ int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id) = NULL;
|
||||
int (*mlx_disable_compile_ptr)(void) = NULL;
|
||||
int (*mlx_enable_compile_ptr)(void) = NULL;
|
||||
int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode) = NULL;
|
||||
int (*mlx_cuda_is_available_ptr)(bool* res) = NULL;
|
||||
mlx_device (*mlx_device_new_ptr)(void) = NULL;
|
||||
mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index) = NULL;
|
||||
int (*mlx_device_free_ptr)(mlx_device dev) = NULL;
|
||||
@@ -136,16 +133,6 @@ int (*mlx_device_get_index_ptr)(int* index, mlx_device dev) = NULL;
|
||||
int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev) = NULL;
|
||||
int (*mlx_get_default_device_ptr)(mlx_device* dev) = NULL;
|
||||
int (*mlx_set_default_device_ptr)(mlx_device dev) = NULL;
|
||||
int (*mlx_device_is_available_ptr)(bool* avail, mlx_device dev) = NULL;
|
||||
int (*mlx_device_count_ptr)(int* count, mlx_device_type type) = NULL;
|
||||
mlx_device_info (*mlx_device_info_new_ptr)(void) = NULL;
|
||||
int (*mlx_device_info_get_ptr)(mlx_device_info* info, mlx_device dev) = NULL;
|
||||
int (*mlx_device_info_free_ptr)(mlx_device_info info) = NULL;
|
||||
int (*mlx_device_info_has_key_ptr)(bool* exists, mlx_device_info info, const char* key) = NULL;
|
||||
int (*mlx_device_info_is_string_ptr)(bool* is_string, mlx_device_info info, const char* key) = NULL;
|
||||
int (*mlx_device_info_get_string_ptr)(const char** value, mlx_device_info info, const char* key) = NULL;
|
||||
int (*mlx_device_info_get_size_ptr)(size_t* value, mlx_device_info info, const char* key) = NULL;
|
||||
int (*mlx_device_info_get_keys_ptr)(mlx_vector_string* keys, mlx_device_info info) = NULL;
|
||||
int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) = NULL;
|
||||
int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL;
|
||||
int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL;
|
||||
@@ -276,6 +263,7 @@ int (*mlx_reset_peak_memory_ptr)(void) = NULL;
|
||||
int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit) = NULL;
|
||||
int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit) = NULL;
|
||||
int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit) = NULL;
|
||||
mlx_metal_device_info_t (*mlx_metal_device_info_ptr)(void) = NULL;
|
||||
int (*mlx_metal_is_available_ptr)(bool* res) = NULL;
|
||||
int (*mlx_metal_start_capture_ptr)(const char* path) = NULL;
|
||||
int (*mlx_metal_stop_capture_ptr)(void) = NULL;
|
||||
@@ -670,16 +658,6 @@ int mlx_load_functions(void* handle) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_array_new_data_managed_ptr = dlsym(handle, "mlx_array_new_data_managed");
|
||||
if (mlx_array_new_data_managed_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_array_new_data_managed_payload_ptr = dlsym(handle, "mlx_array_new_data_managed_payload");
|
||||
if (mlx_array_new_data_managed_payload_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed_payload\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_array_set_ptr = dlsym(handle, "mlx_array_set");
|
||||
if (mlx_array_set_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set\n");
|
||||
@@ -1163,11 +1141,6 @@ int mlx_load_functions(void* handle) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_compile_mode\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_cuda_is_available_ptr = dlsym(handle, "mlx_cuda_is_available");
|
||||
if (mlx_cuda_is_available_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_cuda_is_available\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_new_ptr = dlsym(handle, "mlx_device_new");
|
||||
if (mlx_device_new_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_new\n");
|
||||
@@ -1218,56 +1191,6 @@ int mlx_load_functions(void* handle) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_default_device\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_is_available_ptr = dlsym(handle, "mlx_device_is_available");
|
||||
if (mlx_device_is_available_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_is_available\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_count_ptr = dlsym(handle, "mlx_device_count");
|
||||
if (mlx_device_count_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_count\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_new_ptr = dlsym(handle, "mlx_device_info_new");
|
||||
if (mlx_device_info_new_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_new\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_get_ptr = dlsym(handle, "mlx_device_info_get");
|
||||
if (mlx_device_info_get_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_free_ptr = dlsym(handle, "mlx_device_info_free");
|
||||
if (mlx_device_info_free_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_free\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_has_key_ptr = dlsym(handle, "mlx_device_info_has_key");
|
||||
if (mlx_device_info_has_key_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_has_key\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_is_string_ptr = dlsym(handle, "mlx_device_info_is_string");
|
||||
if (mlx_device_info_is_string_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_is_string\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_get_string_ptr = dlsym(handle, "mlx_device_info_get_string");
|
||||
if (mlx_device_info_get_string_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_string\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_get_size_ptr = dlsym(handle, "mlx_device_info_get_size");
|
||||
if (mlx_device_info_get_size_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_size\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_device_info_get_keys_ptr = dlsym(handle, "mlx_device_info_get_keys");
|
||||
if (mlx_device_info_get_keys_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_keys\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_distributed_all_gather_ptr = dlsym(handle, "mlx_distributed_all_gather");
|
||||
if (mlx_distributed_all_gather_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_gather\n");
|
||||
@@ -1918,6 +1841,11 @@ int mlx_load_functions(void* handle) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_wired_limit\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_metal_device_info_ptr = dlsym(handle, "mlx_metal_device_info");
|
||||
if (mlx_metal_device_info_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_device_info\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_metal_is_available_ptr = dlsym(handle, "mlx_metal_is_available");
|
||||
if (mlx_metal_is_available_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_is_available\n");
|
||||
@@ -3600,14 +3528,6 @@ mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dt
|
||||
return mlx_array_new_data_ptr(data, shape, dim, dtype);
|
||||
}
|
||||
|
||||
mlx_array mlx_array_new_data_managed(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) {
|
||||
return mlx_array_new_data_managed_ptr(data, shape, dim, dtype, dtor);
|
||||
}
|
||||
|
||||
mlx_array mlx_array_new_data_managed_payload(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) {
|
||||
return mlx_array_new_data_managed_payload_ptr(data, shape, dim, dtype, payload, dtor);
|
||||
}
|
||||
|
||||
int mlx_array_set(mlx_array* arr, const mlx_array src) {
|
||||
return mlx_array_set_ptr(arr, src);
|
||||
}
|
||||
@@ -3724,7 +3644,7 @@ int mlx_array_item_float64(double* res, const mlx_array arr) {
|
||||
return mlx_array_item_float64_ptr(res, arr);
|
||||
}
|
||||
|
||||
int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr) {
|
||||
int mlx_array_item_complex64(float _Complex* res, const mlx_array arr) {
|
||||
return mlx_array_item_complex64_ptr(res, arr);
|
||||
}
|
||||
|
||||
@@ -3784,7 +3704,7 @@ const double* mlx_array_data_float64(const mlx_array arr) {
|
||||
return mlx_array_data_float64_ptr(arr);
|
||||
}
|
||||
|
||||
const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr) {
|
||||
const float _Complex* mlx_array_data_complex64(const mlx_array arr) {
|
||||
return mlx_array_data_complex64_ptr(arr);
|
||||
}
|
||||
|
||||
@@ -3996,10 +3916,6 @@ int mlx_set_compile_mode(mlx_compile_mode mode) {
|
||||
return mlx_set_compile_mode_ptr(mode);
|
||||
}
|
||||
|
||||
int mlx_cuda_is_available(bool* res) {
|
||||
return mlx_cuda_is_available_ptr(res);
|
||||
}
|
||||
|
||||
mlx_device mlx_device_new(void) {
|
||||
return mlx_device_new_ptr();
|
||||
}
|
||||
@@ -4040,46 +3956,6 @@ int mlx_set_default_device(mlx_device dev) {
|
||||
return mlx_set_default_device_ptr(dev);
|
||||
}
|
||||
|
||||
int mlx_device_is_available(bool* avail, mlx_device dev) {
|
||||
return mlx_device_is_available_ptr(avail, dev);
|
||||
}
|
||||
|
||||
int mlx_device_count(int* count, mlx_device_type type) {
|
||||
return mlx_device_count_ptr(count, type);
|
||||
}
|
||||
|
||||
mlx_device_info mlx_device_info_new(void) {
|
||||
return mlx_device_info_new_ptr();
|
||||
}
|
||||
|
||||
int mlx_device_info_get(mlx_device_info* info, mlx_device dev) {
|
||||
return mlx_device_info_get_ptr(info, dev);
|
||||
}
|
||||
|
||||
int mlx_device_info_free(mlx_device_info info) {
|
||||
return mlx_device_info_free_ptr(info);
|
||||
}
|
||||
|
||||
int mlx_device_info_has_key(bool* exists, mlx_device_info info, const char* key) {
|
||||
return mlx_device_info_has_key_ptr(exists, info, key);
|
||||
}
|
||||
|
||||
int mlx_device_info_is_string(bool* is_string, mlx_device_info info, const char* key) {
|
||||
return mlx_device_info_is_string_ptr(is_string, info, key);
|
||||
}
|
||||
|
||||
int mlx_device_info_get_string(const char** value, mlx_device_info info, const char* key) {
|
||||
return mlx_device_info_get_string_ptr(value, info, key);
|
||||
}
|
||||
|
||||
int mlx_device_info_get_size(size_t* value, mlx_device_info info, const char* key) {
|
||||
return mlx_device_info_get_size_ptr(value, info, key);
|
||||
}
|
||||
|
||||
int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info) {
|
||||
return mlx_device_info_get_keys_ptr(keys, info);
|
||||
}
|
||||
|
||||
int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) {
|
||||
return mlx_distributed_all_gather_ptr(res, x, group, S);
|
||||
}
|
||||
@@ -4600,6 +4476,10 @@ int mlx_set_wired_limit(size_t* res, size_t limit) {
|
||||
return mlx_set_wired_limit_ptr(res, limit);
|
||||
}
|
||||
|
||||
mlx_metal_device_info_t mlx_metal_device_info(void) {
|
||||
return mlx_metal_device_info_ptr();
|
||||
}
|
||||
|
||||
int mlx_metal_is_available(bool* res) {
|
||||
return mlx_metal_is_available_ptr(res);
|
||||
}
|
||||
|
||||
@@ -26,8 +26,6 @@
|
||||
#undef mlx_array_new_double
|
||||
#undef mlx_array_new_complex
|
||||
#undef mlx_array_new_data
|
||||
#undef mlx_array_new_data_managed
|
||||
#undef mlx_array_new_data_managed_payload
|
||||
#undef mlx_array_set
|
||||
#undef mlx_array_set_bool
|
||||
#undef mlx_array_set_int
|
||||
@@ -123,7 +121,6 @@
|
||||
#undef mlx_disable_compile
|
||||
#undef mlx_enable_compile
|
||||
#undef mlx_set_compile_mode
|
||||
#undef mlx_cuda_is_available
|
||||
#undef mlx_device_new
|
||||
#undef mlx_device_new_type
|
||||
#undef mlx_device_free
|
||||
@@ -134,16 +131,6 @@
|
||||
#undef mlx_device_get_type
|
||||
#undef mlx_get_default_device
|
||||
#undef mlx_set_default_device
|
||||
#undef mlx_device_is_available
|
||||
#undef mlx_device_count
|
||||
#undef mlx_device_info_new
|
||||
#undef mlx_device_info_get
|
||||
#undef mlx_device_info_free
|
||||
#undef mlx_device_info_has_key
|
||||
#undef mlx_device_info_is_string
|
||||
#undef mlx_device_info_get_string
|
||||
#undef mlx_device_info_get_size
|
||||
#undef mlx_device_info_get_keys
|
||||
#undef mlx_distributed_all_gather
|
||||
#undef mlx_distributed_all_max
|
||||
#undef mlx_distributed_all_min
|
||||
@@ -274,6 +261,7 @@
|
||||
#undef mlx_set_cache_limit
|
||||
#undef mlx_set_memory_limit
|
||||
#undef mlx_set_wired_limit
|
||||
#undef mlx_metal_device_info
|
||||
#undef mlx_metal_is_available
|
||||
#undef mlx_metal_start_capture
|
||||
#undef mlx_metal_stop_capture
|
||||
@@ -614,8 +602,6 @@ extern mlx_array (*mlx_array_new_float64_ptr)(double val);
|
||||
extern mlx_array (*mlx_array_new_double_ptr)(double val);
|
||||
extern mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val);
|
||||
extern mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype);
|
||||
extern mlx_array (*mlx_array_new_data_managed_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*));
|
||||
extern mlx_array (*mlx_array_new_data_managed_payload_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*));
|
||||
extern int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src);
|
||||
extern int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val);
|
||||
extern int (*mlx_array_set_int_ptr)(mlx_array* arr, int val);
|
||||
@@ -645,7 +631,7 @@ extern int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr);
|
||||
extern int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr);
|
||||
extern int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr);
|
||||
extern int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr);
|
||||
extern int (*mlx_array_item_complex64_ptr)(mlx_complex64_t* res, const mlx_array arr);
|
||||
extern int (*mlx_array_item_complex64_ptr)(float _Complex* res, const mlx_array arr);
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
extern int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr);
|
||||
#endif
|
||||
@@ -663,7 +649,7 @@ extern const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr);
|
||||
extern const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr);
|
||||
extern const float* (*mlx_array_data_float32_ptr)(const mlx_array arr);
|
||||
extern const double* (*mlx_array_data_float64_ptr)(const mlx_array arr);
|
||||
extern const mlx_complex64_t* (*mlx_array_data_complex64_ptr)(const mlx_array arr);
|
||||
extern const float _Complex* (*mlx_array_data_complex64_ptr)(const mlx_array arr);
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
extern const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr);
|
||||
#endif
|
||||
@@ -719,7 +705,6 @@ extern int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id);
|
||||
extern int (*mlx_disable_compile_ptr)(void);
|
||||
extern int (*mlx_enable_compile_ptr)(void);
|
||||
extern int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode);
|
||||
extern int (*mlx_cuda_is_available_ptr)(bool* res);
|
||||
extern mlx_device (*mlx_device_new_ptr)(void);
|
||||
extern mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index);
|
||||
extern int (*mlx_device_free_ptr)(mlx_device dev);
|
||||
@@ -730,16 +715,6 @@ extern int (*mlx_device_get_index_ptr)(int* index, mlx_device dev);
|
||||
extern int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev);
|
||||
extern int (*mlx_get_default_device_ptr)(mlx_device* dev);
|
||||
extern int (*mlx_set_default_device_ptr)(mlx_device dev);
|
||||
extern int (*mlx_device_is_available_ptr)(bool* avail, mlx_device dev);
|
||||
extern int (*mlx_device_count_ptr)(int* count, mlx_device_type type);
|
||||
extern mlx_device_info (*mlx_device_info_new_ptr)(void);
|
||||
extern int (*mlx_device_info_get_ptr)(mlx_device_info* info, mlx_device dev);
|
||||
extern int (*mlx_device_info_free_ptr)(mlx_device_info info);
|
||||
extern int (*mlx_device_info_has_key_ptr)(bool* exists, mlx_device_info info, const char* key);
|
||||
extern int (*mlx_device_info_is_string_ptr)(bool* is_string, mlx_device_info info, const char* key);
|
||||
extern int (*mlx_device_info_get_string_ptr)(const char** value, mlx_device_info info, const char* key);
|
||||
extern int (*mlx_device_info_get_size_ptr)(size_t* value, mlx_device_info info, const char* key);
|
||||
extern int (*mlx_device_info_get_keys_ptr)(mlx_vector_string* keys, mlx_device_info info);
|
||||
extern int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S);
|
||||
extern int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
|
||||
extern int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
|
||||
@@ -870,6 +845,7 @@ extern int (*mlx_reset_peak_memory_ptr)(void);
|
||||
extern int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit);
|
||||
extern int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit);
|
||||
extern int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit);
|
||||
extern mlx_metal_device_info_t (*mlx_metal_device_info_ptr)(void);
|
||||
extern int (*mlx_metal_is_available_ptr)(bool* res);
|
||||
extern int (*mlx_metal_start_capture_ptr)(const char* path);
|
||||
extern int (*mlx_metal_stop_capture_ptr)(void);
|
||||
@@ -1226,10 +1202,6 @@ mlx_array mlx_array_new_complex(float real_val, float imag_val);
|
||||
|
||||
mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dtype dtype);
|
||||
|
||||
mlx_array mlx_array_new_data_managed(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*));
|
||||
|
||||
mlx_array mlx_array_new_data_managed_payload(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*));
|
||||
|
||||
int mlx_array_set(mlx_array* arr, const mlx_array src);
|
||||
|
||||
int mlx_array_set_bool(mlx_array* arr, bool val);
|
||||
@@ -1288,7 +1260,7 @@ int mlx_array_item_float32(float* res, const mlx_array arr);
|
||||
|
||||
int mlx_array_item_float64(double* res, const mlx_array arr);
|
||||
|
||||
int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr);
|
||||
int mlx_array_item_complex64(float _Complex* res, const mlx_array arr);
|
||||
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
int mlx_array_item_float16(float16_t* res, const mlx_array arr);
|
||||
@@ -1320,7 +1292,7 @@ const float* mlx_array_data_float32(const mlx_array arr);
|
||||
|
||||
const double* mlx_array_data_float64(const mlx_array arr);
|
||||
|
||||
const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr);
|
||||
const float _Complex* mlx_array_data_complex64(const mlx_array arr);
|
||||
|
||||
#if defined(__aarch64__) || defined(_M_ARM64)
|
||||
const float16_t* mlx_array_data_float16(const mlx_array arr);
|
||||
@@ -1428,8 +1400,6 @@ int mlx_enable_compile(void);
|
||||
|
||||
int mlx_set_compile_mode(mlx_compile_mode mode);
|
||||
|
||||
int mlx_cuda_is_available(bool* res);
|
||||
|
||||
mlx_device mlx_device_new(void);
|
||||
|
||||
mlx_device mlx_device_new_type(mlx_device_type type, int index);
|
||||
@@ -1450,26 +1420,6 @@ int mlx_get_default_device(mlx_device* dev);
|
||||
|
||||
int mlx_set_default_device(mlx_device dev);
|
||||
|
||||
int mlx_device_is_available(bool* avail, mlx_device dev);
|
||||
|
||||
int mlx_device_count(int* count, mlx_device_type type);
|
||||
|
||||
mlx_device_info mlx_device_info_new(void);
|
||||
|
||||
int mlx_device_info_get(mlx_device_info* info, mlx_device dev);
|
||||
|
||||
int mlx_device_info_free(mlx_device_info info);
|
||||
|
||||
int mlx_device_info_has_key(bool* exists, mlx_device_info info, const char* key);
|
||||
|
||||
int mlx_device_info_is_string(bool* is_string, mlx_device_info info, const char* key);
|
||||
|
||||
int mlx_device_info_get_string(const char** value, mlx_device_info info, const char* key);
|
||||
|
||||
int mlx_device_info_get_size(size_t* value, mlx_device_info info, const char* key);
|
||||
|
||||
int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info);
|
||||
|
||||
int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S);
|
||||
|
||||
int mlx_distributed_all_max(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s);
|
||||
@@ -1730,6 +1680,8 @@ int mlx_set_memory_limit(size_t* res, size_t limit);
|
||||
|
||||
int mlx_set_wired_limit(size_t* res, size_t limit);
|
||||
|
||||
mlx_metal_device_info_t mlx_metal_device_info(void);
|
||||
|
||||
int mlx_metal_is_available(bool* res);
|
||||
|
||||
int mlx_metal_start_capture(const char* path);
|
||||
|
||||
@@ -15,7 +15,7 @@ set(CMAKE_INSTALL_RPATH "@loader_path")
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
set(MLX_C_GIT_TAG "v0.5.0" CACHE STRING "")
|
||||
set(MLX_C_GIT_TAG "v0.4.1" CACHE STRING "")
|
||||
|
||||
FetchContent_Declare(
|
||||
mlx-c
|
||||
|
||||
@@ -22,19 +22,6 @@ mlx_array (*mlx_array_new_data_)(
|
||||
const int* shape,
|
||||
int dim,
|
||||
mlx_dtype dtype) = NULL;
|
||||
mlx_array (*mlx_array_new_data_managed_)(
|
||||
void* data,
|
||||
const int* shape,
|
||||
int dim,
|
||||
mlx_dtype dtype,
|
||||
void (*dtor)(void*)) = NULL;
|
||||
mlx_array (*mlx_array_new_data_managed_payload_)(
|
||||
void* data,
|
||||
const int* shape,
|
||||
int dim,
|
||||
mlx_dtype dtype,
|
||||
void* payload,
|
||||
void (*dtor)(void*)) = NULL;
|
||||
int (*mlx_array_set_)(mlx_array* arr, const mlx_array src) = NULL;
|
||||
int (*mlx_array_set_bool_)(mlx_array* arr, bool val) = NULL;
|
||||
int (*mlx_array_set_int_)(mlx_array* arr, int val) = NULL;
|
||||
@@ -69,7 +56,7 @@ int (*mlx_array_item_int32_)(int32_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_int64_)(int64_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_float32_)(float* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_float64_)(double* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_complex64_)(mlx_complex64_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_complex64_)(float _Complex* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_float16_)(float16_t* res, const mlx_array arr) = NULL;
|
||||
int (*mlx_array_item_bfloat16_)(bfloat16_t* res, const mlx_array arr) = NULL;
|
||||
const bool * (*mlx_array_data_bool_)(const mlx_array arr) = NULL;
|
||||
@@ -83,7 +70,7 @@ const int32_t * (*mlx_array_data_int32_)(const mlx_array arr) = NULL;
|
||||
const int64_t * (*mlx_array_data_int64_)(const mlx_array arr) = NULL;
|
||||
const float * (*mlx_array_data_float32_)(const mlx_array arr) = NULL;
|
||||
const double * (*mlx_array_data_float64_)(const mlx_array arr) = NULL;
|
||||
const mlx_complex64_t * (*mlx_array_data_complex64_)(const mlx_array arr) = NULL;
|
||||
const float _Complex * (*mlx_array_data_complex64_)(const mlx_array arr) = NULL;
|
||||
const float16_t * (*mlx_array_data_float16_)(const mlx_array arr) = NULL;
|
||||
const bfloat16_t * (*mlx_array_data_bfloat16_)(const mlx_array arr) = NULL;
|
||||
int (*_mlx_array_is_available_)(bool* res, const mlx_array arr) = NULL;
|
||||
@@ -107,11 +94,10 @@ int (*mlx_closure_apply_)(
|
||||
mlx_closure (*mlx_closure_new_unary_)(int (*fun)(mlx_array*, const mlx_array)) = NULL;
|
||||
mlx_closure_kwargs (*mlx_closure_kwargs_new_)(void) = NULL;
|
||||
int (*mlx_closure_kwargs_free_)(mlx_closure_kwargs cls) = NULL;
|
||||
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_map_string_to_array)) = NULL;
|
||||
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_map_string_to_array)) = NULL;
|
||||
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_payload_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
@@ -150,12 +136,11 @@ int (*mlx_closure_value_and_grad_apply_)(
|
||||
const mlx_vector_array input) = NULL;
|
||||
mlx_closure_custom (*mlx_closure_custom_new_)(void) = NULL;
|
||||
int (*mlx_closure_custom_free_)(mlx_closure_custom cls) = NULL;
|
||||
mlx_closure_custom (*mlx_closure_custom_new_func_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array)) = NULL;
|
||||
mlx_closure_custom (*mlx_closure_custom_new_func_)(int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array)) = NULL;
|
||||
mlx_closure_custom (*mlx_closure_custom_new_func_payload_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
@@ -176,13 +161,12 @@ int (*mlx_closure_custom_apply_)(
|
||||
const mlx_vector_array input_2) = NULL;
|
||||
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_)(void) = NULL;
|
||||
int (*mlx_closure_custom_jvp_free_)(mlx_closure_custom_jvp cls) = NULL;
|
||||
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num)) = NULL;
|
||||
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num)) = NULL;
|
||||
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_payload_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
@@ -205,13 +189,12 @@ int (*mlx_closure_custom_jvp_apply_)(
|
||||
size_t input_2_num) = NULL;
|
||||
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_)(void) = NULL;
|
||||
int (*mlx_closure_custom_vmap_free_)(mlx_closure_custom_vmap cls) = NULL;
|
||||
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
mlx_vector_int*,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num)) = NULL;
|
||||
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(int (*fun)(
|
||||
mlx_vector_array*,
|
||||
mlx_vector_int*,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num)) = NULL;
|
||||
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_payload_)(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
@@ -245,7 +228,6 @@ int (*mlx_detail_compile_erase_)(uintptr_t fun_id) = NULL;
|
||||
int (*mlx_disable_compile_)(void) = NULL;
|
||||
int (*mlx_enable_compile_)(void) = NULL;
|
||||
int (*mlx_set_compile_mode_)(mlx_compile_mode mode) = NULL;
|
||||
int (*mlx_cuda_is_available_)(bool* res) = NULL;
|
||||
mlx_device (*mlx_device_new_)(void) = NULL;
|
||||
mlx_device (*mlx_device_new_type_)(mlx_device_type type, int index) = NULL;
|
||||
int (*mlx_device_free_)(mlx_device dev) = NULL;
|
||||
@@ -256,28 +238,11 @@ int (*mlx_device_get_index_)(int* index, mlx_device dev) = NULL;
|
||||
int (*mlx_device_get_type_)(mlx_device_type* type, mlx_device dev) = NULL;
|
||||
int (*mlx_get_default_device_)(mlx_device* dev) = NULL;
|
||||
int (*mlx_set_default_device_)(mlx_device dev) = NULL;
|
||||
int (*mlx_device_is_available_)(bool* avail, mlx_device dev) = NULL;
|
||||
int (*mlx_device_count_)(int* count, mlx_device_type type) = NULL;
|
||||
mlx_device_info (*mlx_device_info_new_)(void) = NULL;
|
||||
int (*mlx_device_info_get_)(mlx_device_info* info, mlx_device dev) = NULL;
|
||||
int (*mlx_device_info_free_)(mlx_device_info info) = NULL;
|
||||
int (*mlx_device_info_has_key_)(
|
||||
bool* exists,
|
||||
mlx_device_info info,
|
||||
const char* key) = NULL;
|
||||
int (*mlx_device_info_is_string_)(
|
||||
bool* is_string,
|
||||
mlx_device_info info,
|
||||
const char* key) = NULL;
|
||||
int (*mlx_device_info_get_string_)(
|
||||
const char** value,
|
||||
mlx_device_info info,
|
||||
const char* key) = NULL;
|
||||
int (*mlx_device_info_get_size_)(
|
||||
size_t* value,
|
||||
mlx_device_info info,
|
||||
const char* key) = NULL;
|
||||
int (*mlx_device_info_get_keys_)(mlx_vector_string* keys, mlx_device_info info) = NULL;
|
||||
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
|
||||
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
|
||||
bool (*mlx_distributed_is_available_)(void) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL;
|
||||
int (*mlx_distributed_all_gather_)(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
@@ -323,11 +288,6 @@ int (*mlx_distributed_sum_scatter_)(
|
||||
const mlx_array x,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
|
||||
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
|
||||
bool (*mlx_distributed_is_available_)(void) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL;
|
||||
void (*mlx_set_error_handler_)(
|
||||
mlx_error_handler_func handler,
|
||||
void* data,
|
||||
@@ -490,16 +450,6 @@ int (*mlx_fast_rope_)(
|
||||
int offset,
|
||||
const mlx_array freqs /* may be null */,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_fast_rope_dynamic_)(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
int dims,
|
||||
bool traditional,
|
||||
mlx_optional_float base,
|
||||
float scale,
|
||||
const mlx_array offset,
|
||||
const mlx_array freqs /* may be null */,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_fast_scaled_dot_product_attention_)(
|
||||
mlx_array* res,
|
||||
const mlx_array queries,
|
||||
@@ -610,6 +560,14 @@ int (*mlx_fft_rfftn_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable) = NULL;
|
||||
int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io) = NULL;
|
||||
int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io) = NULL;
|
||||
int (*mlx_io_reader_free_)(mlx_io_reader io) = NULL;
|
||||
mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable) = NULL;
|
||||
int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io) = NULL;
|
||||
int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io) = NULL;
|
||||
int (*mlx_io_writer_free_)(mlx_io_writer io) = NULL;
|
||||
int (*mlx_load_reader_)(
|
||||
mlx_array* res,
|
||||
mlx_io_reader in_stream,
|
||||
@@ -635,14 +593,6 @@ int (*mlx_save_safetensors_)(
|
||||
const char* file,
|
||||
const mlx_map_string_to_array param,
|
||||
const mlx_map_string_to_string metadata) = NULL;
|
||||
mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable) = NULL;
|
||||
int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io) = NULL;
|
||||
int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io) = NULL;
|
||||
int (*mlx_io_reader_free_)(mlx_io_reader io) = NULL;
|
||||
mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable) = NULL;
|
||||
int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io) = NULL;
|
||||
int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io) = NULL;
|
||||
int (*mlx_io_writer_free_)(mlx_io_writer io) = NULL;
|
||||
int (*mlx_linalg_cholesky_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -783,6 +733,7 @@ int (*mlx_reset_peak_memory_)(void) = NULL;
|
||||
int (*mlx_set_cache_limit_)(size_t* res, size_t limit) = NULL;
|
||||
int (*mlx_set_memory_limit_)(size_t* res, size_t limit) = NULL;
|
||||
int (*mlx_set_wired_limit_)(size_t* res, size_t limit) = NULL;
|
||||
mlx_metal_device_info_t (*mlx_metal_device_info_)(void) = NULL;
|
||||
int (*mlx_metal_is_available_)(bool* res) = NULL;
|
||||
int (*mlx_metal_start_capture_)(const char* path) = NULL;
|
||||
int (*mlx_metal_stop_capture_)(void) = NULL;
|
||||
@@ -1211,14 +1162,6 @@ int (*mlx_gather_)(
|
||||
const int* slice_sizes,
|
||||
size_t slice_sizes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_gather_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
int axis,
|
||||
const int* slice_sizes,
|
||||
size_t slice_sizes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_gather_mm_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -1540,15 +1483,6 @@ int (*mlx_put_along_axis_)(
|
||||
const mlx_array values,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_qqmm_)(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
const mlx_array w,
|
||||
const mlx_array w_scales /* may be null */,
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_quantize_)(
|
||||
mlx_vector_array* res,
|
||||
const mlx_array w,
|
||||
@@ -1632,13 +1566,6 @@ int (*mlx_scatter_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
const mlx_array updates,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_add_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -1647,13 +1574,6 @@ int (*mlx_scatter_add_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_add_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
const mlx_array updates,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_add_axis_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -1669,13 +1589,6 @@ int (*mlx_scatter_max_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_max_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
const mlx_array updates,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_min_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -1684,13 +1597,6 @@ int (*mlx_scatter_min_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_min_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
const mlx_array updates,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_prod_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -1699,13 +1605,6 @@ int (*mlx_scatter_prod_)(
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_scatter_prod_single_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array indices,
|
||||
const mlx_array updates,
|
||||
int axis,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_segmented_mm_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -2129,6 +2028,22 @@ mlx_string (*mlx_string_new_data_)(const char* str) = NULL;
|
||||
int (*mlx_string_set_)(mlx_string* str, const mlx_string src) = NULL;
|
||||
const char * (*mlx_string_data_)(mlx_string str) = NULL;
|
||||
int (*mlx_string_free_)(mlx_string str) = NULL;
|
||||
int (*mlx_detail_vmap_replace_)(
|
||||
mlx_vector_array* res,
|
||||
const mlx_vector_array inputs,
|
||||
const mlx_vector_array s_inputs,
|
||||
const mlx_vector_array s_outputs,
|
||||
const int* in_axes,
|
||||
size_t in_axes_num,
|
||||
const int* out_axes,
|
||||
size_t out_axes_num) = NULL;
|
||||
int (*mlx_detail_vmap_trace_)(
|
||||
mlx_vector_array* res_0,
|
||||
mlx_vector_array* res_1,
|
||||
const mlx_closure fun,
|
||||
const mlx_vector_array inputs,
|
||||
const int* in_axes,
|
||||
size_t in_axes_num) = NULL;
|
||||
int (*mlx_async_eval_)(const mlx_vector_array outputs) = NULL;
|
||||
int (*mlx_checkpoint_)(mlx_closure* res, const mlx_closure fun) = NULL;
|
||||
int (*mlx_custom_function_)(
|
||||
@@ -2159,22 +2074,6 @@ int (*mlx_vjp_)(
|
||||
const mlx_closure fun,
|
||||
const mlx_vector_array primals,
|
||||
const mlx_vector_array cotangents) = NULL;
|
||||
int (*mlx_detail_vmap_replace_)(
|
||||
mlx_vector_array* res,
|
||||
const mlx_vector_array inputs,
|
||||
const mlx_vector_array s_inputs,
|
||||
const mlx_vector_array s_outputs,
|
||||
const int* in_axes,
|
||||
size_t in_axes_num,
|
||||
const int* out_axes,
|
||||
size_t out_axes_num) = NULL;
|
||||
int (*mlx_detail_vmap_trace_)(
|
||||
mlx_vector_array* res_0,
|
||||
mlx_vector_array* res_1,
|
||||
const mlx_closure fun,
|
||||
const mlx_vector_array inputs,
|
||||
const int* in_axes,
|
||||
size_t in_axes_num) = NULL;
|
||||
mlx_vector_array (*mlx_vector_array_new_)(void) = NULL;
|
||||
int (*mlx_vector_array_set_)(mlx_vector_array* vec, const mlx_vector_array src) = NULL;
|
||||
int (*mlx_vector_array_free_)(mlx_vector_array vec) = NULL;
|
||||
@@ -2267,8 +2166,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_array_new_double);
|
||||
CHECK_LOAD(handle, mlx_array_new_complex);
|
||||
CHECK_LOAD(handle, mlx_array_new_data);
|
||||
CHECK_LOAD(handle, mlx_array_new_data_managed);
|
||||
CHECK_LOAD(handle, mlx_array_new_data_managed_payload);
|
||||
CHECK_LOAD(handle, mlx_array_set);
|
||||
CHECK_LOAD(handle, mlx_array_set_bool);
|
||||
CHECK_LOAD(handle, mlx_array_set_int);
|
||||
@@ -2364,7 +2261,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_disable_compile);
|
||||
CHECK_LOAD(handle, mlx_enable_compile);
|
||||
CHECK_LOAD(handle, mlx_set_compile_mode);
|
||||
CHECK_LOAD(handle, mlx_cuda_is_available);
|
||||
CHECK_LOAD(handle, mlx_device_new);
|
||||
CHECK_LOAD(handle, mlx_device_new_type);
|
||||
CHECK_LOAD(handle, mlx_device_free);
|
||||
@@ -2375,16 +2271,11 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_device_get_type);
|
||||
CHECK_LOAD(handle, mlx_get_default_device);
|
||||
CHECK_LOAD(handle, mlx_set_default_device);
|
||||
CHECK_LOAD(handle, mlx_device_is_available);
|
||||
CHECK_LOAD(handle, mlx_device_count);
|
||||
CHECK_LOAD(handle, mlx_device_info_new);
|
||||
CHECK_LOAD(handle, mlx_device_info_get);
|
||||
CHECK_LOAD(handle, mlx_device_info_free);
|
||||
CHECK_LOAD(handle, mlx_device_info_has_key);
|
||||
CHECK_LOAD(handle, mlx_device_info_is_string);
|
||||
CHECK_LOAD(handle, mlx_device_info_get_string);
|
||||
CHECK_LOAD(handle, mlx_device_info_get_size);
|
||||
CHECK_LOAD(handle, mlx_device_info_get_keys);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_rank);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_size);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_split);
|
||||
CHECK_LOAD(handle, mlx_distributed_is_available);
|
||||
CHECK_LOAD(handle, mlx_distributed_init);
|
||||
CHECK_LOAD(handle, mlx_distributed_all_gather);
|
||||
CHECK_LOAD(handle, mlx_distributed_all_max);
|
||||
CHECK_LOAD(handle, mlx_distributed_all_min);
|
||||
@@ -2393,11 +2284,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_distributed_recv_like);
|
||||
CHECK_LOAD(handle, mlx_distributed_send);
|
||||
CHECK_LOAD(handle, mlx_distributed_sum_scatter);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_rank);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_size);
|
||||
CHECK_LOAD(handle, mlx_distributed_group_split);
|
||||
CHECK_LOAD(handle, mlx_distributed_is_available);
|
||||
CHECK_LOAD(handle, mlx_distributed_init);
|
||||
CHECK_LOAD(handle, mlx_set_error_handler);
|
||||
CHECK_LOAD(handle, _mlx_error);
|
||||
CHECK_LOAD(handle, mlx_export_function);
|
||||
@@ -2439,7 +2325,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_fast_metal_kernel_apply);
|
||||
CHECK_LOAD(handle, mlx_fast_rms_norm);
|
||||
CHECK_LOAD(handle, mlx_fast_rope);
|
||||
CHECK_LOAD(handle, mlx_fast_rope_dynamic);
|
||||
CHECK_LOAD(handle, mlx_fast_scaled_dot_product_attention);
|
||||
CHECK_LOAD(handle, mlx_fft_fft);
|
||||
CHECK_LOAD(handle, mlx_fft_fft2);
|
||||
@@ -2455,14 +2340,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_fft_rfft);
|
||||
CHECK_LOAD(handle, mlx_fft_rfft2);
|
||||
CHECK_LOAD(handle, mlx_fft_rfftn);
|
||||
CHECK_LOAD(handle, mlx_load_reader);
|
||||
CHECK_LOAD(handle, mlx_load);
|
||||
CHECK_LOAD(handle, mlx_load_safetensors_reader);
|
||||
CHECK_LOAD(handle, mlx_load_safetensors);
|
||||
CHECK_LOAD(handle, mlx_save_writer);
|
||||
CHECK_LOAD(handle, mlx_save);
|
||||
CHECK_LOAD(handle, mlx_save_safetensors_writer);
|
||||
CHECK_LOAD(handle, mlx_save_safetensors);
|
||||
CHECK_LOAD(handle, mlx_io_reader_new);
|
||||
CHECK_LOAD(handle, mlx_io_reader_descriptor);
|
||||
CHECK_LOAD(handle, mlx_io_reader_tostring);
|
||||
@@ -2471,6 +2348,14 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_io_writer_descriptor);
|
||||
CHECK_LOAD(handle, mlx_io_writer_tostring);
|
||||
CHECK_LOAD(handle, mlx_io_writer_free);
|
||||
CHECK_LOAD(handle, mlx_load_reader);
|
||||
CHECK_LOAD(handle, mlx_load);
|
||||
CHECK_LOAD(handle, mlx_load_safetensors_reader);
|
||||
CHECK_LOAD(handle, mlx_load_safetensors);
|
||||
CHECK_LOAD(handle, mlx_save_writer);
|
||||
CHECK_LOAD(handle, mlx_save);
|
||||
CHECK_LOAD(handle, mlx_save_safetensors_writer);
|
||||
CHECK_LOAD(handle, mlx_save_safetensors);
|
||||
CHECK_LOAD(handle, mlx_linalg_cholesky);
|
||||
CHECK_LOAD(handle, mlx_linalg_cholesky_inv);
|
||||
CHECK_LOAD(handle, mlx_linalg_cross);
|
||||
@@ -2515,6 +2400,7 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_set_cache_limit);
|
||||
CHECK_LOAD(handle, mlx_set_memory_limit);
|
||||
CHECK_LOAD(handle, mlx_set_wired_limit);
|
||||
CHECK_LOAD(handle, mlx_metal_device_info);
|
||||
CHECK_LOAD(handle, mlx_metal_is_available);
|
||||
CHECK_LOAD(handle, mlx_metal_start_capture);
|
||||
CHECK_LOAD(handle, mlx_metal_stop_capture);
|
||||
@@ -2600,7 +2486,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_full);
|
||||
CHECK_LOAD(handle, mlx_full_like);
|
||||
CHECK_LOAD(handle, mlx_gather);
|
||||
CHECK_LOAD(handle, mlx_gather_single);
|
||||
CHECK_LOAD(handle, mlx_gather_mm);
|
||||
CHECK_LOAD(handle, mlx_gather_qmm);
|
||||
CHECK_LOAD(handle, mlx_greater);
|
||||
@@ -2665,7 +2550,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_prod_axis);
|
||||
CHECK_LOAD(handle, mlx_prod);
|
||||
CHECK_LOAD(handle, mlx_put_along_axis);
|
||||
CHECK_LOAD(handle, mlx_qqmm);
|
||||
CHECK_LOAD(handle, mlx_quantize);
|
||||
CHECK_LOAD(handle, mlx_quantized_matmul);
|
||||
CHECK_LOAD(handle, mlx_radians);
|
||||
@@ -2682,16 +2566,11 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_round);
|
||||
CHECK_LOAD(handle, mlx_rsqrt);
|
||||
CHECK_LOAD(handle, mlx_scatter);
|
||||
CHECK_LOAD(handle, mlx_scatter_single);
|
||||
CHECK_LOAD(handle, mlx_scatter_add);
|
||||
CHECK_LOAD(handle, mlx_scatter_add_single);
|
||||
CHECK_LOAD(handle, mlx_scatter_add_axis);
|
||||
CHECK_LOAD(handle, mlx_scatter_max);
|
||||
CHECK_LOAD(handle, mlx_scatter_max_single);
|
||||
CHECK_LOAD(handle, mlx_scatter_min);
|
||||
CHECK_LOAD(handle, mlx_scatter_min_single);
|
||||
CHECK_LOAD(handle, mlx_scatter_prod);
|
||||
CHECK_LOAD(handle, mlx_scatter_prod_single);
|
||||
CHECK_LOAD(handle, mlx_segmented_mm);
|
||||
CHECK_LOAD(handle, mlx_sigmoid);
|
||||
CHECK_LOAD(handle, mlx_sign);
|
||||
@@ -2786,6 +2665,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_string_set);
|
||||
CHECK_LOAD(handle, mlx_string_data);
|
||||
CHECK_LOAD(handle, mlx_string_free);
|
||||
CHECK_LOAD(handle, mlx_detail_vmap_replace);
|
||||
CHECK_LOAD(handle, mlx_detail_vmap_trace);
|
||||
CHECK_LOAD(handle, mlx_async_eval);
|
||||
CHECK_LOAD(handle, mlx_checkpoint);
|
||||
CHECK_LOAD(handle, mlx_custom_function);
|
||||
@@ -2794,8 +2675,6 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_jvp);
|
||||
CHECK_LOAD(handle, mlx_value_and_grad);
|
||||
CHECK_LOAD(handle, mlx_vjp);
|
||||
CHECK_LOAD(handle, mlx_detail_vmap_replace);
|
||||
CHECK_LOAD(handle, mlx_detail_vmap_trace);
|
||||
CHECK_LOAD(handle, mlx_vector_array_new);
|
||||
CHECK_LOAD(handle, mlx_vector_array_set);
|
||||
CHECK_LOAD(handle, mlx_vector_array_free);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,10 +4,6 @@
|
||||
#define MLX_GENERATED_H
|
||||
|
||||
#include "dynamic.h"
|
||||
{{ range .Functions }}
|
||||
#define {{ .Name }} {{ .Name }}_mlx_gen_orig_
|
||||
{{- end }}
|
||||
|
||||
#include "mlx/c/mlx.h"
|
||||
{{ range .Functions }}
|
||||
#undef {{ .Name }}
|
||||
|
||||
Reference in New Issue
Block a user